Refactor SSH configuration and key management

- Restrict to specific key exchanges / MACs / ciphers.
- Refactored GetSSHKey method to return an ssh.Signer instead of byte array.
- Added common package.

Co-authored-by: nhas <jordanatararimu@gmail.com>
This commit is contained in:
henrygd
2025-05-07 20:03:21 -04:00
parent c0a6153a43
commit 63af81666b
4 changed files with 74 additions and 64 deletions

View File

@@ -1,6 +1,7 @@
package agent
import (
"beszel/internal/common"
"encoding/json"
"fmt"
"log/slog"
@@ -19,8 +20,6 @@ type ServerOptions struct {
}
func (a *Agent) StartServer(opts ServerOptions) error {
ssh.Handle(a.handleSession)
slog.Info("Starting SSH server", "addr", opts.Addr, "network", opts.Network)
if opts.Network == "unix" {
@@ -37,17 +36,40 @@ func (a *Agent) StartServer(opts ServerOptions) error {
}
defer ln.Close()
// Start SSH server on the listener
return ssh.Serve(ln, nil, ssh.NoPty(),
ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
// base config (limit to allowed algorithms)
config := &gossh.ServerConfig{}
config.KeyExchanges = common.DefaultKeyExchanges
config.MACs = common.DefaultMACs
config.Ciphers = common.DefaultCiphers
// set default handler
ssh.Handle(a.handleSession)
server := ssh.Server{
ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig {
return config
},
// check public key(s)
PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
for _, pubKey := range opts.Keys {
if ssh.KeysEqual(key, pubKey) {
return true
}
}
return false
}),
)
},
// disable pty
PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool {
return false
},
// log failed connections
ConnectionFailedCallback: func(conn net.Conn, err error) {
slog.Warn("Failed connection attempt", "addr", conn.RemoteAddr().String(), "err", err)
},
}
// Start SSH server on the listener
return server.Serve(ln)
}
func (a *Agent) handleSession(s ssh.Session) {
@@ -56,6 +78,7 @@ func (a *Agent) handleSession(s ssh.Session) {
if err := json.NewEncoder(s).Encode(stats); err != nil {
slog.Error("Error encoding stats", "err", err, "stats", stats)
s.Exit(1)
return
}
s.Exit(0)
}