mirror of
https://github.com/henrygd/beszel.git
synced 2025-12-17 10:46:16 +01:00
feat(agent): NETWORK env var and support for multiple keys
- merges agent.Run with agent.NewAgent - separates StartServer method - bumps go version to 1.24 - add tests
This commit is contained in:
@@ -8,12 +8,19 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Define flags for key and port
|
||||
keyFlag := flag.String("key", "", "Public key")
|
||||
portFlag := flag.String("port", "45876", "Port number")
|
||||
type cmdConfig struct {
|
||||
key string // key is the public key(s) for SSH authentication.
|
||||
addr string // addr is the address or port to listen on.
|
||||
}
|
||||
|
||||
// parseFlags parses the command line flags and populates the config struct.
|
||||
func parseFlags(cfg *cmdConfig) {
|
||||
flag.StringVar(&cfg.key, "key", "", "Public key(s) for SSH authentication")
|
||||
flag.StringVar(&cfg.addr, "addr", "", "Address or port to listen on")
|
||||
|
||||
flag.Usage = func() {
|
||||
fmt.Printf("Usage: %s [options] [subcommand]\n", os.Args[0])
|
||||
@@ -24,65 +31,103 @@ func main() {
|
||||
fmt.Println(" help Display this help message")
|
||||
fmt.Println(" update Update the agent to the latest version")
|
||||
}
|
||||
}
|
||||
|
||||
// handleSubcommand handles subcommands such as version, help, and update.
|
||||
// It returns true if a subcommand was handled, false otherwise.
|
||||
func handleSubcommand() bool {
|
||||
if len(os.Args) <= 1 {
|
||||
return false
|
||||
}
|
||||
switch os.Args[1] {
|
||||
case "version", "-v":
|
||||
fmt.Println(beszel.AppName+"-agent", beszel.Version)
|
||||
os.Exit(0)
|
||||
case "help":
|
||||
flag.Usage()
|
||||
os.Exit(0)
|
||||
case "update":
|
||||
agent.Update()
|
||||
os.Exit(0)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// loadPublicKeys loads the public keys from the command line flag, environment variable, or key file.
|
||||
func loadPublicKeys(cfg cmdConfig) ([]ssh.PublicKey, error) {
|
||||
// Try command line flag first
|
||||
if cfg.key != "" {
|
||||
return agent.ParseKeys(cfg.key)
|
||||
}
|
||||
|
||||
// Try environment variable
|
||||
if key, ok := agent.GetEnv("KEY"); ok && key != "" {
|
||||
return agent.ParseKeys(key)
|
||||
}
|
||||
|
||||
// Try key file
|
||||
keyFile, ok := agent.GetEnv("KEY_FILE")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no key provided: must set -key flag, KEY env var, or KEY_FILE env var. ")
|
||||
}
|
||||
|
||||
pubKey, err := os.ReadFile(keyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read key file: %w", err)
|
||||
}
|
||||
return agent.ParseKeys(string(pubKey))
|
||||
}
|
||||
|
||||
// getAddress gets the address to listen on from the command line flag, environment variable, or default value.
|
||||
func getAddress(addr string) string {
|
||||
// Try command line flag first
|
||||
if addr != "" {
|
||||
return addr
|
||||
}
|
||||
// Try environment variables
|
||||
if addr, ok := agent.GetEnv("ADDR"); ok && addr != "" {
|
||||
return addr
|
||||
}
|
||||
// Legacy PORT environment variable support
|
||||
if port, ok := agent.GetEnv("PORT"); ok && port != "" {
|
||||
return port
|
||||
}
|
||||
return ":45876"
|
||||
}
|
||||
|
||||
// getNetwork returns the network type to use for the server.
|
||||
func getNetwork(addr string) string {
|
||||
if network, _ := agent.GetEnv("NETWORK"); network != "" {
|
||||
return network
|
||||
}
|
||||
if strings.HasPrefix(addr, "/") {
|
||||
return "unix"
|
||||
}
|
||||
return "tcp"
|
||||
}
|
||||
|
||||
func main() {
|
||||
var cfg cmdConfig
|
||||
parseFlags(&cfg)
|
||||
|
||||
if handleSubcommand() {
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the flags
|
||||
flag.Parse()
|
||||
|
||||
// handle flags / subcommands
|
||||
if len(os.Args) > 1 {
|
||||
switch os.Args[1] {
|
||||
case "version":
|
||||
fmt.Println(beszel.AppName+"-agent", beszel.Version)
|
||||
os.Exit(0)
|
||||
case "help":
|
||||
flag.Usage()
|
||||
os.Exit(0)
|
||||
case "update":
|
||||
agent.Update()
|
||||
os.Exit(0)
|
||||
}
|
||||
var serverConfig agent.ServerConfig
|
||||
var err error
|
||||
serverConfig.Keys, err = loadPublicKeys(cfg)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to load public keys:", err)
|
||||
}
|
||||
|
||||
var pubKey []byte
|
||||
// Override the key if the -key flag is provided
|
||||
if *keyFlag != "" {
|
||||
pubKey = []byte(*keyFlag)
|
||||
} else {
|
||||
// Try to get the key from the KEY environment variable.
|
||||
key, _ := agent.GetEnv("KEY")
|
||||
pubKey = []byte(key)
|
||||
}
|
||||
serverConfig.Addr = getAddress(cfg.addr)
|
||||
serverConfig.Network = getNetwork(cfg.addr)
|
||||
|
||||
// If KEY is not set, try to read the key from the file specified by KEY_FILE.
|
||||
if len(pubKey) == 0 {
|
||||
keyFile, exists := agent.GetEnv("KEY_FILE")
|
||||
if !exists {
|
||||
log.Fatal("Must set KEY or KEY_FILE environment variable or supply as input argument. Use 'beszel-agent help' for more information.")
|
||||
}
|
||||
var err error
|
||||
pubKey, err = os.ReadFile(keyFile)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
agent := agent.NewAgent()
|
||||
if err := agent.StartServer(serverConfig); err != nil {
|
||||
log.Fatal("Failed to start server:", err)
|
||||
}
|
||||
|
||||
// Init with default port
|
||||
addr := ":" + *portFlag
|
||||
|
||||
//Use port from ENV if it exists
|
||||
// TODO: change env var to ADDR
|
||||
if portEnvVar, exists := agent.GetEnv("PORT"); exists {
|
||||
// allow passing an address in the form of "127.0.0.1:45876"
|
||||
if !strings.Contains(portEnvVar, ":") {
|
||||
portEnvVar = ":" + portEnvVar
|
||||
}
|
||||
addr = portEnvVar
|
||||
}
|
||||
|
||||
// Override the default and ENV port if the -port flag is provided and is non default
|
||||
if *portFlag != "45876" {
|
||||
addr = ":" + *portFlag
|
||||
}
|
||||
|
||||
agent.NewAgent().Run(pubKey, addr)
|
||||
}
|
||||
|
||||
285
beszel/cmd/agent/agent_test.go
Normal file
285
beszel/cmd/agent/agent_test.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"flag"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestGetAddress(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg cmdConfig
|
||||
envVars map[string]string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "default port when no config",
|
||||
cfg: cmdConfig{},
|
||||
expected: ":45876",
|
||||
},
|
||||
{
|
||||
name: "use address from flag",
|
||||
cfg: cmdConfig{
|
||||
addr: "8080",
|
||||
},
|
||||
expected: "8080",
|
||||
},
|
||||
{
|
||||
name: "use unix socket from flag",
|
||||
cfg: cmdConfig{
|
||||
addr: "/tmp/beszel.sock",
|
||||
},
|
||||
expected: "/tmp/beszel.sock",
|
||||
},
|
||||
{
|
||||
name: "use ADDR env var",
|
||||
cfg: cmdConfig{},
|
||||
envVars: map[string]string{
|
||||
"ADDR": "1.2.3.4:9090",
|
||||
},
|
||||
expected: "1.2.3.4:9090",
|
||||
},
|
||||
{
|
||||
name: "use legacy PORT env var",
|
||||
cfg: cmdConfig{},
|
||||
envVars: map[string]string{
|
||||
"PORT": "7070",
|
||||
},
|
||||
expected: "7070",
|
||||
},
|
||||
{
|
||||
name: "flag takes precedence over env vars",
|
||||
cfg: cmdConfig{
|
||||
addr: ":8080",
|
||||
},
|
||||
envVars: map[string]string{
|
||||
"ADDR": ":9090",
|
||||
"PORT": "7070",
|
||||
},
|
||||
expected: ":8080",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Setup environment
|
||||
for k, v := range tt.envVars {
|
||||
t.Setenv(k, v)
|
||||
}
|
||||
|
||||
addr := getAddress(tt.cfg.addr)
|
||||
assert.Equal(t, tt.expected, addr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadPublicKeys(t *testing.T) {
|
||||
// Generate a test key
|
||||
_, priv, err := ed25519.GenerateKey(nil)
|
||||
require.NoError(t, err)
|
||||
signer, err := ssh.NewSignerFromKey(priv)
|
||||
require.NoError(t, err)
|
||||
pubKey := ssh.MarshalAuthorizedKey(signer.PublicKey())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg cmdConfig
|
||||
envVars map[string]string
|
||||
setupFiles map[string][]byte
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "load key from flag",
|
||||
cfg: cmdConfig{
|
||||
key: string(pubKey),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "load key from env var",
|
||||
envVars: map[string]string{
|
||||
"KEY": string(pubKey),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "load key from file",
|
||||
envVars: map[string]string{
|
||||
"KEY_FILE": "testkey.pub",
|
||||
},
|
||||
setupFiles: map[string][]byte{
|
||||
"testkey.pub": pubKey,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "error when no key provided",
|
||||
wantErr: true,
|
||||
errContains: "no key provided",
|
||||
},
|
||||
{
|
||||
name: "error on invalid key file",
|
||||
envVars: map[string]string{
|
||||
"KEY_FILE": "nonexistent.pub",
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "failed to read key file",
|
||||
},
|
||||
{
|
||||
name: "error on invalid key data",
|
||||
cfg: cmdConfig{
|
||||
key: "invalid-key-data",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create a temporary directory for test files
|
||||
if len(tt.setupFiles) > 0 {
|
||||
tmpDir := t.TempDir()
|
||||
for name, content := range tt.setupFiles {
|
||||
path := filepath.Join(tmpDir, name)
|
||||
err := os.WriteFile(path, content, 0600)
|
||||
require.NoError(t, err)
|
||||
if tt.envVars != nil {
|
||||
tt.envVars["KEY_FILE"] = path
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set up environment
|
||||
for k, v := range tt.envVars {
|
||||
t.Setenv(k, v)
|
||||
}
|
||||
|
||||
keys, err := loadPublicKeys(tt.cfg)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, keys, 1)
|
||||
assert.Equal(t, signer.PublicKey().Type(), keys[0].Type())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetNetwork(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr string
|
||||
envVars map[string]string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "only port",
|
||||
addr: "8080",
|
||||
expected: "tcp",
|
||||
},
|
||||
{
|
||||
name: "ipv4 address",
|
||||
addr: "1.2.3.4:8080",
|
||||
expected: "tcp",
|
||||
},
|
||||
{
|
||||
name: "ipv6 address",
|
||||
addr: "[2001:db8::1]:8080",
|
||||
expected: "tcp",
|
||||
},
|
||||
{
|
||||
name: "unix network",
|
||||
addr: "/tmp/beszel.sock",
|
||||
expected: "unix",
|
||||
},
|
||||
{
|
||||
name: "env var network",
|
||||
addr: ":8080",
|
||||
envVars: map[string]string{"NETWORK": "tcp4"},
|
||||
expected: "tcp4",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Setup environment
|
||||
for k, v := range tt.envVars {
|
||||
t.Setenv(k, v)
|
||||
}
|
||||
network := getNetwork(tt.addr)
|
||||
assert.Equal(t, tt.expected, network)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFlags(t *testing.T) {
|
||||
// Save original command line arguments and restore after test
|
||||
oldArgs := os.Args
|
||||
defer func() {
|
||||
os.Args = oldArgs
|
||||
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expected cmdConfig
|
||||
}{
|
||||
{
|
||||
name: "no flags",
|
||||
args: []string{"cmd"},
|
||||
expected: cmdConfig{
|
||||
key: "",
|
||||
addr: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "key flag only",
|
||||
args: []string{"cmd", "-key", "testkey"},
|
||||
expected: cmdConfig{
|
||||
key: "testkey",
|
||||
addr: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "addr flag only",
|
||||
args: []string{"cmd", "-addr", ":8080"},
|
||||
expected: cmdConfig{
|
||||
key: "",
|
||||
addr: ":8080",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "both flags",
|
||||
args: []string{"cmd", "-key", "testkey", "-addr", ":8080"},
|
||||
expected: cmdConfig{
|
||||
key: "testkey",
|
||||
addr: ":8080",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset flags for each test
|
||||
flag.CommandLine = flag.NewFlagSet(tt.args[0], flag.ExitOnError)
|
||||
os.Args = tt.args
|
||||
|
||||
var cfg cmdConfig
|
||||
parseFlags(&cfg)
|
||||
flag.Parse()
|
||||
|
||||
assert.Equal(t, tt.expected, cfg)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user