mirror of
https://github.com/henrygd/beszel.git
synced 2025-12-17 02:36:17 +01:00
feat(agent): NETWORK env var and support for multiple keys
- merges agent.Run with agent.NewAgent - separates StartServer method - bumps go version to 1.24 - add tests
This commit is contained in:
@@ -28,29 +28,16 @@ type Agent struct {
|
||||
}
|
||||
|
||||
func NewAgent() *Agent {
|
||||
newAgent := &Agent{
|
||||
sensorsContext: context.Background(),
|
||||
fsStats: make(map[string]*system.FsStats),
|
||||
agent := &Agent{
|
||||
fsStats: make(map[string]*system.FsStats),
|
||||
}
|
||||
newAgent.memCalc, _ = GetEnv("MEM_CALC")
|
||||
return newAgent
|
||||
}
|
||||
agent.memCalc, _ = GetEnv("MEM_CALC")
|
||||
|
||||
// GetEnv retrieves an environment variable with a "BESZEL_AGENT_" prefix, or falls back to the unprefixed key.
|
||||
func GetEnv(key string) (value string, exists bool) {
|
||||
if value, exists = os.LookupEnv("BESZEL_AGENT_" + key); exists {
|
||||
return value, exists
|
||||
}
|
||||
// Fallback to the old unprefixed key
|
||||
return os.LookupEnv(key)
|
||||
}
|
||||
|
||||
func (a *Agent) Run(pubKey []byte, addr string) {
|
||||
// Set up slog with a log level determined by the LOG_LEVEL env var
|
||||
if logLevelStr, exists := GetEnv("LOG_LEVEL"); exists {
|
||||
switch strings.ToLower(logLevelStr) {
|
||||
case "debug":
|
||||
a.debug = true
|
||||
agent.debug = true
|
||||
slog.SetLogLoggerLevel(slog.LevelDebug)
|
||||
case "warn":
|
||||
slog.SetLogLoggerLevel(slog.LevelWarn)
|
||||
@@ -64,40 +51,51 @@ func (a *Agent) Run(pubKey []byte, addr string) {
|
||||
// Set sensors context (allows overriding sys location for sensors)
|
||||
if sysSensors, exists := GetEnv("SYS_SENSORS"); exists {
|
||||
slog.Info("SYS_SENSORS", "path", sysSensors)
|
||||
a.sensorsContext = context.WithValue(a.sensorsContext,
|
||||
agent.sensorsContext = context.WithValue(agent.sensorsContext,
|
||||
common.EnvKey, common.EnvMap{common.HostSysEnvKey: sysSensors},
|
||||
)
|
||||
} else {
|
||||
agent.sensorsContext = context.Background()
|
||||
}
|
||||
|
||||
// Set sensors whitelist
|
||||
if sensors, exists := GetEnv("SENSORS"); exists {
|
||||
a.sensorsWhitelist = make(map[string]struct{})
|
||||
agent.sensorsWhitelist = make(map[string]struct{})
|
||||
for _, sensor := range strings.Split(sensors, ",") {
|
||||
if sensor != "" {
|
||||
a.sensorsWhitelist[sensor] = struct{}{}
|
||||
agent.sensorsWhitelist[sensor] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// initialize system info / docker manager
|
||||
a.initializeSystemInfo()
|
||||
a.initializeDiskInfo()
|
||||
a.initializeNetIoStats()
|
||||
a.dockerManager = newDockerManager(a)
|
||||
agent.initializeSystemInfo()
|
||||
agent.initializeDiskInfo()
|
||||
agent.initializeNetIoStats()
|
||||
agent.dockerManager = newDockerManager(agent)
|
||||
|
||||
// initialize GPU manager
|
||||
if gm, err := NewGPUManager(); err != nil {
|
||||
slog.Debug("GPU", "err", err)
|
||||
} else {
|
||||
a.gpuManager = gm
|
||||
agent.gpuManager = gm
|
||||
}
|
||||
|
||||
// if debugging, print stats
|
||||
if a.debug {
|
||||
slog.Debug("Stats", "data", a.gatherStats())
|
||||
if agent.debug {
|
||||
slog.Debug("Stats", "data", agent.gatherStats())
|
||||
}
|
||||
|
||||
a.startServer(pubKey, addr)
|
||||
return agent
|
||||
}
|
||||
|
||||
// GetEnv retrieves an environment variable with a "BESZEL_AGENT_" prefix, or falls back to the unprefixed key.
|
||||
func GetEnv(key string) (value string, exists bool) {
|
||||
if value, exists = os.LookupEnv("BESZEL_AGENT_" + key); exists {
|
||||
return value, exists
|
||||
}
|
||||
// Fallback to the old unprefixed key
|
||||
return os.LookupEnv(key)
|
||||
}
|
||||
|
||||
func (a *Agent) gatherStats() system.CombinedData {
|
||||
|
||||
@@ -2,33 +2,96 @@ package agent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
sshServer "github.com/gliderlabs/ssh"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func (a *Agent) startServer(pubKey []byte, addr string) {
|
||||
type ServerConfig struct {
|
||||
Addr string
|
||||
Network string
|
||||
Keys []ssh.PublicKey
|
||||
}
|
||||
|
||||
func (a *Agent) StartServer(cfg ServerConfig) error {
|
||||
sshServer.Handle(a.handleSession)
|
||||
|
||||
slog.Info("Starting SSH server", "address", addr)
|
||||
if err := sshServer.ListenAndServe(addr, nil, sshServer.NoPty(),
|
||||
sshServer.PublicKeyAuth(func(ctx sshServer.Context, key sshServer.PublicKey) bool {
|
||||
allowed, _, _, _, _ := sshServer.ParseAuthorizedKey(pubKey)
|
||||
return sshServer.KeysEqual(key, allowed)
|
||||
}),
|
||||
); err != nil {
|
||||
slog.Error("Error starting SSH server", "err", err)
|
||||
os.Exit(1)
|
||||
slog.Info("Starting SSH server", "addr", cfg.Addr, "network", cfg.Network)
|
||||
|
||||
switch cfg.Network {
|
||||
case "unix":
|
||||
// remove existing socket file if it exists
|
||||
if err := os.Remove(cfg.Addr); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
// prefix with : if only port was provided
|
||||
if !strings.Contains(cfg.Addr, ":") {
|
||||
cfg.Addr = ":" + cfg.Addr
|
||||
}
|
||||
}
|
||||
|
||||
// Listen on the address
|
||||
ln, err := net.Listen(cfg.Network, cfg.Addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
// Start server on the listener
|
||||
err = sshServer.Serve(ln, nil, sshServer.NoPty(),
|
||||
sshServer.PublicKeyAuth(func(ctx sshServer.Context, key sshServer.PublicKey) bool {
|
||||
for _, pubKey := range cfg.Keys {
|
||||
if sshServer.KeysEqual(key, pubKey) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Agent) handleSession(s sshServer.Session) {
|
||||
// slog.Debug("connection", "remoteaddr", s.RemoteAddr(), "user", s.User())
|
||||
stats := a.gatherStats()
|
||||
if err := json.NewEncoder(s).Encode(stats); err != nil {
|
||||
slog.Error("Error encoding stats", "err", err, "stats", stats)
|
||||
s.Exit(1)
|
||||
return
|
||||
}
|
||||
s.Exit(0)
|
||||
}
|
||||
|
||||
// ParseKeys parses a string containing SSH public keys in authorized_keys format.
|
||||
// It returns a slice of ssh.PublicKey and an error if any key fails to parse.
|
||||
func ParseKeys(input string) ([]ssh.PublicKey, error) {
|
||||
var parsedKeys []ssh.PublicKey
|
||||
|
||||
for line := range strings.Lines(input) {
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
// Skip empty lines or comments
|
||||
if len(line) == 0 || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse the key
|
||||
parsedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(line))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse key: %s, error: %w", line, err)
|
||||
}
|
||||
|
||||
// Append the parsed key to the list
|
||||
parsedKeys = append(parsedKeys, parsedKey)
|
||||
}
|
||||
|
||||
return parsedKeys, nil
|
||||
}
|
||||
|
||||
281
beszel/internal/agent/server_test.go
Normal file
281
beszel/internal/agent/server_test.go
Normal file
@@ -0,0 +1,281 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestStartServer(t *testing.T) {
|
||||
// Generate a test key pair
|
||||
pubKey, privKey, err := ed25519.GenerateKey(nil)
|
||||
require.NoError(t, err)
|
||||
signer, err := ssh.NewSignerFromKey(privKey)
|
||||
require.NoError(t, err)
|
||||
sshPubKey, err := ssh.NewPublicKey(pubKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate a different key pair for bad key test
|
||||
badPubKey, badPrivKey, err := ed25519.GenerateKey(nil)
|
||||
require.NoError(t, err)
|
||||
badSigner, err := ssh.NewSignerFromKey(badPrivKey)
|
||||
require.NoError(t, err)
|
||||
sshBadPubKey, err := ssh.NewPublicKey(badPubKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
socketFile := filepath.Join(t.TempDir(), "beszel-test.sock")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config ServerConfig
|
||||
wantErr bool
|
||||
errContains string
|
||||
setup func() error
|
||||
cleanup func() error
|
||||
}{
|
||||
{
|
||||
name: "tcp port only",
|
||||
config: ServerConfig{
|
||||
Network: "tcp",
|
||||
Addr: "45987",
|
||||
Keys: []ssh.PublicKey{sshPubKey},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tcp with ipv4",
|
||||
config: ServerConfig{
|
||||
Network: "tcp4",
|
||||
Addr: "127.0.0.1:45988",
|
||||
Keys: []ssh.PublicKey{sshPubKey},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tcp with ipv6",
|
||||
config: ServerConfig{
|
||||
Network: "tcp6",
|
||||
Addr: "[::1]:45989",
|
||||
Keys: []ssh.PublicKey{sshPubKey},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unix socket",
|
||||
config: ServerConfig{
|
||||
Network: "unix",
|
||||
Addr: socketFile,
|
||||
Keys: []ssh.PublicKey{sshPubKey},
|
||||
},
|
||||
setup: func() error {
|
||||
// Create a socket file that should be removed
|
||||
f, err := os.Create(socketFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return f.Close()
|
||||
},
|
||||
cleanup: func() error {
|
||||
return os.Remove(socketFile)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad key should fail",
|
||||
config: ServerConfig{
|
||||
Network: "tcp",
|
||||
Addr: "45987",
|
||||
Keys: []ssh.PublicKey{sshBadPubKey},
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "ssh: handshake failed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.setup != nil {
|
||||
err := tt.setup()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
if tt.cleanup != nil {
|
||||
defer tt.cleanup()
|
||||
}
|
||||
|
||||
agent := NewAgent()
|
||||
|
||||
// Start server in a goroutine since it blocks
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
errChan <- agent.StartServer(tt.config)
|
||||
}()
|
||||
|
||||
// Add a short delay to allow the server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Try to connect to verify server is running
|
||||
var client *ssh.Client
|
||||
var err error
|
||||
|
||||
// Choose the appropriate signer based on the test case
|
||||
testSigner := signer
|
||||
if tt.name == "bad key should fail" {
|
||||
testSigner = badSigner
|
||||
}
|
||||
|
||||
sshClientConfig := &ssh.ClientConfig{
|
||||
User: "a",
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.PublicKeys(testSigner),
|
||||
},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 4 * time.Second,
|
||||
}
|
||||
|
||||
switch tt.config.Network {
|
||||
case "unix":
|
||||
client, err = ssh.Dial("unix", tt.config.Addr, sshClientConfig)
|
||||
default:
|
||||
if !strings.Contains(tt.config.Addr, ":") {
|
||||
tt.config.Addr = ":" + tt.config.Addr
|
||||
}
|
||||
client, err = ssh.Dial("tcp", tt.config.Addr, sshClientConfig)
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
client.Close()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
//////////////////// ParseKeys Tests ////////////////////////////
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
// Helper function to generate a temporary file with content
|
||||
func createTempFile(content string) (string, error) {
|
||||
tmpFile, err := os.CreateTemp("", "ssh_keys_*.txt")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
defer tmpFile.Close()
|
||||
|
||||
if _, err := tmpFile.WriteString(content); err != nil {
|
||||
return "", fmt.Errorf("failed to write to temp file: %w", err)
|
||||
}
|
||||
|
||||
return tmpFile.Name(), nil
|
||||
}
|
||||
|
||||
// Test case 1: String with a single SSH key
|
||||
func TestParseSingleKeyFromString(t *testing.T) {
|
||||
input := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKCBM91kukN7hbvFKtbpEeo2JXjCcNxXcdBH7V7ADMBo"
|
||||
keys, err := ParseKeys(input)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
if len(keys) != 1 {
|
||||
t.Fatalf("Expected 1 key, got %d keys", len(keys))
|
||||
}
|
||||
if keys[0].Type() != "ssh-ed25519" {
|
||||
t.Fatalf("Expected key type 'ssh-ed25519', got '%s'", keys[0].Type())
|
||||
}
|
||||
}
|
||||
|
||||
// Test case 2: String with multiple SSH keys
|
||||
func TestParseMultipleKeysFromString(t *testing.T) {
|
||||
input := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKCBM91kukN7hbvFKtbpEeo2JXjCcNxXcdBH7V7ADMBo\nssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJDMtAOQfxDlCxe+A5lVbUY/DHxK1LAF2Z3AV0FYv36D \n #comment\n ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJDMtAOQfxDlCxe+A5lVbUY/DHxK1LAF2Z3AV0FYv36D"
|
||||
keys, err := ParseKeys(input)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
if len(keys) != 3 {
|
||||
t.Fatalf("Expected 3 keys, got %d keys", len(keys))
|
||||
}
|
||||
if keys[0].Type() != "ssh-ed25519" || keys[1].Type() != "ssh-ed25519" || keys[2].Type() != "ssh-ed25519" {
|
||||
t.Fatalf("Unexpected key types: %s, %s, %s", keys[0].Type(), keys[1].Type(), keys[2].Type())
|
||||
}
|
||||
}
|
||||
|
||||
// Test case 3: File with a single SSH key
|
||||
func TestParseSingleKeyFromFile(t *testing.T) {
|
||||
content := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKCBM91kukN7hbvFKtbpEeo2JXjCcNxXcdBH7V7ADMBo"
|
||||
filePath, err := createTempFile(content)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
defer os.Remove(filePath) // Clean up the file after the test
|
||||
|
||||
// Read the file content
|
||||
fileContent, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read temp file: %v", err)
|
||||
}
|
||||
|
||||
// Parse the keys
|
||||
keys, err := ParseKeys(string(fileContent))
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
if len(keys) != 1 {
|
||||
t.Fatalf("Expected 1 key, got %d keys", len(keys))
|
||||
}
|
||||
if keys[0].Type() != "ssh-ed25519" {
|
||||
t.Fatalf("Expected key type 'ssh-ed25519', got '%s'", keys[0].Type())
|
||||
}
|
||||
}
|
||||
|
||||
// Test case 4: File with multiple SSH keys
|
||||
func TestParseMultipleKeysFromFile(t *testing.T) {
|
||||
content := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKCBM91kukN7hbvFKtbpEeo2JXjCcNxXcdBH7V7ADMBo\nssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJDMtAOQfxDlCxe+A5lVbUY/DHxK1LAF2Z3AV0FYv36D \n #comment\n ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJDMtAOQfxDlCxe+A5lVbUY/DHxK1LAF2Z3AV0FYv36D"
|
||||
filePath, err := createTempFile(content)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
// defer os.Remove(filePath) // Clean up the file after the test
|
||||
|
||||
// Read the file content
|
||||
fileContent, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read temp file: %v", err)
|
||||
}
|
||||
|
||||
// Parse the keys
|
||||
keys, err := ParseKeys(string(fileContent))
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
if len(keys) != 3 {
|
||||
t.Fatalf("Expected 3 keys, got %d keys", len(keys))
|
||||
}
|
||||
if keys[0].Type() != "ssh-ed25519" || keys[1].Type() != "ssh-ed25519" || keys[2].Type() != "ssh-ed25519" {
|
||||
t.Fatalf("Unexpected key types: %s, %s, %s", keys[0].Type(), keys[1].Type(), keys[2].Type())
|
||||
}
|
||||
}
|
||||
|
||||
// Test case 5: Invalid SSH key input
|
||||
func TestParseInvalidKey(t *testing.T) {
|
||||
input := "invalid-key-data"
|
||||
_, err := ParseKeys(input)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected an error for invalid key, got nil")
|
||||
}
|
||||
expectedErrMsg := "failed to parse key"
|
||||
if !strings.Contains(err.Error(), expectedErrMsg) {
|
||||
t.Fatalf("Expected error message to contain '%s', got: %v", expectedErrMsg, err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user