add health check for agent

- Updated command-line flag parsing.
- Moved GetAddress and GetNetwork to server.go
This commit is contained in:
henrygd
2025-03-14 03:33:25 -04:00
parent 400ea89587
commit edefc6f53e
7 changed files with 228 additions and 74 deletions

View File

@@ -0,0 +1,19 @@
package agent
import (
"net"
"time"
)
// Health checks if the agent's server is running by attempting to connect to it.
// It returns 0 if the server is running, 1 otherwise (as in exit codes).
//
// If an error occurs when attempting to connect to the server, it returns the error.
func Health(addr string, network string) (int, error) {
conn, err := net.DialTimeout(network, addr, 4*time.Second)
if err != nil {
return 1, err
}
conn.Close()
return 0, nil
}

View File

@@ -0,0 +1,130 @@
//go:build testing
// +build testing
package agent_test
import (
"fmt"
"net"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"beszel/internal/agent"
)
// setupTestServer creates a temporary server for testing
func setupTestServer(t *testing.T) (string, func()) {
// Create a temporary socket file for Unix socket testing
tempSockFile := os.TempDir() + "/beszel_health_test.sock"
// Clean up any existing socket file
os.Remove(tempSockFile)
// Create a listener
listener, err := net.Listen("unix", tempSockFile)
require.NoError(t, err, "Failed to create test listener")
// Start a simple server in a goroutine
go func() {
conn, err := listener.Accept()
if err != nil {
return // Listener closed
}
defer conn.Close()
// Just accept the connection and do nothing
}()
// Return the socket file path and a cleanup function
return tempSockFile, func() {
listener.Close()
os.Remove(tempSockFile)
}
}
// setupTCPTestServer creates a temporary TCP server for testing
func setupTCPTestServer(t *testing.T) (string, func()) {
// Listen on a random available port
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "Failed to create test listener")
// Get the port that was assigned
addr := listener.Addr().(*net.TCPAddr)
port := addr.Port
// Start a simple server in a goroutine
go func() {
conn, err := listener.Accept()
if err != nil {
return // Listener closed
}
defer conn.Close()
// Just accept the connection and do nothing
}()
// Return the address and a cleanup function
return fmt.Sprintf("127.0.0.1:%d", port), func() {
listener.Close()
}
}
func TestHealth(t *testing.T) {
t.Run("server is running (unix socket)", func(t *testing.T) {
// Setup a test server
sockFile, cleanup := setupTestServer(t)
defer cleanup()
// Run the health check with explicit parameters
result, err := agent.Health(sockFile, "unix")
require.NoError(t, err, "Failed to check health")
// Verify the result
assert.Equal(t, 0, result, "Health check should return 0 when server is running")
})
t.Run("server is running (tcp address)", func(t *testing.T) {
// Setup a test server
addr, cleanup := setupTCPTestServer(t)
defer cleanup()
// Run the health check with explicit parameters
result, err := agent.Health(addr, "tcp")
require.NoError(t, err, "Failed to check health")
// Verify the result
assert.Equal(t, 0, result, "Health check should return 0 when server is running")
})
t.Run("server is not running", func(t *testing.T) {
// Use an address that's likely not in use
addr := "127.0.0.1:65535"
// Run the health check with explicit parameters
result, err := agent.Health(addr, "tcp")
require.Error(t, err, "Health check should return an error when server is not running")
// Verify the result
assert.Equal(t, 1, result, "Health check should return 1 when server is not running")
})
t.Run("invalid network", func(t *testing.T) {
// Use an invalid network type
result, err := agent.Health("127.0.0.1:8080", "invalid_network")
require.Error(t, err, "Health check should return an error with invalid network")
assert.Equal(t, 1, result, "Health check should return 1 when network is invalid")
})
t.Run("unix socket not found", func(t *testing.T) {
// Use a non-existent unix socket
nonExistentSocket := os.TempDir() + "/non_existent_socket.sock"
// Make sure it really doesn't exist
os.Remove(nonExistentSocket)
result, err := agent.Health(nonExistentSocket, "unix")
require.Error(t, err, "Health check should return an error when socket doesn't exist")
assert.Equal(t, 1, result, "Health check should return 1 when socket doesn't exist")
})
}

View File

@@ -17,7 +17,7 @@ func (a *Agent) initializeNetIoStats() {
nics, nicsEnvExists := GetEnv("NICS")
if nicsEnvExists {
nicsMap = make(map[string]struct{}, 0)
for _, nic := range strings.Split(nics, ",") {
for nic := range strings.SplitSeq(nics, ",") {
nicsMap[nic] = struct{}{}
}
}

View File

@@ -23,20 +23,14 @@ func (a *Agent) StartServer(opts ServerOptions) error {
slog.Info("Starting SSH server", "addr", opts.Addr, "network", opts.Network)
switch opts.Network {
case "unix":
if opts.Network == "unix" {
// remove existing socket file if it exists
if err := os.Remove(opts.Addr); err != nil && !os.IsNotExist(err) {
return err
}
default:
// prefix with : if only port was provided
if !strings.Contains(opts.Addr, ":") {
opts.Addr = ":" + opts.Addr
}
}
// Listen on the address
// start listening on the address
ln, err := net.Listen(opts.Network, opts.Addr)
if err != nil {
return err
@@ -44,7 +38,7 @@ func (a *Agent) StartServer(opts ServerOptions) error {
defer ln.Close()
// Start SSH server on the listener
err = sshServer.Serve(ln, nil, sshServer.NoPty(),
return sshServer.Serve(ln, nil, sshServer.NoPty(),
sshServer.PublicKeyAuth(func(ctx sshServer.Context, key sshServer.PublicKey) bool {
for _, pubKey := range opts.Keys {
if sshServer.KeysEqual(key, pubKey) {
@@ -54,10 +48,6 @@ func (a *Agent) StartServer(opts ServerOptions) error {
return false
}),
)
if err != nil {
return err
}
return nil
}
func (a *Agent) handleSession(s sshServer.Session) {
@@ -89,3 +79,33 @@ func ParseKeys(input string) ([]ssh.PublicKey, error) {
}
return parsedKeys, nil
}
// GetAddress gets the address to listen on or connect to from environment variables or default value.
func GetAddress(addr string) string {
if addr == "" {
addr, _ = GetEnv("LISTEN")
}
if addr == "" {
// Legacy PORT environment variable support
addr, _ = GetEnv("PORT")
}
if addr == "" {
return ":45876"
}
// prefix with : if only port was provided
if GetNetwork(addr) != "unix" && !strings.Contains(addr, ":") {
addr = ":" + addr
}
return addr
}
// GetNetwork returns the network type to use based on the address
func GetNetwork(addr string) string {
if network, ok := GetEnv("NETWORK"); ok && network != "" {
return network
}
if strings.HasPrefix(addr, "/") {
return "unix"
}
return "tcp"
}

View File

@@ -45,7 +45,7 @@ func TestStartServer(t *testing.T) {
name: "tcp port only",
config: ServerOptions{
Network: "tcp",
Addr: "45987",
Addr: ":45987",
Keys: []ssh.PublicKey{sshPubKey},
},
},
@@ -88,7 +88,7 @@ func TestStartServer(t *testing.T) {
name: "bad key should fail",
config: ServerOptions{
Network: "tcp",
Addr: "45987",
Addr: ":45987",
Keys: []ssh.PublicKey{sshBadPubKey},
},
wantErr: true,
@@ -98,7 +98,7 @@ func TestStartServer(t *testing.T) {
name: "good key still good",
config: ServerOptions{
Network: "tcp",
Addr: "45987",
Addr: ":45987",
Keys: []ssh.PublicKey{sshPubKey},
},
},