mirror of
https://github.com/henrygd/beszel.git
synced 2025-12-17 10:46:16 +01:00
refactor(agent): refactor option parsing logic for agent command
This commit is contained in:
@@ -12,41 +12,41 @@ import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type ServerConfig struct {
|
||||
type ServerOptions struct {
|
||||
Addr string
|
||||
Network string
|
||||
Keys []ssh.PublicKey
|
||||
}
|
||||
|
||||
func (a *Agent) StartServer(cfg ServerConfig) error {
|
||||
func (a *Agent) StartServer(opts ServerOptions) error {
|
||||
sshServer.Handle(a.handleSession)
|
||||
|
||||
slog.Info("Starting SSH server", "addr", cfg.Addr, "network", cfg.Network)
|
||||
slog.Info("Starting SSH server", "addr", opts.Addr, "network", opts.Network)
|
||||
|
||||
switch cfg.Network {
|
||||
switch opts.Network {
|
||||
case "unix":
|
||||
// remove existing socket file if it exists
|
||||
if err := os.Remove(cfg.Addr); err != nil && !os.IsNotExist(err) {
|
||||
if err := os.Remove(opts.Addr); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
// prefix with : if only port was provided
|
||||
if !strings.Contains(cfg.Addr, ":") {
|
||||
cfg.Addr = ":" + cfg.Addr
|
||||
if !strings.Contains(opts.Addr, ":") {
|
||||
opts.Addr = ":" + opts.Addr
|
||||
}
|
||||
}
|
||||
|
||||
// Listen on the address
|
||||
ln, err := net.Listen(cfg.Network, cfg.Addr)
|
||||
ln, err := net.Listen(opts.Network, opts.Addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
// Start server on the listener
|
||||
// Start SSH server on the listener
|
||||
err = sshServer.Serve(ln, nil, sshServer.NoPty(),
|
||||
sshServer.PublicKeyAuth(func(ctx sshServer.Context, key sshServer.PublicKey) bool {
|
||||
for _, pubKey := range cfg.Keys {
|
||||
for _, pubKey := range opts.Keys {
|
||||
if sshServer.KeysEqual(key, pubKey) {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ func TestStartServer(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config ServerConfig
|
||||
config ServerOptions
|
||||
wantErr bool
|
||||
errContains string
|
||||
setup func() error
|
||||
@@ -43,7 +43,7 @@ func TestStartServer(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "tcp port only",
|
||||
config: ServerConfig{
|
||||
config: ServerOptions{
|
||||
Network: "tcp",
|
||||
Addr: "45987",
|
||||
Keys: []ssh.PublicKey{sshPubKey},
|
||||
@@ -51,7 +51,7 @@ func TestStartServer(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "tcp with ipv4",
|
||||
config: ServerConfig{
|
||||
config: ServerOptions{
|
||||
Network: "tcp4",
|
||||
Addr: "127.0.0.1:45988",
|
||||
Keys: []ssh.PublicKey{sshPubKey},
|
||||
@@ -59,7 +59,7 @@ func TestStartServer(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "tcp with ipv6",
|
||||
config: ServerConfig{
|
||||
config: ServerOptions{
|
||||
Network: "tcp6",
|
||||
Addr: "[::1]:45989",
|
||||
Keys: []ssh.PublicKey{sshPubKey},
|
||||
@@ -67,7 +67,7 @@ func TestStartServer(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "unix socket",
|
||||
config: ServerConfig{
|
||||
config: ServerOptions{
|
||||
Network: "unix",
|
||||
Addr: socketFile,
|
||||
Keys: []ssh.PublicKey{sshPubKey},
|
||||
@@ -86,7 +86,7 @@ func TestStartServer(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "bad key should fail",
|
||||
config: ServerConfig{
|
||||
config: ServerOptions{
|
||||
Network: "tcp",
|
||||
Addr: "45987",
|
||||
Keys: []ssh.PublicKey{sshBadPubKey},
|
||||
@@ -94,6 +94,14 @@ func TestStartServer(t *testing.T) {
|
||||
wantErr: true,
|
||||
errContains: "ssh: handshake failed",
|
||||
},
|
||||
{
|
||||
name: "good key still good",
|
||||
config: ServerOptions{
|
||||
Network: "tcp",
|
||||
Addr: "45987",
|
||||
Keys: []ssh.PublicKey{sshPubKey},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
Reference in New Issue
Block a user