Skip to content

Commit

Permalink
Add support for ssh-agent
Browse files Browse the repository at this point in the history
Signed-off-by: yuguorui <[email protected]>
  • Loading branch information
yuguorui committed Dec 5, 2023
1 parent 2c4ef83 commit 5d55354
Showing 1 changed file with 37 additions and 4 deletions.
41 changes: 37 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import (

"github.com/dsnet/golib/jsonfmt"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"golang.org/x/crypto/ssh/knownhosts"
)

Expand All @@ -56,6 +57,8 @@ type TunnelConfig struct {
// If the path is empty, then the server will output to os.Stderr.
LogFile string `json:",omitempty"`

SshAgentSocket string `json:",omitempty"`

// KeyFiles is a list of SSH private key files.
KeyFiles []string

Expand Down Expand Up @@ -116,6 +119,21 @@ type KeepAliveConfig struct {
CountMax uint
}

func setupSshAgent(socket string) ssh.AuthMethod {
if len(socket) == 0 {
return nil
}

conn, err := net.Dial("unix", socket)
if err != nil {
log.Printf("Failed to open SSH_AUTH_SOCK %s: %v\n", socket, err)
return nil
}

agentClient := agent.NewClient(conn)
return ssh.PublicKeysCallback(agentClient.Signers)
}

func loadConfig(conf string) (tunns []tunnel, logger *log.Logger, closer func() error) {
var logBuf bytes.Buffer
logger = log.New(io.MultiWriter(os.Stderr, &logBuf), "", log.Ldate|log.Ltime|log.Lshortfile)
Expand All @@ -137,6 +155,10 @@ func loadConfig(conf string) (tunns []tunnel, logger *log.Logger, closer func()
if err := json.Unmarshal(c, &config); err != nil {
logger.Fatalf("unable to decode config: %v", err)
}
if config.SshAgentSocket == "" {
// ssh-agent(1) provides a UNIX socket at $SSH_AUTH_SOCK.
config.SshAgentSocket = os.Getenv("SSH_AUTH_SOCK")
}
for _, t := range config.Tunnels {
if config.KeepAlive == nil && t.KeepAlive == nil {
config.KeepAlive = &KeepAliveConfig{Interval: 30, CountMax: 2}
Expand Down Expand Up @@ -171,11 +193,10 @@ func loadConfig(conf string) (tunns []tunnel, logger *log.Logger, closer func()
closer = f.Close
}

var auth []ssh.AuthMethod

// Parse all of the private keys.
var keys []ssh.Signer
if len(config.KeyFiles) == 0 {
logger.Fatal("no private keys specified")
}
for _, kf := range config.KeyFiles {
b, err := ioutil.ReadFile(kf)
if err != nil {
Expand All @@ -187,7 +208,19 @@ func loadConfig(conf string) (tunns []tunnel, logger *log.Logger, closer func()
}
keys = append(keys, k)
}
auth := []ssh.AuthMethod{ssh.PublicKeys(keys...)}
if len(keys) > 0 {
auth = append(auth, ssh.PublicKeys(keys...))
}

// Setup ssh-agent(1)
agent := setupSshAgent(config.SshAgentSocket)
if agent != nil {
auth = append(auth, agent)
}

if len(auth) == 0 {
logger.Panic("no private keys and ssh-agent usable")
}

// Parse all of the host public keys.
if len(config.KnownHostFiles) == 0 {
Expand Down

0 comments on commit 5d55354

Please sign in to comment.