diff --git a/beszel/cmd/agent/agent.go b/beszel/cmd/agent/agent.go index f789b544..631abc07 100644 --- a/beszel/cmd/agent/agent.go +++ b/beszel/cmd/agent/agent.go @@ -12,15 +12,16 @@ import ( "golang.org/x/crypto/ssh" ) -type cmdConfig struct { +// cli options +type cmdOptions struct { key string // key is the public key(s) for SSH authentication. addr string // addr is the address or port to listen on. } // parseFlags parses the command line flags and populates the config struct. -func parseFlags(cfg *cmdConfig) { - flag.StringVar(&cfg.key, "key", "", "Public key(s) for SSH authentication") - flag.StringVar(&cfg.addr, "addr", "", "Address or port to listen on") +func (opts *cmdOptions) parseFlags() { + flag.StringVar(&opts.key, "key", "", "Public key(s) for SSH authentication") + flag.StringVar(&opts.addr, "addr", "", "Address or port to listen on") flag.Usage = func() { fmt.Printf("Usage: %s [options] [subcommand]\n", os.Args[0]) @@ -54,10 +55,10 @@ func handleSubcommand() bool { } // loadPublicKeys loads the public keys from the command line flag, environment variable, or key file. -func loadPublicKeys(cfg cmdConfig) ([]ssh.PublicKey, error) { +func (opts *cmdOptions) loadPublicKeys() ([]ssh.PublicKey, error) { // Try command line flag first - if cfg.key != "" { - return agent.ParseKeys(cfg.key) + if opts.key != "" { + return agent.ParseKeys(opts.key) } // Try environment variable @@ -68,7 +69,7 @@ func loadPublicKeys(cfg cmdConfig) ([]ssh.PublicKey, error) { // Try key file keyFile, ok := agent.GetEnv("KEY_FILE") if !ok { - return nil, fmt.Errorf("no key provided: must set -key flag, KEY env var, or KEY_FILE env var. ") + return nil, fmt.Errorf("no key provided: must set -key flag, KEY env var, or KEY_FILE env var. Use 'beszel-agent help' for usage") } pubKey, err := os.ReadFile(keyFile) @@ -79,10 +80,10 @@ func loadPublicKeys(cfg cmdConfig) ([]ssh.PublicKey, error) { } // getAddress gets the address to listen on from the command line flag, environment variable, or default value. -func getAddress(addr string) string { +func (opts *cmdOptions) getAddress() string { // Try command line flag first - if addr != "" { - return addr + if opts.addr != "" { + return opts.addr } // Try environment variables if addr, ok := agent.GetEnv("ADDR"); ok && addr != "" { @@ -96,19 +97,19 @@ func getAddress(addr string) string { } // getNetwork returns the network type to use for the server. -func getNetwork(addr string) string { +func (opts *cmdOptions) getNetwork() string { if network, _ := agent.GetEnv("NETWORK"); network != "" { return network } - if strings.HasPrefix(addr, "/") { + if strings.HasPrefix(opts.addr, "/") { return "unix" } return "tcp" } func main() { - var cfg cmdConfig - parseFlags(&cfg) + var opts cmdOptions + opts.parseFlags() if handleSubcommand() { return @@ -116,15 +117,15 @@ func main() { flag.Parse() - var serverConfig agent.ServerConfig + var serverConfig agent.ServerOptions var err error - serverConfig.Keys, err = loadPublicKeys(cfg) + serverConfig.Keys, err = opts.loadPublicKeys() if err != nil { log.Fatal("Failed to load public keys:", err) } - serverConfig.Addr = getAddress(cfg.addr) - serverConfig.Network = getNetwork(cfg.addr) + serverConfig.Addr = opts.getAddress() + serverConfig.Network = opts.getNetwork() agent := agent.NewAgent() if err := agent.StartServer(serverConfig); err != nil { diff --git a/beszel/cmd/agent/agent_test.go b/beszel/cmd/agent/agent_test.go index 94582f83..cc72fb77 100644 --- a/beszel/cmd/agent/agent_test.go +++ b/beszel/cmd/agent/agent_test.go @@ -15,32 +15,32 @@ import ( func TestGetAddress(t *testing.T) { tests := []struct { name string - cfg cmdConfig + opts cmdOptions envVars map[string]string expected string }{ { name: "default port when no config", - cfg: cmdConfig{}, + opts: cmdOptions{}, expected: ":45876", }, { name: "use address from flag", - cfg: cmdConfig{ + opts: cmdOptions{ addr: "8080", }, expected: "8080", }, { name: "use unix socket from flag", - cfg: cmdConfig{ + opts: cmdOptions{ addr: "/tmp/beszel.sock", }, expected: "/tmp/beszel.sock", }, { name: "use ADDR env var", - cfg: cmdConfig{}, + opts: cmdOptions{}, envVars: map[string]string{ "ADDR": "1.2.3.4:9090", }, @@ -48,7 +48,7 @@ func TestGetAddress(t *testing.T) { }, { name: "use legacy PORT env var", - cfg: cmdConfig{}, + opts: cmdOptions{}, envVars: map[string]string{ "PORT": "7070", }, @@ -56,7 +56,7 @@ func TestGetAddress(t *testing.T) { }, { name: "flag takes precedence over env vars", - cfg: cmdConfig{ + opts: cmdOptions{ addr: ":8080", }, envVars: map[string]string{ @@ -74,7 +74,7 @@ func TestGetAddress(t *testing.T) { t.Setenv(k, v) } - addr := getAddress(tt.cfg.addr) + addr := tt.opts.getAddress() assert.Equal(t, tt.expected, addr) }) } @@ -90,7 +90,7 @@ func TestLoadPublicKeys(t *testing.T) { tests := []struct { name string - cfg cmdConfig + opts cmdOptions envVars map[string]string setupFiles map[string][]byte wantErr bool @@ -98,7 +98,7 @@ func TestLoadPublicKeys(t *testing.T) { }{ { name: "load key from flag", - cfg: cmdConfig{ + opts: cmdOptions{ key: string(pubKey), }, }, @@ -132,7 +132,7 @@ func TestLoadPublicKeys(t *testing.T) { }, { name: "error on invalid key data", - cfg: cmdConfig{ + opts: cmdOptions{ key: "invalid-key-data", }, wantErr: true, @@ -159,7 +159,7 @@ func TestLoadPublicKeys(t *testing.T) { t.Setenv(k, v) } - keys, err := loadPublicKeys(tt.cfg) + keys, err := tt.opts.loadPublicKeys() if tt.wantErr { assert.Error(t, err) if tt.errContains != "" { @@ -178,33 +178,40 @@ func TestLoadPublicKeys(t *testing.T) { func TestGetNetwork(t *testing.T) { tests := []struct { name string - addr string + opts cmdOptions envVars map[string]string expected string }{ + { + name: "NETWORK env var", + envVars: map[string]string{ + "NETWORK": "tcp4", + }, + expected: "tcp4", + }, { name: "only port", - addr: "8080", + opts: cmdOptions{addr: "8080"}, expected: "tcp", }, { name: "ipv4 address", - addr: "1.2.3.4:8080", + opts: cmdOptions{addr: "1.2.3.4:8080"}, expected: "tcp", }, { name: "ipv6 address", - addr: "[2001:db8::1]:8080", + opts: cmdOptions{addr: "[2001:db8::1]:8080"}, expected: "tcp", }, { name: "unix network", - addr: "/tmp/beszel.sock", + opts: cmdOptions{addr: "/tmp/beszel.sock"}, expected: "unix", }, { name: "env var network", - addr: ":8080", + opts: cmdOptions{addr: ":8080"}, envVars: map[string]string{"NETWORK": "tcp4"}, expected: "tcp4", }, @@ -216,7 +223,7 @@ func TestGetNetwork(t *testing.T) { for k, v := range tt.envVars { t.Setenv(k, v) } - network := getNetwork(tt.addr) + network := tt.opts.getNetwork() assert.Equal(t, tt.expected, network) }) } @@ -233,12 +240,12 @@ func TestParseFlags(t *testing.T) { tests := []struct { name string args []string - expected cmdConfig + expected cmdOptions }{ { name: "no flags", args: []string{"cmd"}, - expected: cmdConfig{ + expected: cmdOptions{ key: "", addr: "", }, @@ -246,7 +253,7 @@ func TestParseFlags(t *testing.T) { { name: "key flag only", args: []string{"cmd", "-key", "testkey"}, - expected: cmdConfig{ + expected: cmdOptions{ key: "testkey", addr: "", }, @@ -254,7 +261,7 @@ func TestParseFlags(t *testing.T) { { name: "addr flag only", args: []string{"cmd", "-addr", ":8080"}, - expected: cmdConfig{ + expected: cmdOptions{ key: "", addr: ":8080", }, @@ -262,7 +269,7 @@ func TestParseFlags(t *testing.T) { { name: "both flags", args: []string{"cmd", "-key", "testkey", "-addr", ":8080"}, - expected: cmdConfig{ + expected: cmdOptions{ key: "testkey", addr: ":8080", }, @@ -275,11 +282,11 @@ func TestParseFlags(t *testing.T) { flag.CommandLine = flag.NewFlagSet(tt.args[0], flag.ExitOnError) os.Args = tt.args - var cfg cmdConfig - parseFlags(&cfg) + var opts cmdOptions + opts.parseFlags() flag.Parse() - assert.Equal(t, tt.expected, cfg) + assert.Equal(t, tt.expected, opts) }) } } diff --git a/beszel/internal/agent/server.go b/beszel/internal/agent/server.go index 1830fc65..627b5c85 100644 --- a/beszel/internal/agent/server.go +++ b/beszel/internal/agent/server.go @@ -12,41 +12,41 @@ import ( "golang.org/x/crypto/ssh" ) -type ServerConfig struct { +type ServerOptions struct { Addr string Network string Keys []ssh.PublicKey } -func (a *Agent) StartServer(cfg ServerConfig) error { +func (a *Agent) StartServer(opts ServerOptions) error { sshServer.Handle(a.handleSession) - slog.Info("Starting SSH server", "addr", cfg.Addr, "network", cfg.Network) + slog.Info("Starting SSH server", "addr", opts.Addr, "network", opts.Network) - switch cfg.Network { + switch opts.Network { case "unix": // remove existing socket file if it exists - if err := os.Remove(cfg.Addr); err != nil && !os.IsNotExist(err) { + if err := os.Remove(opts.Addr); err != nil && !os.IsNotExist(err) { return err } default: // prefix with : if only port was provided - if !strings.Contains(cfg.Addr, ":") { - cfg.Addr = ":" + cfg.Addr + if !strings.Contains(opts.Addr, ":") { + opts.Addr = ":" + opts.Addr } } // Listen on the address - ln, err := net.Listen(cfg.Network, cfg.Addr) + ln, err := net.Listen(opts.Network, opts.Addr) if err != nil { return err } defer ln.Close() - // Start server on the listener + // Start SSH server on the listener err = sshServer.Serve(ln, nil, sshServer.NoPty(), sshServer.PublicKeyAuth(func(ctx sshServer.Context, key sshServer.PublicKey) bool { - for _, pubKey := range cfg.Keys { + for _, pubKey := range opts.Keys { if sshServer.KeysEqual(key, pubKey) { return true } diff --git a/beszel/internal/agent/server_test.go b/beszel/internal/agent/server_test.go index 41b3399c..6bbc90e0 100644 --- a/beszel/internal/agent/server_test.go +++ b/beszel/internal/agent/server_test.go @@ -35,7 +35,7 @@ func TestStartServer(t *testing.T) { tests := []struct { name string - config ServerConfig + config ServerOptions wantErr bool errContains string setup func() error @@ -43,7 +43,7 @@ func TestStartServer(t *testing.T) { }{ { name: "tcp port only", - config: ServerConfig{ + config: ServerOptions{ Network: "tcp", Addr: "45987", Keys: []ssh.PublicKey{sshPubKey}, @@ -51,7 +51,7 @@ func TestStartServer(t *testing.T) { }, { name: "tcp with ipv4", - config: ServerConfig{ + config: ServerOptions{ Network: "tcp4", Addr: "127.0.0.1:45988", Keys: []ssh.PublicKey{sshPubKey}, @@ -59,7 +59,7 @@ func TestStartServer(t *testing.T) { }, { name: "tcp with ipv6", - config: ServerConfig{ + config: ServerOptions{ Network: "tcp6", Addr: "[::1]:45989", Keys: []ssh.PublicKey{sshPubKey}, @@ -67,7 +67,7 @@ func TestStartServer(t *testing.T) { }, { name: "unix socket", - config: ServerConfig{ + config: ServerOptions{ Network: "unix", Addr: socketFile, Keys: []ssh.PublicKey{sshPubKey}, @@ -86,7 +86,7 @@ func TestStartServer(t *testing.T) { }, { name: "bad key should fail", - config: ServerConfig{ + config: ServerOptions{ Network: "tcp", Addr: "45987", Keys: []ssh.PublicKey{sshBadPubKey}, @@ -94,6 +94,14 @@ func TestStartServer(t *testing.T) { wantErr: true, errContains: "ssh: handshake failed", }, + { + name: "good key still good", + config: ServerOptions{ + Network: "tcp", + Addr: "45987", + Keys: []ssh.PublicKey{sshPubKey}, + }, + }, } for _, tt := range tests {