diff --git a/agent/agent.go b/agent/agent.go index 6c6a9072..8a3cd801 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -12,33 +12,36 @@ import ( "path/filepath" "strings" "sync" - "time" "github.com/gliderlabs/ssh" "github.com/henrygd/beszel" + "github.com/henrygd/beszel/agent/deltatracker" "github.com/henrygd/beszel/internal/entities/system" "github.com/shirou/gopsutil/v4/host" gossh "golang.org/x/crypto/ssh" ) type Agent struct { - sync.Mutex // Used to lock agent while collecting data - debug bool // true if LOG_LEVEL is set to debug - zfs bool // true if system has arcstats - memCalc string // Memory calculation formula - fsNames []string // List of filesystem device names being monitored - fsStats map[string]*system.FsStats // Keeps track of disk stats for each filesystem - netInterfaces map[string]struct{} // Stores all valid network interfaces - netIoStats system.NetIoStats // Keeps track of bandwidth usage - dockerManager *dockerManager // Manages Docker API requests - sensorConfig *SensorConfig // Sensors config - systemInfo system.Info // Host system info - gpuManager *GPUManager // Manages GPU data - cache *SessionCache // Cache for system stats based on primary session ID - connectionManager *ConnectionManager // Channel to signal connection events - server *ssh.Server // SSH server - dataDir string // Directory for persisting data - keys []gossh.PublicKey // SSH public keys + sync.Mutex // Used to lock agent while collecting data + debug bool // true if LOG_LEVEL is set to debug + zfs bool // true if system has arcstats + memCalc string // Memory calculation formula + fsNames []string // List of filesystem device names being monitored + fsStats map[string]*system.FsStats // Keeps track of disk stats for each filesystem + diskPrev map[uint16]map[string]prevDisk // Previous disk I/O counters per cache interval + netInterfaces map[string]struct{} // Stores all valid network interfaces + netIoStats map[uint16]system.NetIoStats // Keeps track of bandwidth usage per cache interval + netInterfaceDeltaTrackers map[uint16]*deltatracker.DeltaTracker[string, uint64] // Per-cache-time NIC delta trackers + dockerManager *dockerManager // Manages Docker API requests + sensorConfig *SensorConfig // Sensors config + systemInfo system.Info // Host system info + gpuManager *GPUManager // Manages GPU data + cache *systemDataCache // Cache for system stats based on cache time + connectionManager *ConnectionManager // Channel to signal connection events + handlerRegistry *HandlerRegistry // Registry for routing incoming messages + server *ssh.Server // SSH server + dataDir string // Directory for persisting data + keys []gossh.PublicKey // SSH public keys } // NewAgent creates a new agent with the given data directory for persisting data. @@ -46,9 +49,15 @@ type Agent struct { func NewAgent(dataDir ...string) (agent *Agent, err error) { agent = &Agent{ fsStats: make(map[string]*system.FsStats), - cache: NewSessionCache(69 * time.Second), + cache: NewSystemDataCache(), } + // Initialize disk I/O previous counters storage + agent.diskPrev = make(map[uint16]map[string]prevDisk) + // Initialize per-cache-time network tracking structures + agent.netIoStats = make(map[uint16]system.NetIoStats) + agent.netInterfaceDeltaTrackers = make(map[uint16]*deltatracker.DeltaTracker[string, uint64]) + agent.dataDir, err = getDataDir(dataDir...) if err != nil { slog.Warn("Data directory not found") @@ -79,6 +88,9 @@ func NewAgent(dataDir ...string) (agent *Agent, err error) { // initialize connection manager agent.connectionManager = newConnectionManager(agent) + // initialize handler registry + agent.handlerRegistry = NewHandlerRegistry() + // initialize disk info agent.initializeDiskInfo() @@ -97,7 +109,7 @@ func NewAgent(dataDir ...string) (agent *Agent, err error) { // if debugging, print stats if agent.debug { - slog.Debug("Stats", "data", agent.gatherStats("")) + slog.Debug("Stats", "data", agent.gatherStats(0)) } return agent, nil @@ -112,24 +124,24 @@ func GetEnv(key string) (value string, exists bool) { return os.LookupEnv(key) } -func (a *Agent) gatherStats(sessionID string) *system.CombinedData { +func (a *Agent) gatherStats(cacheTimeMs uint16) *system.CombinedData { a.Lock() defer a.Unlock() - data, isCached := a.cache.Get(sessionID) + data, isCached := a.cache.Get(cacheTimeMs) if isCached { - slog.Debug("Cached data", "session", sessionID) + slog.Debug("Cached data", "cacheTimeMs", cacheTimeMs) return data } *data = system.CombinedData{ - Stats: a.getSystemStats(), + Stats: a.getSystemStats(cacheTimeMs), Info: a.systemInfo, } - slog.Debug("System data", "data", data) + // slog.Info("System data", "data", data, "cacheTimeMs", cacheTimeMs) if a.dockerManager != nil { - if containerStats, err := a.dockerManager.getDockerStats(); err == nil { + if containerStats, err := a.dockerManager.getDockerStats(cacheTimeMs); err == nil { data.Containers = containerStats slog.Debug("Containers", "data", data.Containers) } else { @@ -145,7 +157,7 @@ func (a *Agent) gatherStats(sessionID string) *system.CombinedData { } slog.Debug("Extra FS", "data", data.Stats.ExtraFs) - a.cache.Set(sessionID, data) + a.cache.Set(data, cacheTimeMs) return data } diff --git a/agent/agent_cache.go b/agent/agent_cache.go index 7425cd06..bec929aa 100644 --- a/agent/agent_cache.go +++ b/agent/agent_cache.go @@ -1,37 +1,55 @@ package agent import ( + "sync" "time" "github.com/henrygd/beszel/internal/entities/system" ) -// Not thread safe since we only access from gatherStats which is already locked -type SessionCache struct { - data *system.CombinedData - lastUpdate time.Time - primarySession string - leaseTime time.Duration +type systemDataCache struct { + sync.RWMutex + cache map[uint16]*cacheNode } -func NewSessionCache(leaseTime time.Duration) *SessionCache { - return &SessionCache{ - leaseTime: leaseTime, - data: &system.CombinedData{}, +type cacheNode struct { + data *system.CombinedData + lastUpdate time.Time +} + +// NewSystemDataCache creates a cache keyed by the polling interval in milliseconds. +func NewSystemDataCache() *systemDataCache { + return &systemDataCache{ + cache: make(map[uint16]*cacheNode), } } -func (c *SessionCache) Get(sessionID string) (stats *system.CombinedData, isCached bool) { - if sessionID != c.primarySession && time.Since(c.lastUpdate) < c.leaseTime { - return c.data, true +// Get returns cached combined data when the entry is still considered fresh. +func (c *systemDataCache) Get(cacheTimeMs uint16) (stats *system.CombinedData, isCached bool) { + c.RLock() + defer c.RUnlock() + + node, ok := c.cache[cacheTimeMs] + if !ok { + return &system.CombinedData{}, false } - return c.data, false + // allowedSkew := time.Second + // isFresh := time.Since(node.lastUpdate) < time.Duration(cacheTimeMs)*time.Millisecond-allowedSkew + // allow a 50% skew of the cache time + isFresh := time.Since(node.lastUpdate) < time.Duration(cacheTimeMs/2)*time.Millisecond + return node.data, isFresh } -func (c *SessionCache) Set(sessionID string, data *system.CombinedData) { - if data != nil { - *c.data = *data +// Set stores the latest combined data snapshot for the given interval. +func (c *systemDataCache) Set(data *system.CombinedData, cacheTimeMs uint16) { + c.Lock() + defer c.Unlock() + + node, ok := c.cache[cacheTimeMs] + if !ok { + node = &cacheNode{} + c.cache[cacheTimeMs] = node } - c.primarySession = sessionID - c.lastUpdate = time.Now() + node.data = data + node.lastUpdate = time.Now() } diff --git a/agent/agent_cache_test.go b/agent/agent_cache_test.go index 87984c95..2930563f 100644 --- a/agent/agent_cache_test.go +++ b/agent/agent_cache_test.go @@ -8,82 +8,239 @@ import ( "testing/synctest" "time" + "github.com/henrygd/beszel/internal/entities/container" "github.com/henrygd/beszel/internal/entities/system" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestSessionCache_GetSet(t *testing.T) { - synctest.Test(t, func(t *testing.T) { - cache := NewSessionCache(69 * time.Second) +func createTestCacheData() *system.CombinedData { + return &system.CombinedData{ + Stats: system.Stats{ + Cpu: 50.5, + Mem: 8192, + DiskTotal: 100000, + }, + Info: system.Info{ + Hostname: "test-host", + }, + Containers: []*container.Stats{ + { + Name: "test-container", + Cpu: 25.0, + }, + }, + } +} - testData := &system.CombinedData{ - Info: system.Info{ - Hostname: "test-host", - Cores: 4, - }, +func TestNewSystemDataCache(t *testing.T) { + cache := NewSystemDataCache() + require.NotNil(t, cache) + assert.NotNil(t, cache.cache) + assert.Empty(t, cache.cache) +} + +func TestCacheGetSet(t *testing.T) { + cache := NewSystemDataCache() + data := createTestCacheData() + + // Test setting data + cache.Set(data, 1000) // 1 second cache + + // Test getting fresh data + retrieved, isCached := cache.Get(1000) + assert.True(t, isCached) + assert.Equal(t, data, retrieved) + + // Test getting non-existent cache key + _, isCached = cache.Get(2000) + assert.False(t, isCached) +} + +func TestCacheFreshness(t *testing.T) { + cache := NewSystemDataCache() + data := createTestCacheData() + + testCases := []struct { + name string + cacheTimeMs uint16 + sleepMs time.Duration + expectFresh bool + }{ + { + name: "fresh data - well within cache time", + cacheTimeMs: 1000, // 1 second + sleepMs: 100, // 100ms + expectFresh: true, + }, + { + name: "fresh data - at 50% of cache time boundary", + cacheTimeMs: 1000, // 1 second, 50% = 500ms + sleepMs: 499, // just under 500ms + expectFresh: true, + }, + { + name: "stale data - exactly at 50% cache time", + cacheTimeMs: 1000, // 1 second, 50% = 500ms + sleepMs: 500, // exactly 500ms + expectFresh: false, + }, + { + name: "stale data - well beyond cache time", + cacheTimeMs: 1000, // 1 second + sleepMs: 800, // 800ms + expectFresh: false, + }, + { + name: "short cache time", + cacheTimeMs: 200, // 200ms, 50% = 100ms + sleepMs: 150, // 150ms > 100ms + expectFresh: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + // Set data + cache.Set(data, tc.cacheTimeMs) + + // Wait for the specified duration + if tc.sleepMs > 0 { + time.Sleep(tc.sleepMs * time.Millisecond) + } + + // Check freshness + _, isCached := cache.Get(tc.cacheTimeMs) + assert.Equal(t, tc.expectFresh, isCached) + }) + }) + } +} + +func TestCacheMultipleIntervals(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + cache := NewSystemDataCache() + data1 := createTestCacheData() + data2 := &system.CombinedData{ Stats: system.Stats{ - Cpu: 50.0, - MemPct: 30.0, - DiskPct: 40.0, + Cpu: 75.0, + Mem: 16384, }, + Info: system.Info{ + Hostname: "test-host-2", + }, + Containers: []*container.Stats{}, } - // Test initial state - should not be cached - data, isCached := cache.Get("session1") - assert.False(t, isCached, "Expected no cached data initially") - assert.NotNil(t, data, "Expected data to be initialized") - // Set data for session1 - cache.Set("session1", testData) + // Set data for different intervals + cache.Set(data1, 500) // 500ms cache + cache.Set(data2, 1000) // 1000ms cache - time.Sleep(15 * time.Second) + // Both should be fresh immediately + retrieved1, isCached1 := cache.Get(500) + assert.True(t, isCached1) + assert.Equal(t, data1, retrieved1) - // Get data for a different session - should be cached - data, isCached = cache.Get("session2") - assert.True(t, isCached, "Expected data to be cached for non-primary session") - require.NotNil(t, data, "Expected cached data to be returned") - assert.Equal(t, "test-host", data.Info.Hostname, "Hostname should match test data") - assert.Equal(t, 4, data.Info.Cores, "Cores should match test data") - assert.Equal(t, 50.0, data.Stats.Cpu, "CPU should match test data") - assert.Equal(t, 30.0, data.Stats.MemPct, "Memory percentage should match test data") - assert.Equal(t, 40.0, data.Stats.DiskPct, "Disk percentage should match test data") + retrieved2, isCached2 := cache.Get(1000) + assert.True(t, isCached2) + assert.Equal(t, data2, retrieved2) - time.Sleep(10 * time.Second) + // Wait 300ms - 500ms cache should be stale (250ms threshold), 1000ms should still be fresh (500ms threshold) + time.Sleep(300 * time.Millisecond) - // Get data for the primary session - should not be cached - data, isCached = cache.Get("session1") - assert.False(t, isCached, "Expected data not to be cached for primary session") - require.NotNil(t, data, "Expected data to be returned even if not cached") - assert.Equal(t, "test-host", data.Info.Hostname, "Hostname should match test data") - // if not cached, agent will update the data - cache.Set("session1", testData) + _, isCached1 = cache.Get(500) + assert.False(t, isCached1) - time.Sleep(45 * time.Second) + _, isCached2 = cache.Get(1000) + assert.True(t, isCached2) - // Get data for a different session - should still be cached - _, isCached = cache.Get("session2") - assert.True(t, isCached, "Expected data to be cached for non-primary session") - - // Wait for the lease to expire - time.Sleep(30 * time.Second) - - // Get data for session2 - should not be cached - _, isCached = cache.Get("session2") - assert.False(t, isCached, "Expected data not to be cached after lease expiration") + // Wait another 300ms (total 600ms) - now 1000ms cache should also be stale + time.Sleep(300 * time.Millisecond) + _, isCached2 = cache.Get(1000) + assert.False(t, isCached2) }) } -func TestSessionCache_NilData(t *testing.T) { - // Create a new SessionCache - cache := NewSessionCache(30 * time.Second) +func TestCacheOverwrite(t *testing.T) { + cache := NewSystemDataCache() + data1 := createTestCacheData() + data2 := &system.CombinedData{ + Stats: system.Stats{ + Cpu: 90.0, + Mem: 32768, + }, + Info: system.Info{ + Hostname: "updated-host", + }, + Containers: []*container.Stats{}, + } - // Test setting nil data (should not panic) - assert.NotPanics(t, func() { - cache.Set("session1", nil) - }, "Setting nil data should not panic") + // Set initial data + cache.Set(data1, 1000) + retrieved, isCached := cache.Get(1000) + assert.True(t, isCached) + assert.Equal(t, data1, retrieved) - // Get data - should not be nil even though we set nil - data, _ := cache.Get("session2") - assert.NotNil(t, data, "Expected data to not be nil after setting nil data") + // Overwrite with new data + cache.Set(data2, 1000) + retrieved, isCached = cache.Get(1000) + assert.True(t, isCached) + assert.Equal(t, data2, retrieved) + assert.NotEqual(t, data1, retrieved) +} + +func TestCacheMiss(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + cache := NewSystemDataCache() + + // Test getting from empty cache + _, isCached := cache.Get(1000) + assert.False(t, isCached) + + // Set data for one interval + data := createTestCacheData() + cache.Set(data, 1000) + + // Test getting different interval + _, isCached = cache.Get(2000) + assert.False(t, isCached) + + // Test getting after data has expired + time.Sleep(600 * time.Millisecond) // 600ms > 500ms (50% of 1000ms) + _, isCached = cache.Get(1000) + assert.False(t, isCached) + }) +} + +func TestCacheZeroInterval(t *testing.T) { + cache := NewSystemDataCache() + data := createTestCacheData() + + // Set with zero interval - should allow immediate cache + cache.Set(data, 0) + + // With 0 interval, 50% is 0, so it should never be considered fresh + // (time.Since(lastUpdate) >= 0, which is not < 0) + _, isCached := cache.Get(0) + assert.False(t, isCached) +} + +func TestCacheLargeInterval(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + cache := NewSystemDataCache() + data := createTestCacheData() + + // Test with maximum uint16 value + cache.Set(data, 65535) // ~65 seconds + + // Should be fresh immediately + _, isCached := cache.Get(65535) + assert.True(t, isCached) + + // Should still be fresh after a short time + time.Sleep(100 * time.Millisecond) + _, isCached = cache.Get(65535) + assert.True(t, isCached) + }) } diff --git a/agent/client.go b/agent/client.go index 87449cfe..2f65c758 100644 --- a/agent/client.go +++ b/agent/client.go @@ -15,6 +15,7 @@ import ( "github.com/henrygd/beszel" "github.com/henrygd/beszel/internal/common" + "github.com/henrygd/beszel/internal/entities/system" "github.com/fxamacker/cbor/v2" "github.com/lxzan/gws" @@ -156,11 +157,15 @@ func (client *WebSocketClient) OnMessage(conn *gws.Conn, message *gws.Message) { return } - if err := cbor.NewDecoder(message.Data).Decode(client.hubRequest); err != nil { + var HubRequest common.HubRequest[cbor.RawMessage] + + err := cbor.Unmarshal(message.Data.Bytes(), &HubRequest) + if err != nil { slog.Error("Error parsing message", "err", err) return } - if err := client.handleHubRequest(client.hubRequest); err != nil { + + if err := client.handleHubRequest(&HubRequest, HubRequest.Id); err != nil { slog.Error("Error handling message", "err", err) } } @@ -173,7 +178,7 @@ func (client *WebSocketClient) OnPing(conn *gws.Conn, message []byte) { } // handleAuthChallenge verifies the authenticity of the hub and returns the system's fingerprint. -func (client *WebSocketClient) handleAuthChallenge(msg *common.HubRequest[cbor.RawMessage]) (err error) { +func (client *WebSocketClient) handleAuthChallenge(msg *common.HubRequest[cbor.RawMessage], requestID *uint32) (err error) { var authRequest common.FingerprintRequest if err := cbor.Unmarshal(msg.Data, &authRequest); err != nil { return err @@ -196,7 +201,7 @@ func (client *WebSocketClient) handleAuthChallenge(msg *common.HubRequest[cbor.R _, response.Port, _ = net.SplitHostPort(serverAddr) } - return client.sendMessage(response) + return client.sendResponse(response, requestID) } // verifySignature verifies the signature of the token using the public keys. @@ -221,25 +226,17 @@ func (client *WebSocketClient) Close() { } } -// handleHubRequest routes the request to the appropriate handler. -// It ensures the hub is verified before processing most requests. -func (client *WebSocketClient) handleHubRequest(msg *common.HubRequest[cbor.RawMessage]) error { - if !client.hubVerified && msg.Action != common.CheckFingerprint { - return errors.New("hub not verified") +// handleHubRequest routes the request to the appropriate handler using the handler registry. +func (client *WebSocketClient) handleHubRequest(msg *common.HubRequest[cbor.RawMessage], requestID *uint32) error { + ctx := &HandlerContext{ + Client: client, + Agent: client.agent, + Request: msg, + RequestID: requestID, + HubVerified: client.hubVerified, + SendResponse: client.sendResponse, } - switch msg.Action { - case common.GetData: - return client.sendSystemData() - case common.CheckFingerprint: - return client.handleAuthChallenge(msg) - } - return nil -} - -// sendSystemData gathers and sends current system statistics to the hub. -func (client *WebSocketClient) sendSystemData() error { - sysStats := client.agent.gatherStats(client.token) - return client.sendMessage(sysStats) + return client.agent.handlerRegistry.Handle(ctx) } // sendMessage encodes the given data to CBOR and sends it as a binary message over the WebSocket connection to the hub. @@ -251,6 +248,36 @@ func (client *WebSocketClient) sendMessage(data any) error { return client.Conn.WriteMessage(gws.OpcodeBinary, bytes) } +// sendResponse sends a response with optional request ID for the new protocol +func (client *WebSocketClient) sendResponse(data any, requestID *uint32) error { + if requestID != nil { + // New format with ID - use typed fields + response := common.AgentResponse{ + Id: requestID, + } + + // Set the appropriate typed field based on data type + switch v := data.(type) { + case *system.CombinedData: + response.SystemData = v + case *common.FingerprintResponse: + response.Fingerprint = v + // case []byte: + // response.RawBytes = v + // case string: + // response.RawBytes = []byte(v) + default: + // For any other type, convert to error + response.Error = fmt.Sprintf("unsupported response type: %T", data) + } + + return client.sendMessage(response) + } else { + // Legacy format - send data directly + return client.sendMessage(data) + } +} + // getUserAgent returns one of two User-Agent strings based on current time. // This is used to avoid being blocked by Cloudflare or other anti-bot measures. func getUserAgent() string { diff --git a/agent/client_test.go b/agent/client_test.go index 5741884c..5b712827 100644 --- a/agent/client_test.go +++ b/agent/client_test.go @@ -301,7 +301,7 @@ func TestWebSocketClient_HandleHubRequest(t *testing.T) { Data: cbor.RawMessage{}, } - err := client.handleHubRequest(hubRequest) + err := client.handleHubRequest(hubRequest, nil) if tc.expectError { assert.Error(t, err) diff --git a/agent/cpu.go b/agent/cpu.go new file mode 100644 index 00000000..bd4afc21 --- /dev/null +++ b/agent/cpu.go @@ -0,0 +1,66 @@ +package agent + +import ( + "math" + "runtime" + + "github.com/shirou/gopsutil/v4/cpu" +) + +var lastCpuTimes = make(map[uint16]cpu.TimesStat) + +// init initializes the CPU monitoring by storing the initial CPU times +// for the default 60-second cache interval. +func init() { + if times, err := cpu.Times(false); err == nil { + lastCpuTimes[60000] = times[0] + } +} + +// getCpuPercent calculates the CPU usage percentage using cached previous measurements. +// It uses the specified cache time interval to determine the time window for calculation. +// Returns the CPU usage percentage (0-100) and any error encountered. +func getCpuPercent(cacheTimeMs uint16) (float64, error) { + times, err := cpu.Times(false) + if err != nil || len(times) == 0 { + return 0, err + } + // if cacheTimeMs is not in lastCpuTimes, use 60000 as fallback lastCpuTime + if _, ok := lastCpuTimes[cacheTimeMs]; !ok { + lastCpuTimes[cacheTimeMs] = lastCpuTimes[60000] + } + delta := calculateBusy(lastCpuTimes[cacheTimeMs], times[0]) + lastCpuTimes[cacheTimeMs] = times[0] + return delta, nil +} + +// calculateBusy calculates the CPU busy percentage between two time points. +// It computes the ratio of busy time to total time elapsed between t1 and t2, +// returning a percentage clamped between 0 and 100. +func calculateBusy(t1, t2 cpu.TimesStat) float64 { + t1All, t1Busy := getAllBusy(t1) + t2All, t2Busy := getAllBusy(t2) + + if t2Busy <= t1Busy { + return 0 + } + if t2All <= t1All { + return 100 + } + return math.Min(100, math.Max(0, (t2Busy-t1Busy)/(t2All-t1All)*100)) +} + +// getAllBusy calculates the total CPU time and busy CPU time from CPU times statistics. +// On Linux, it excludes guest and guest_nice time from the total to match kernel behavior. +// Returns total CPU time and busy CPU time (total minus idle and I/O wait time). +func getAllBusy(t cpu.TimesStat) (float64, float64) { + tot := t.Total() + if runtime.GOOS == "linux" { + tot -= t.Guest // Linux 2.6.24+ + tot -= t.GuestNice // Linux 3.2.0+ + } + + busy := tot - t.Idle - t.Iowait + + return tot, busy +} diff --git a/agent/disk.go b/agent/disk.go index 0b64ec29..e117a284 100644 --- a/agent/disk.go +++ b/agent/disk.go @@ -189,3 +189,96 @@ func (a *Agent) initializeDiskIoStats(diskIoCounters map[string]disk.IOCountersS a.fsNames = append(a.fsNames, device) } } + +// Updates disk usage statistics for all monitored filesystems +func (a *Agent) updateDiskUsage(systemStats *system.Stats) { + // disk usage + for _, stats := range a.fsStats { + if d, err := disk.Usage(stats.Mountpoint); err == nil { + stats.DiskTotal = bytesToGigabytes(d.Total) + stats.DiskUsed = bytesToGigabytes(d.Used) + if stats.Root { + systemStats.DiskTotal = bytesToGigabytes(d.Total) + systemStats.DiskUsed = bytesToGigabytes(d.Used) + systemStats.DiskPct = twoDecimals(d.UsedPercent) + } + } else { + // reset stats if error (likely unmounted) + slog.Error("Error getting disk stats", "name", stats.Mountpoint, "err", err) + stats.DiskTotal = 0 + stats.DiskUsed = 0 + stats.TotalRead = 0 + stats.TotalWrite = 0 + } + } +} + +// Updates disk I/O statistics for all monitored filesystems +func (a *Agent) updateDiskIo(cacheTimeMs uint16, systemStats *system.Stats) { + // disk i/o (cache-aware per interval) + if ioCounters, err := disk.IOCounters(a.fsNames...); err == nil { + // Ensure map for this interval exists + if _, ok := a.diskPrev[cacheTimeMs]; !ok { + a.diskPrev[cacheTimeMs] = make(map[string]prevDisk) + } + now := time.Now() + for name, d := range ioCounters { + stats := a.fsStats[d.Name] + if stats == nil { + // skip devices not tracked + continue + } + + // Previous snapshot for this interval and device + prev, hasPrev := a.diskPrev[cacheTimeMs][name] + if !hasPrev { + // Seed from agent-level fsStats if present, else seed from current + prev = prevDisk{readBytes: stats.TotalRead, writeBytes: stats.TotalWrite, at: stats.Time} + if prev.at.IsZero() { + prev = prevDisk{readBytes: d.ReadBytes, writeBytes: d.WriteBytes, at: now} + } + } + + msElapsed := uint64(now.Sub(prev.at).Milliseconds()) + if msElapsed < 100 { + // Avoid division by zero or clock issues; update snapshot and continue + a.diskPrev[cacheTimeMs][name] = prevDisk{readBytes: d.ReadBytes, writeBytes: d.WriteBytes, at: now} + continue + } + + diskIORead := (d.ReadBytes - prev.readBytes) * 1000 / msElapsed + diskIOWrite := (d.WriteBytes - prev.writeBytes) * 1000 / msElapsed + readMbPerSecond := bytesToMegabytes(float64(diskIORead)) + writeMbPerSecond := bytesToMegabytes(float64(diskIOWrite)) + + // validate values + if readMbPerSecond > 50_000 || writeMbPerSecond > 50_000 { + slog.Warn("Invalid disk I/O. Resetting.", "name", d.Name, "read", readMbPerSecond, "write", writeMbPerSecond) + // Reset interval snapshot and seed from current + a.diskPrev[cacheTimeMs][name] = prevDisk{readBytes: d.ReadBytes, writeBytes: d.WriteBytes, at: now} + // also refresh agent baseline to avoid future negatives + a.initializeDiskIoStats(ioCounters) + continue + } + + // Update per-interval snapshot + a.diskPrev[cacheTimeMs][name] = prevDisk{readBytes: d.ReadBytes, writeBytes: d.WriteBytes, at: now} + + // Update global fsStats baseline for cross-interval correctness + stats.Time = now + stats.TotalRead = d.ReadBytes + stats.TotalWrite = d.WriteBytes + stats.DiskReadPs = readMbPerSecond + stats.DiskWritePs = writeMbPerSecond + stats.DiskReadBytes = diskIORead + stats.DiskWriteBytes = diskIOWrite + + if stats.Root { + systemStats.DiskReadPs = stats.DiskReadPs + systemStats.DiskWritePs = stats.DiskWritePs + systemStats.DiskIO[0] = diskIORead + systemStats.DiskIO[1] = diskIOWrite + } + } + } +} diff --git a/agent/docker.go b/agent/docker.go index 1b941a2a..bd76b3a3 100644 --- a/agent/docker.go +++ b/agent/docker.go @@ -14,17 +14,25 @@ import ( "sync" "time" + "github.com/henrygd/beszel/agent/deltatracker" "github.com/henrygd/beszel/internal/entities/container" "github.com/blang/semver" ) +const ( + // Docker API timeout in milliseconds + dockerTimeoutMs = 2100 + // Maximum realistic network speed (5 GB/s) to detect bad deltas + maxNetworkSpeedBps uint64 = 5e9 +) + type dockerManager struct { client *http.Client // Client to query Docker API wg sync.WaitGroup // WaitGroup to wait for all goroutines to finish sem chan struct{} // Semaphore to limit concurrent container requests containerStatsMutex sync.RWMutex // Mutex to prevent concurrent access to containerStatsMap - apiContainerList []*container.ApiInfo // List of containers from Docker API (no pointer) + apiContainerList []*container.ApiInfo // List of containers from Docker API containerStatsMap map[string]*container.Stats // Keeps track of container stats validIds map[string]struct{} // Map of valid container ids, used to prune invalid containers from containerStatsMap goodDockerVersion bool // Whether docker version is at least 25.0.0 (one-shot works correctly) @@ -32,6 +40,17 @@ type dockerManager struct { buf *bytes.Buffer // Buffer to store and read response bodies decoder *json.Decoder // Reusable JSON decoder that reads from buf apiStats *container.ApiStats // Reusable API stats object + + // Cache-time-aware tracking for CPU stats (similar to cpu.go) + // Maps cache time intervals to container-specific CPU usage tracking + lastCpuContainer map[uint16]map[string]uint64 // cacheTimeMs -> containerId -> last cpu container usage + lastCpuSystem map[uint16]map[string]uint64 // cacheTimeMs -> containerId -> last cpu system usage + lastCpuReadTime map[uint16]map[string]time.Time // cacheTimeMs -> containerId -> last read time (Windows) + + // Network delta trackers - one per cache time to avoid interference + // cacheTimeMs -> DeltaTracker for network bytes sent/received + networkSentTrackers map[uint16]*deltatracker.DeltaTracker[string, uint64] + networkRecvTrackers map[uint16]*deltatracker.DeltaTracker[string, uint64] } // userAgentRoundTripper is a custom http.RoundTripper that adds a User-Agent header to all requests @@ -62,8 +81,8 @@ func (d *dockerManager) dequeue() { } } -// Returns stats for all running containers -func (dm *dockerManager) getDockerStats() ([]*container.Stats, error) { +// Returns stats for all running containers with cache-time-aware delta tracking +func (dm *dockerManager) getDockerStats(cacheTimeMs uint16) ([]*container.Stats, error) { resp, err := dm.client.Get("http://localhost/containers/json") if err != nil { return nil, err @@ -87,8 +106,7 @@ func (dm *dockerManager) getDockerStats() ([]*container.Stats, error) { var failedContainers []*container.ApiInfo - for i := range dm.apiContainerList { - ctr := dm.apiContainerList[i] + for _, ctr := range dm.apiContainerList { ctr.IdShort = ctr.Id[:12] dm.validIds[ctr.IdShort] = struct{}{} // check if container is less than 1 minute old (possible restart) @@ -98,9 +116,9 @@ func (dm *dockerManager) getDockerStats() ([]*container.Stats, error) { dm.deleteContainerStatsSync(ctr.IdShort) } dm.queue() - go func() { + go func(ctr *container.ApiInfo) { defer dm.dequeue() - err := dm.updateContainerStats(ctr) + err := dm.updateContainerStats(ctr, cacheTimeMs) // if error, delete from map and add to failed list to retry if err != nil { dm.containerStatsMutex.Lock() @@ -108,7 +126,7 @@ func (dm *dockerManager) getDockerStats() ([]*container.Stats, error) { failedContainers = append(failedContainers, ctr) dm.containerStatsMutex.Unlock() } - }() + }(ctr) } dm.wg.Wait() @@ -119,13 +137,12 @@ func (dm *dockerManager) getDockerStats() ([]*container.Stats, error) { for i := range failedContainers { ctr := failedContainers[i] dm.queue() - go func() { + go func(ctr *container.ApiInfo) { defer dm.dequeue() - err = dm.updateContainerStats(ctr) - if err != nil { - slog.Error("Error getting container stats", "err", err) + if err2 := dm.updateContainerStats(ctr, cacheTimeMs); err2 != nil { + slog.Error("Error getting container stats", "err", err2) } - }() + }(ctr) } dm.wg.Wait() } @@ -140,18 +157,156 @@ func (dm *dockerManager) getDockerStats() ([]*container.Stats, error) { } } + // prepare network trackers for next interval for this cache time + dm.cycleNetworkDeltasForCacheTime(cacheTimeMs) + return stats, nil } -// Updates stats for individual container -func (dm *dockerManager) updateContainerStats(ctr *container.ApiInfo) error { +// initializeCpuTracking initializes CPU tracking maps for a specific cache time interval +func (dm *dockerManager) initializeCpuTracking(cacheTimeMs uint16) { + // Initialize cache time maps if they don't exist + if dm.lastCpuContainer[cacheTimeMs] == nil { + dm.lastCpuContainer[cacheTimeMs] = make(map[string]uint64) + } + if dm.lastCpuSystem[cacheTimeMs] == nil { + dm.lastCpuSystem[cacheTimeMs] = make(map[string]uint64) + } + // Ensure the outer map exists before indexing + if dm.lastCpuReadTime == nil { + dm.lastCpuReadTime = make(map[uint16]map[string]time.Time) + } + if dm.lastCpuReadTime[cacheTimeMs] == nil { + dm.lastCpuReadTime[cacheTimeMs] = make(map[string]time.Time) + } +} + +// getCpuPreviousValues returns previous CPU values for a container and cache time interval +func (dm *dockerManager) getCpuPreviousValues(cacheTimeMs uint16, containerId string) (uint64, uint64) { + return dm.lastCpuContainer[cacheTimeMs][containerId], dm.lastCpuSystem[cacheTimeMs][containerId] +} + +// setCpuCurrentValues stores current CPU values for a container and cache time interval +func (dm *dockerManager) setCpuCurrentValues(cacheTimeMs uint16, containerId string, cpuContainer, cpuSystem uint64) { + dm.lastCpuContainer[cacheTimeMs][containerId] = cpuContainer + dm.lastCpuSystem[cacheTimeMs][containerId] = cpuSystem +} + +// calculateMemoryUsage calculates memory usage from Docker API stats +func calculateMemoryUsage(apiStats *container.ApiStats, isWindows bool) (uint64, error) { + if isWindows { + return apiStats.MemoryStats.PrivateWorkingSet, nil + } + + // Check if container has valid data, otherwise may be in restart loop (#103) + if apiStats.MemoryStats.Usage == 0 { + return 0, fmt.Errorf("no memory stats available") + } + + memCache := apiStats.MemoryStats.Stats.InactiveFile + if memCache == 0 { + memCache = apiStats.MemoryStats.Stats.Cache + } + + return apiStats.MemoryStats.Usage - memCache, nil +} + +// getNetworkTracker returns the DeltaTracker for a specific cache time, creating it if needed +func (dm *dockerManager) getNetworkTracker(cacheTimeMs uint16, isSent bool) *deltatracker.DeltaTracker[string, uint64] { + var trackers map[uint16]*deltatracker.DeltaTracker[string, uint64] + if isSent { + trackers = dm.networkSentTrackers + } else { + trackers = dm.networkRecvTrackers + } + + if trackers[cacheTimeMs] == nil { + trackers[cacheTimeMs] = deltatracker.NewDeltaTracker[string, uint64]() + } + + return trackers[cacheTimeMs] +} + +// cycleNetworkDeltasForCacheTime cycles the network delta trackers for a specific cache time +func (dm *dockerManager) cycleNetworkDeltasForCacheTime(cacheTimeMs uint16) { + if dm.networkSentTrackers[cacheTimeMs] != nil { + dm.networkSentTrackers[cacheTimeMs].Cycle() + } + if dm.networkRecvTrackers[cacheTimeMs] != nil { + dm.networkRecvTrackers[cacheTimeMs].Cycle() + } +} + +// calculateNetworkStats calculates network sent/receive deltas using DeltaTracker +func (dm *dockerManager) calculateNetworkStats(ctr *container.ApiInfo, apiStats *container.ApiStats, stats *container.Stats, initialized bool, name string, cacheTimeMs uint16) (uint64, uint64) { + var total_sent, total_recv uint64 + for _, v := range apiStats.Networks { + total_sent += v.TxBytes + total_recv += v.RxBytes + } + + // Get the DeltaTracker for this specific cache time + sentTracker := dm.getNetworkTracker(cacheTimeMs, true) + recvTracker := dm.getNetworkTracker(cacheTimeMs, false) + + // Set current values in the cache-time-specific DeltaTracker + sentTracker.Set(ctr.IdShort, total_sent) + recvTracker.Set(ctr.IdShort, total_recv) + + // Get deltas (bytes since last measurement) + sent_delta_raw := sentTracker.Delta(ctr.IdShort) + recv_delta_raw := recvTracker.Delta(ctr.IdShort) + + // Calculate bytes per second independently for Tx and Rx if we have previous data + var sent_delta, recv_delta uint64 + if initialized { + millisecondsElapsed := uint64(time.Since(stats.PrevReadTime).Milliseconds()) + if millisecondsElapsed > 0 { + if sent_delta_raw > 0 { + sent_delta = sent_delta_raw * 1000 / millisecondsElapsed + if sent_delta > maxNetworkSpeedBps { + slog.Warn("Bad network delta", "container", name) + sent_delta = 0 + } + } + if recv_delta_raw > 0 { + recv_delta = recv_delta_raw * 1000 / millisecondsElapsed + if recv_delta > maxNetworkSpeedBps { + slog.Warn("Bad network delta", "container", name) + recv_delta = 0 + } + } + } + } + + return sent_delta, recv_delta +} + +// validateCpuPercentage checks if CPU percentage is within valid range +func validateCpuPercentage(cpuPct float64, containerName string) error { + if cpuPct > 100 { + return fmt.Errorf("%s cpu pct greater than 100: %+v", containerName, cpuPct) + } + return nil +} + +// updateContainerStatsValues updates the final stats values +func updateContainerStatsValues(stats *container.Stats, cpuPct float64, usedMemory uint64, sent_delta, recv_delta uint64, readTime time.Time) { + stats.Cpu = twoDecimals(cpuPct) + stats.Mem = bytesToMegabytes(float64(usedMemory)) + stats.NetworkSent = bytesToMegabytes(float64(sent_delta)) + stats.NetworkRecv = bytesToMegabytes(float64(recv_delta)) + stats.PrevReadTime = readTime +} + +// Updates stats for individual container with cache-time-aware delta tracking +func (dm *dockerManager) updateContainerStats(ctr *container.ApiInfo, cacheTimeMs uint16) error { name := ctr.Names[0][1:] resp, err := dm.client.Get("http://localhost/containers/" + ctr.IdShort + "/stats?stream=0&one-shot=1") if err != nil { return err } - defer resp.Body.Close() dm.containerStatsMutex.Lock() defer dm.containerStatsMutex.Unlock() @@ -169,72 +324,58 @@ func (dm *dockerManager) updateContainerStats(ctr *container.ApiInfo) error { stats.NetworkSent = 0 stats.NetworkRecv = 0 - // docker host container stats response - // res := dm.getApiStats() - // defer dm.putApiStats(res) - // - res := dm.apiStats res.Networks = nil if err := dm.decode(resp, res); err != nil { return err } - // calculate cpu and memory stats - var usedMemory uint64 + // Initialize CPU tracking for this cache time interval + dm.initializeCpuTracking(cacheTimeMs) + + // Get previous CPU values + prevCpuContainer, prevCpuSystem := dm.getCpuPreviousValues(cacheTimeMs, ctr.IdShort) + + // Calculate CPU percentage based on platform var cpuPct float64 - - // store current cpu stats - prevCpuContainer, prevCpuSystem := stats.CpuContainer, stats.CpuSystem - stats.CpuContainer = res.CPUStats.CPUUsage.TotalUsage - stats.CpuSystem = res.CPUStats.SystemUsage - if dm.isWindows { - usedMemory = res.MemoryStats.PrivateWorkingSet - cpuPct = res.CalculateCpuPercentWindows(prevCpuContainer, stats.PrevReadTime) + prevRead := dm.lastCpuReadTime[cacheTimeMs][ctr.IdShort] + cpuPct = res.CalculateCpuPercentWindows(prevCpuContainer, prevRead) } else { - // check if container has valid data, otherwise may be in restart loop (#103) - if res.MemoryStats.Usage == 0 { - return fmt.Errorf("%s - no memory stats - see https://github.com/henrygd/beszel/issues/144", name) - } - memCache := res.MemoryStats.Stats.InactiveFile - if memCache == 0 { - memCache = res.MemoryStats.Stats.Cache - } - usedMemory = res.MemoryStats.Usage - memCache - cpuPct = res.CalculateCpuPercentLinux(prevCpuContainer, prevCpuSystem) } - if cpuPct > 100 { - return fmt.Errorf("%s cpu pct greater than 100: %+v", name, cpuPct) + // Calculate memory usage + usedMemory, err := calculateMemoryUsage(res, dm.isWindows) + if err != nil { + return fmt.Errorf("%s - %w - see https://github.com/henrygd/beszel/issues/144", name, err) } - // network + // Store current CPU stats for next calculation + currentCpuContainer := res.CPUStats.CPUUsage.TotalUsage + currentCpuSystem := res.CPUStats.SystemUsage + dm.setCpuCurrentValues(cacheTimeMs, ctr.IdShort, currentCpuContainer, currentCpuSystem) + + // Validate CPU percentage + if err := validateCpuPercentage(cpuPct, name); err != nil { + return err + } + + // Calculate network stats using DeltaTracker + sent_delta, recv_delta := dm.calculateNetworkStats(ctr, res, stats, initialized, name, cacheTimeMs) + + // Store current network values for legacy compatibility var total_sent, total_recv uint64 for _, v := range res.Networks { total_sent += v.TxBytes total_recv += v.RxBytes } - var sent_delta, recv_delta uint64 - millisecondsElapsed := uint64(time.Since(stats.PrevReadTime).Milliseconds()) - if initialized && millisecondsElapsed > 0 { - // get bytes per second - sent_delta = (total_sent - stats.PrevNet.Sent) * 1000 / millisecondsElapsed - recv_delta = (total_recv - stats.PrevNet.Recv) * 1000 / millisecondsElapsed - // check for unrealistic network values (> 5GB/s) - if sent_delta > 5e9 || recv_delta > 5e9 { - slog.Warn("Bad network delta", "container", name) - sent_delta, recv_delta = 0, 0 - } - } stats.PrevNet.Sent, stats.PrevNet.Recv = total_sent, total_recv - stats.Cpu = twoDecimals(cpuPct) - stats.Mem = bytesToMegabytes(float64(usedMemory)) - stats.NetworkSent = bytesToMegabytes(float64(sent_delta)) - stats.NetworkRecv = bytesToMegabytes(float64(recv_delta)) - stats.PrevReadTime = res.Read + // Update final stats values + updateContainerStatsValues(stats, cpuPct, usedMemory, sent_delta, recv_delta, res.Read) + // store per-cache-time read time for Windows CPU percent calc + dm.lastCpuReadTime[cacheTimeMs][ctr.IdShort] = res.Read return nil } @@ -244,6 +385,15 @@ func (dm *dockerManager) deleteContainerStatsSync(id string) { dm.containerStatsMutex.Lock() defer dm.containerStatsMutex.Unlock() delete(dm.containerStatsMap, id) + for ct := range dm.lastCpuContainer { + delete(dm.lastCpuContainer[ct], id) + } + for ct := range dm.lastCpuSystem { + delete(dm.lastCpuSystem[ct], id) + } + for ct := range dm.lastCpuReadTime { + delete(dm.lastCpuReadTime[ct], id) + } } // Creates a new http client for Docker or Podman API @@ -283,7 +433,7 @@ func newDockerManager(a *Agent) *dockerManager { } // configurable timeout - timeout := time.Millisecond * 2100 + timeout := time.Millisecond * time.Duration(dockerTimeoutMs) if t, set := GetEnv("DOCKER_TIMEOUT"); set { timeout, err = time.ParseDuration(t) if err != nil { @@ -308,6 +458,13 @@ func newDockerManager(a *Agent) *dockerManager { sem: make(chan struct{}, 5), apiContainerList: []*container.ApiInfo{}, apiStats: &container.ApiStats{}, + + // Initialize cache-time-aware tracking structures + lastCpuContainer: make(map[uint16]map[string]uint64), + lastCpuSystem: make(map[uint16]map[string]uint64), + lastCpuReadTime: make(map[uint16]map[string]time.Time), + networkSentTrackers: make(map[uint16]*deltatracker.DeltaTracker[string, uint64]), + networkRecvTrackers: make(map[uint16]*deltatracker.DeltaTracker[string, uint64]), } // If using podman, return client diff --git a/agent/docker_test.go b/agent/docker_test.go new file mode 100644 index 00000000..4ce58adf --- /dev/null +++ b/agent/docker_test.go @@ -0,0 +1,875 @@ +//go:build testing +// +build testing + +package agent + +import ( + "encoding/json" + "os" + "testing" + "time" + + "github.com/henrygd/beszel/agent/deltatracker" + "github.com/henrygd/beszel/internal/entities/container" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var defaultCacheTimeMs = uint16(60_000) + +// cycleCpuDeltas cycles the CPU tracking data for a specific cache time interval +func (dm *dockerManager) cycleCpuDeltas(cacheTimeMs uint16) { + // Clear the CPU tracking maps for this cache time interval + if dm.lastCpuContainer[cacheTimeMs] != nil { + clear(dm.lastCpuContainer[cacheTimeMs]) + } + if dm.lastCpuSystem[cacheTimeMs] != nil { + clear(dm.lastCpuSystem[cacheTimeMs]) + } +} + +func TestCalculateMemoryUsage(t *testing.T) { + tests := []struct { + name string + apiStats *container.ApiStats + isWindows bool + expected uint64 + expectError bool + }{ + { + name: "Linux with valid memory stats", + apiStats: &container.ApiStats{ + MemoryStats: container.MemoryStats{ + Usage: 1048576, // 1MB + Stats: container.MemoryStatsStats{ + Cache: 524288, // 512KB + InactiveFile: 262144, // 256KB + }, + }, + }, + isWindows: false, + expected: 786432, // 1MB - 256KB (inactive_file takes precedence) = 768KB + expectError: false, + }, + { + name: "Linux with zero cache uses inactive_file", + apiStats: &container.ApiStats{ + MemoryStats: container.MemoryStats{ + Usage: 1048576, // 1MB + Stats: container.MemoryStatsStats{ + Cache: 0, + InactiveFile: 262144, // 256KB + }, + }, + }, + isWindows: false, + expected: 786432, // 1MB - 256KB = 768KB + expectError: false, + }, + { + name: "Windows with valid memory stats", + apiStats: &container.ApiStats{ + MemoryStats: container.MemoryStats{ + PrivateWorkingSet: 524288, // 512KB + }, + }, + isWindows: true, + expected: 524288, + expectError: false, + }, + { + name: "Linux with zero usage returns error", + apiStats: &container.ApiStats{ + MemoryStats: container.MemoryStats{ + Usage: 0, + Stats: container.MemoryStatsStats{ + Cache: 0, + InactiveFile: 0, + }, + }, + }, + isWindows: false, + expected: 0, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := calculateMemoryUsage(tt.apiStats, tt.isWindows) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestValidateCpuPercentage(t *testing.T) { + tests := []struct { + name string + cpuPct float64 + containerName string + expectError bool + expectedError string + }{ + { + name: "valid CPU percentage", + cpuPct: 50.5, + containerName: "test-container", + expectError: false, + }, + { + name: "zero CPU percentage", + cpuPct: 0.0, + containerName: "test-container", + expectError: false, + }, + { + name: "CPU percentage over 100", + cpuPct: 150.5, + containerName: "test-container", + expectError: true, + expectedError: "test-container cpu pct greater than 100: 150.5", + }, + { + name: "CPU percentage exactly 100", + cpuPct: 100.0, + containerName: "test-container", + expectError: false, + }, + { + name: "negative CPU percentage", + cpuPct: -10.0, + containerName: "test-container", + expectError: false, // Function only checks for > 100, not negative + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateCpuPercentage(tt.cpuPct, tt.containerName) + + if tt.expectError { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestUpdateContainerStatsValues(t *testing.T) { + stats := &container.Stats{ + Name: "test-container", + Cpu: 0.0, + Mem: 0.0, + NetworkSent: 0.0, + NetworkRecv: 0.0, + PrevReadTime: time.Time{}, + } + + testTime := time.Now() + updateContainerStatsValues(stats, 75.5, 1048576, 524288, 262144, testTime) + + // Check CPU percentage (should be rounded to 2 decimals) + assert.Equal(t, 75.5, stats.Cpu) + + // Check memory (should be converted to MB: 1048576 bytes = 1 MB) + assert.Equal(t, 1.0, stats.Mem) + + // Check network sent (should be converted to MB: 524288 bytes = 0.5 MB) + assert.Equal(t, 0.5, stats.NetworkSent) + + // Check network recv (should be converted to MB: 262144 bytes = 0.25 MB) + assert.Equal(t, 0.25, stats.NetworkRecv) + + // Check read time + assert.Equal(t, testTime, stats.PrevReadTime) +} + +func TestTwoDecimals(t *testing.T) { + tests := []struct { + name string + input float64 + expected float64 + }{ + {"round down", 1.234, 1.23}, + {"round half up", 1.235, 1.24}, // math.Round rounds half up + {"no rounding needed", 1.23, 1.23}, + {"negative number", -1.235, -1.24}, // math.Round rounds half up (more negative) + {"zero", 0.0, 0.0}, + {"large number", 123.456, 123.46}, // rounds 5 up + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := twoDecimals(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestBytesToMegabytes(t *testing.T) { + tests := []struct { + name string + input float64 + expected float64 + }{ + {"1 MB", 1048576, 1.0}, + {"512 KB", 524288, 0.5}, + {"zero", 0, 0}, + {"large value", 1073741824, 1024}, // 1 GB = 1024 MB + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := bytesToMegabytes(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestInitializeCpuTracking(t *testing.T) { + dm := &dockerManager{ + lastCpuContainer: make(map[uint16]map[string]uint64), + lastCpuSystem: make(map[uint16]map[string]uint64), + lastCpuReadTime: make(map[uint16]map[string]time.Time), + } + + cacheTimeMs := uint16(30000) + + // Test initializing a new cache time + dm.initializeCpuTracking(cacheTimeMs) + + // Check that maps were created + assert.NotNil(t, dm.lastCpuContainer[cacheTimeMs]) + assert.NotNil(t, dm.lastCpuSystem[cacheTimeMs]) + assert.NotNil(t, dm.lastCpuReadTime[cacheTimeMs]) + assert.Empty(t, dm.lastCpuContainer[cacheTimeMs]) + assert.Empty(t, dm.lastCpuSystem[cacheTimeMs]) + + // Test initializing existing cache time (should not overwrite) + dm.lastCpuContainer[cacheTimeMs]["test"] = 100 + dm.lastCpuSystem[cacheTimeMs]["test"] = 200 + + dm.initializeCpuTracking(cacheTimeMs) + + // Should still have the existing values + assert.Equal(t, uint64(100), dm.lastCpuContainer[cacheTimeMs]["test"]) + assert.Equal(t, uint64(200), dm.lastCpuSystem[cacheTimeMs]["test"]) +} + +func TestGetCpuPreviousValues(t *testing.T) { + dm := &dockerManager{ + lastCpuContainer: map[uint16]map[string]uint64{ + 30000: {"container1": 100, "container2": 200}, + }, + lastCpuSystem: map[uint16]map[string]uint64{ + 30000: {"container1": 150, "container2": 250}, + }, + } + + // Test getting existing values + container, system := dm.getCpuPreviousValues(30000, "container1") + assert.Equal(t, uint64(100), container) + assert.Equal(t, uint64(150), system) + + // Test getting non-existing container + container, system = dm.getCpuPreviousValues(30000, "nonexistent") + assert.Equal(t, uint64(0), container) + assert.Equal(t, uint64(0), system) + + // Test getting non-existing cache time + container, system = dm.getCpuPreviousValues(60000, "container1") + assert.Equal(t, uint64(0), container) + assert.Equal(t, uint64(0), system) +} + +func TestSetCpuCurrentValues(t *testing.T) { + dm := &dockerManager{ + lastCpuContainer: make(map[uint16]map[string]uint64), + lastCpuSystem: make(map[uint16]map[string]uint64), + } + + cacheTimeMs := uint16(30000) + containerId := "test-container" + + // Initialize the cache time maps first + dm.initializeCpuTracking(cacheTimeMs) + + // Set values + dm.setCpuCurrentValues(cacheTimeMs, containerId, 500, 750) + + // Check that values were set + assert.Equal(t, uint64(500), dm.lastCpuContainer[cacheTimeMs][containerId]) + assert.Equal(t, uint64(750), dm.lastCpuSystem[cacheTimeMs][containerId]) +} + +func TestCalculateNetworkStats(t *testing.T) { + // Create docker manager with tracker maps + dm := &dockerManager{ + networkSentTrackers: make(map[uint16]*deltatracker.DeltaTracker[string, uint64]), + networkRecvTrackers: make(map[uint16]*deltatracker.DeltaTracker[string, uint64]), + } + + cacheTimeMs := uint16(30000) + + // Pre-populate tracker for this cache time with initial values + sentTracker := deltatracker.NewDeltaTracker[string, uint64]() + recvTracker := deltatracker.NewDeltaTracker[string, uint64]() + sentTracker.Set("container1", 1000) + recvTracker.Set("container1", 800) + sentTracker.Cycle() // Move to previous + recvTracker.Cycle() + + dm.networkSentTrackers[cacheTimeMs] = sentTracker + dm.networkRecvTrackers[cacheTimeMs] = recvTracker + + ctr := &container.ApiInfo{ + IdShort: "container1", + } + + apiStats := &container.ApiStats{ + Networks: map[string]container.NetworkStats{ + "eth0": {TxBytes: 2000, RxBytes: 1800}, // New values + }, + } + + stats := &container.Stats{ + PrevReadTime: time.Now().Add(-time.Second), // 1 second ago + } + + // Test with initialized container + sent, recv := dm.calculateNetworkStats(ctr, apiStats, stats, true, "test-container", cacheTimeMs) + + // Should return calculated byte rates per second + assert.GreaterOrEqual(t, sent, uint64(0)) + assert.GreaterOrEqual(t, recv, uint64(0)) + + // Cycle and test one-direction change (Tx only) is reflected independently + dm.cycleNetworkDeltasForCacheTime(cacheTimeMs) + apiStats.Networks["eth0"] = container.NetworkStats{TxBytes: 2500, RxBytes: 1800} // +500 Tx only + sent, recv = dm.calculateNetworkStats(ctr, apiStats, stats, true, "test-container", cacheTimeMs) + assert.Greater(t, sent, uint64(0)) + assert.Equal(t, uint64(0), recv) +} + +func TestDockerManagerCreation(t *testing.T) { + // Test that dockerManager can be created without panicking + dm := &dockerManager{ + lastCpuContainer: make(map[uint16]map[string]uint64), + lastCpuSystem: make(map[uint16]map[string]uint64), + lastCpuReadTime: make(map[uint16]map[string]time.Time), + networkSentTrackers: make(map[uint16]*deltatracker.DeltaTracker[string, uint64]), + networkRecvTrackers: make(map[uint16]*deltatracker.DeltaTracker[string, uint64]), + } + + assert.NotNil(t, dm) + assert.NotNil(t, dm.lastCpuContainer) + assert.NotNil(t, dm.lastCpuSystem) + assert.NotNil(t, dm.networkSentTrackers) + assert.NotNil(t, dm.networkRecvTrackers) +} + +func TestCycleCpuDeltas(t *testing.T) { + dm := &dockerManager{ + lastCpuContainer: map[uint16]map[string]uint64{ + 30000: {"container1": 100, "container2": 200}, + }, + lastCpuSystem: map[uint16]map[string]uint64{ + 30000: {"container1": 150, "container2": 250}, + }, + lastCpuReadTime: map[uint16]map[string]time.Time{ + 30000: {"container1": time.Now()}, + }, + } + + cacheTimeMs := uint16(30000) + + // Verify values exist before cycling + assert.Equal(t, uint64(100), dm.lastCpuContainer[cacheTimeMs]["container1"]) + assert.Equal(t, uint64(200), dm.lastCpuContainer[cacheTimeMs]["container2"]) + + // Cycle the CPU deltas + dm.cycleCpuDeltas(cacheTimeMs) + + // Verify values are cleared + assert.Empty(t, dm.lastCpuContainer[cacheTimeMs]) + assert.Empty(t, dm.lastCpuSystem[cacheTimeMs]) + // lastCpuReadTime is not affected by cycleCpuDeltas + assert.NotEmpty(t, dm.lastCpuReadTime[cacheTimeMs]) +} + +func TestCycleNetworkDeltas(t *testing.T) { + // Create docker manager with tracker maps + dm := &dockerManager{ + networkSentTrackers: make(map[uint16]*deltatracker.DeltaTracker[string, uint64]), + networkRecvTrackers: make(map[uint16]*deltatracker.DeltaTracker[string, uint64]), + } + + cacheTimeMs := uint16(30000) + + // Get trackers for this cache time (creates them) + sentTracker := dm.getNetworkTracker(cacheTimeMs, true) + recvTracker := dm.getNetworkTracker(cacheTimeMs, false) + + // Set some test data + sentTracker.Set("test", 100) + recvTracker.Set("test", 200) + + // This should not panic + assert.NotPanics(t, func() { + dm.cycleNetworkDeltasForCacheTime(cacheTimeMs) + }) + + // Verify that cycle worked by checking deltas are now zero (no previous values) + assert.Equal(t, uint64(0), sentTracker.Delta("test")) + assert.Equal(t, uint64(0), recvTracker.Delta("test")) +} + +func TestConstants(t *testing.T) { + // Test that constants are properly defined + assert.Equal(t, uint16(60000), defaultCacheTimeMs) + assert.Equal(t, uint64(5e9), maxNetworkSpeedBps) + assert.Equal(t, 2100, dockerTimeoutMs) +} + +func TestDockerStatsWithMockData(t *testing.T) { + // Create a docker manager with initialized tracking + dm := &dockerManager{ + lastCpuContainer: make(map[uint16]map[string]uint64), + lastCpuSystem: make(map[uint16]map[string]uint64), + lastCpuReadTime: make(map[uint16]map[string]time.Time), + networkSentTrackers: make(map[uint16]*deltatracker.DeltaTracker[string, uint64]), + networkRecvTrackers: make(map[uint16]*deltatracker.DeltaTracker[string, uint64]), + containerStatsMap: make(map[string]*container.Stats), + } + + cacheTimeMs := uint16(30000) + + // Test that initializeCpuTracking works + dm.initializeCpuTracking(cacheTimeMs) + assert.NotNil(t, dm.lastCpuContainer[cacheTimeMs]) + assert.NotNil(t, dm.lastCpuSystem[cacheTimeMs]) + + // Test that we can set and get CPU values + dm.setCpuCurrentValues(cacheTimeMs, "test-container", 1000, 2000) + container, system := dm.getCpuPreviousValues(cacheTimeMs, "test-container") + assert.Equal(t, uint64(1000), container) + assert.Equal(t, uint64(2000), system) +} + +func TestMemoryStatsEdgeCases(t *testing.T) { + tests := []struct { + name string + usage uint64 + cache uint64 + inactive uint64 + isWindows bool + expected uint64 + hasError bool + }{ + {"Linux normal case", 1000, 200, 0, false, 800, false}, + {"Linux with inactive file", 1000, 0, 300, false, 700, false}, + {"Windows normal case", 0, 0, 0, true, 500, false}, + {"Linux zero usage error", 0, 0, 0, false, 0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + apiStats := &container.ApiStats{ + MemoryStats: container.MemoryStats{ + Usage: tt.usage, + Stats: container.MemoryStatsStats{ + Cache: tt.cache, + InactiveFile: tt.inactive, + }, + }, + } + + if tt.isWindows { + apiStats.MemoryStats.PrivateWorkingSet = tt.expected + } + + result, err := calculateMemoryUsage(apiStats, tt.isWindows) + + if tt.hasError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestContainerStatsInitialization(t *testing.T) { + stats := &container.Stats{Name: "test-container"} + + // Verify initial values + assert.Equal(t, "test-container", stats.Name) + assert.Equal(t, 0.0, stats.Cpu) + assert.Equal(t, 0.0, stats.Mem) + assert.Equal(t, 0.0, stats.NetworkSent) + assert.Equal(t, 0.0, stats.NetworkRecv) + assert.Equal(t, time.Time{}, stats.PrevReadTime) + + // Test updating values + testTime := time.Now() + updateContainerStatsValues(stats, 45.67, 2097152, 1048576, 524288, testTime) + + assert.Equal(t, 45.67, stats.Cpu) + assert.Equal(t, 2.0, stats.Mem) + assert.Equal(t, 1.0, stats.NetworkSent) + assert.Equal(t, 0.5, stats.NetworkRecv) + assert.Equal(t, testTime, stats.PrevReadTime) +} + +// Test with real Docker API test data +func TestCalculateMemoryUsageWithRealData(t *testing.T) { + // Load minimal container stats from test data + data, err := os.ReadFile("test-data/container.json") + require.NoError(t, err) + + var apiStats container.ApiStats + err = json.Unmarshal(data, &apiStats) + require.NoError(t, err) + + // Test memory calculation with real data + usedMemory, err := calculateMemoryUsage(&apiStats, false) + require.NoError(t, err) + + // From the real data: usage - inactive_file = 507400192 - 165130240 = 342269952 + expected := uint64(507400192 - 165130240) + assert.Equal(t, expected, usedMemory) +} + +func TestCpuPercentageCalculationWithRealData(t *testing.T) { + // Load minimal container stats from test data + data1, err := os.ReadFile("test-data/container.json") + require.NoError(t, err) + + data2, err := os.ReadFile("test-data/container2.json") + require.NoError(t, err) + + var apiStats1, apiStats2 container.ApiStats + err = json.Unmarshal(data1, &apiStats1) + require.NoError(t, err) + err = json.Unmarshal(data2, &apiStats2) + require.NoError(t, err) + + // Calculate delta manually: 314891801000 - 312055276000 = 2836525000 + // System delta: 1368474900000000 - 1366399830000000 = 2075070000000 + // Expected %: (2836525000 / 2075070000000) * 100 ≈ 0.1367% + expectedPct := float64(2836525000) / float64(2075070000000) * 100.0 + actualPct := apiStats2.CalculateCpuPercentLinux(apiStats1.CPUStats.CPUUsage.TotalUsage, apiStats1.CPUStats.SystemUsage) + + assert.InDelta(t, expectedPct, actualPct, 0.01) +} + +func TestNetworkStatsCalculationWithRealData(t *testing.T) { + // Create synthetic test data to avoid timing issues + apiStats1 := &container.ApiStats{ + Networks: map[string]container.NetworkStats{ + "eth0": {TxBytes: 1000000, RxBytes: 500000}, + }, + } + + apiStats2 := &container.ApiStats{ + Networks: map[string]container.NetworkStats{ + "eth0": {TxBytes: 3000000, RxBytes: 1500000}, // 2MB sent, 1MB received increase + }, + } + + // Create docker manager with tracker maps + dm := &dockerManager{ + networkSentTrackers: make(map[uint16]*deltatracker.DeltaTracker[string, uint64]), + networkRecvTrackers: make(map[uint16]*deltatracker.DeltaTracker[string, uint64]), + } + + ctr := &container.ApiInfo{IdShort: "test-container"} + cacheTimeMs := uint16(30000) // Test with 30 second cache + + // Use exact timing for deterministic results + exactly1000msAgo := time.Now().Add(-1000 * time.Millisecond) + stats := &container.Stats{ + PrevReadTime: exactly1000msAgo, + } + + // First call sets baseline + sent1, recv1 := dm.calculateNetworkStats(ctr, apiStats1, stats, true, "test", cacheTimeMs) + assert.Equal(t, uint64(0), sent1) + assert.Equal(t, uint64(0), recv1) + + // Cycle to establish baseline for this cache time + dm.cycleNetworkDeltasForCacheTime(cacheTimeMs) + + // Calculate expected results precisely + deltaSent := uint64(2000000) // 3000000 - 1000000 + deltaRecv := uint64(1000000) // 1500000 - 500000 + expectedElapsedMs := uint64(1000) // Exactly 1000ms + expectedSentRate := deltaSent * 1000 / expectedElapsedMs // Should be exactly 2000000 + expectedRecvRate := deltaRecv * 1000 / expectedElapsedMs // Should be exactly 1000000 + + // Second call with changed data + sent2, recv2 := dm.calculateNetworkStats(ctr, apiStats2, stats, true, "test", cacheTimeMs) + + // Should be exactly the expected rates (no tolerance needed) + assert.Equal(t, expectedSentRate, sent2) + assert.Equal(t, expectedRecvRate, recv2) + + // Bad speed cap: set absurd delta over 1ms and expect 0 due to cap + dm.cycleNetworkDeltasForCacheTime(cacheTimeMs) + stats.PrevReadTime = time.Now().Add(-1 * time.Millisecond) + apiStats1.Networks["eth0"] = container.NetworkStats{TxBytes: 0, RxBytes: 0} + apiStats2.Networks["eth0"] = container.NetworkStats{TxBytes: 10 * 1024 * 1024 * 1024, RxBytes: 0} // 10GB delta + _, _ = dm.calculateNetworkStats(ctr, apiStats1, stats, true, "test", cacheTimeMs) // baseline + dm.cycleNetworkDeltasForCacheTime(cacheTimeMs) + sent3, recv3 := dm.calculateNetworkStats(ctr, apiStats2, stats, true, "test", cacheTimeMs) + assert.Equal(t, uint64(0), sent3) + assert.Equal(t, uint64(0), recv3) +} + +func TestContainerStatsEndToEndWithRealData(t *testing.T) { + // Load minimal container stats + data, err := os.ReadFile("test-data/container.json") + require.NoError(t, err) + + var apiStats container.ApiStats + err = json.Unmarshal(data, &apiStats) + require.NoError(t, err) + + // Create a docker manager with proper initialization + dm := &dockerManager{ + lastCpuContainer: make(map[uint16]map[string]uint64), + lastCpuSystem: make(map[uint16]map[string]uint64), + lastCpuReadTime: make(map[uint16]map[string]time.Time), + networkSentTrackers: make(map[uint16]*deltatracker.DeltaTracker[string, uint64]), + networkRecvTrackers: make(map[uint16]*deltatracker.DeltaTracker[string, uint64]), + containerStatsMap: make(map[string]*container.Stats), + } + + // Initialize CPU tracking + cacheTimeMs := uint16(30000) + dm.initializeCpuTracking(cacheTimeMs) + + // Create container info + ctr := &container.ApiInfo{ + IdShort: "abc123", + } + + // Initialize container stats + stats := &container.Stats{Name: "jellyfin"} + dm.containerStatsMap[ctr.IdShort] = stats + + // Test individual components that we can verify + usedMemory, memErr := calculateMemoryUsage(&apiStats, false) + assert.NoError(t, memErr) + assert.Greater(t, usedMemory, uint64(0)) + + // Test CPU percentage validation + cpuPct := 85.5 + err = validateCpuPercentage(cpuPct, "jellyfin") + assert.NoError(t, err) + + err = validateCpuPercentage(150.0, "jellyfin") + assert.Error(t, err) + + // Test stats value updates + testStats := &container.Stats{} + testTime := time.Now() + updateContainerStatsValues(testStats, cpuPct, usedMemory, 1000000, 500000, testTime) + + assert.Equal(t, cpuPct, testStats.Cpu) + assert.Equal(t, bytesToMegabytes(float64(usedMemory)), testStats.Mem) + assert.Equal(t, bytesToMegabytes(1000000), testStats.NetworkSent) + assert.Equal(t, bytesToMegabytes(500000), testStats.NetworkRecv) + assert.Equal(t, testTime, testStats.PrevReadTime) +} + +func TestEdgeCasesWithRealData(t *testing.T) { + // Test with minimal container stats + minimalStats := &container.ApiStats{ + CPUStats: container.CPUStats{ + CPUUsage: container.CPUUsage{TotalUsage: 1000}, + SystemUsage: 50000, + }, + MemoryStats: container.MemoryStats{ + Usage: 1000000, + Stats: container.MemoryStatsStats{ + Cache: 0, + InactiveFile: 0, + }, + }, + Networks: map[string]container.NetworkStats{ + "eth0": {TxBytes: 1000, RxBytes: 500}, + }, + } + + // Test memory calculation with zero cache/inactive + usedMemory, err := calculateMemoryUsage(minimalStats, false) + assert.NoError(t, err) + assert.Equal(t, uint64(1000000), usedMemory) // Should equal usage when no cache + + // Test CPU percentage calculation + cpuPct := minimalStats.CalculateCpuPercentLinux(0, 0) // First run + assert.Equal(t, 0.0, cpuPct) + + // Test with Windows data + minimalStats.MemoryStats.PrivateWorkingSet = 800000 + usedMemory, err = calculateMemoryUsage(minimalStats, true) + assert.NoError(t, err) + assert.Equal(t, uint64(800000), usedMemory) +} + +func TestDockerStatsWorkflow(t *testing.T) { + // Test the complete workflow that can be tested without HTTP calls + dm := &dockerManager{ + lastCpuContainer: make(map[uint16]map[string]uint64), + lastCpuSystem: make(map[uint16]map[string]uint64), + networkSentTrackers: make(map[uint16]*deltatracker.DeltaTracker[string, uint64]), + networkRecvTrackers: make(map[uint16]*deltatracker.DeltaTracker[string, uint64]), + containerStatsMap: make(map[string]*container.Stats), + } + + cacheTimeMs := uint16(30000) + + // Test CPU tracking workflow + dm.initializeCpuTracking(cacheTimeMs) + assert.NotNil(t, dm.lastCpuContainer[cacheTimeMs]) + + // Test setting and getting CPU values + dm.setCpuCurrentValues(cacheTimeMs, "test-container", 1000, 50000) + containerVal, systemVal := dm.getCpuPreviousValues(cacheTimeMs, "test-container") + assert.Equal(t, uint64(1000), containerVal) + assert.Equal(t, uint64(50000), systemVal) + + // Test network tracking workflow (multi-interface summation) + sentTracker := dm.getNetworkTracker(cacheTimeMs, true) + recvTracker := dm.getNetworkTracker(cacheTimeMs, false) + + // Simulate two interfaces summed by setting combined totals + sentTracker.Set("test-container", 1000+2000) + recvTracker.Set("test-container", 500+700) + + deltaSent := sentTracker.Delta("test-container") + deltaRecv := recvTracker.Delta("test-container") + assert.Equal(t, uint64(0), deltaSent) // No previous value + assert.Equal(t, uint64(0), deltaRecv) + + // Cycle and test again + dm.cycleNetworkDeltasForCacheTime(cacheTimeMs) + + // Increase each interface total (combined totals go up by 1500 and 800) + sentTracker.Set("test-container", (1000+2000)+1500) + recvTracker.Set("test-container", (500+700)+800) + + deltaSent = sentTracker.Delta("test-container") + deltaRecv = recvTracker.Delta("test-container") + assert.Equal(t, uint64(1500), deltaSent) + assert.Equal(t, uint64(800), deltaRecv) +} + +func TestNetworkRateCalculationFormula(t *testing.T) { + // Test the exact formula used in calculateNetworkStats + testCases := []struct { + name string + deltaBytes uint64 + elapsedMs uint64 + expectedRate uint64 + }{ + {"1MB over 1 second", 1000000, 1000, 1000000}, + {"2MB over 1 second", 2000000, 1000, 2000000}, + {"1MB over 2 seconds", 1000000, 2000, 500000}, + {"500KB over 500ms", 500000, 500, 1000000}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // This is the exact formula from calculateNetworkStats + actualRate := tc.deltaBytes * 1000 / tc.elapsedMs + assert.Equal(t, tc.expectedRate, actualRate, + "Rate calculation should be exact: %d bytes * 1000 / %d ms = %d", + tc.deltaBytes, tc.elapsedMs, tc.expectedRate) + }) + } +} + +func TestDeltaTrackerCacheTimeIsolation(t *testing.T) { + // Test that different cache times have separate DeltaTracker instances + dm := &dockerManager{ + networkSentTrackers: make(map[uint16]*deltatracker.DeltaTracker[string, uint64]), + networkRecvTrackers: make(map[uint16]*deltatracker.DeltaTracker[string, uint64]), + } + + ctr := &container.ApiInfo{IdShort: "web-server"} + cacheTime1 := uint16(30000) + cacheTime2 := uint16(60000) + + // Get trackers for different cache times (creates separate instances) + sentTracker1 := dm.getNetworkTracker(cacheTime1, true) + recvTracker1 := dm.getNetworkTracker(cacheTime1, false) + + sentTracker2 := dm.getNetworkTracker(cacheTime2, true) + recvTracker2 := dm.getNetworkTracker(cacheTime2, false) + + // Verify they are different instances + assert.NotSame(t, sentTracker1, sentTracker2) + assert.NotSame(t, recvTracker1, recvTracker2) + + // Set values for cache time 1 + sentTracker1.Set(ctr.IdShort, 1000000) + recvTracker1.Set(ctr.IdShort, 500000) + + // Set values for cache time 2 + sentTracker2.Set(ctr.IdShort, 2000000) + recvTracker2.Set(ctr.IdShort, 1000000) + + // Verify they don't interfere (both should return 0 since no previous values) + assert.Equal(t, uint64(0), sentTracker1.Delta(ctr.IdShort)) + assert.Equal(t, uint64(0), recvTracker1.Delta(ctr.IdShort)) + assert.Equal(t, uint64(0), sentTracker2.Delta(ctr.IdShort)) + assert.Equal(t, uint64(0), recvTracker2.Delta(ctr.IdShort)) + + // Cycle cache time 1 trackers + dm.cycleNetworkDeltasForCacheTime(cacheTime1) + + // Set new values for cache time 1 + sentTracker1.Set(ctr.IdShort, 3000000) // 2MB increase + recvTracker1.Set(ctr.IdShort, 1500000) // 1MB increase + + // Cache time 1 should show deltas, cache time 2 should still be 0 + assert.Equal(t, uint64(2000000), sentTracker1.Delta(ctr.IdShort)) + assert.Equal(t, uint64(1000000), recvTracker1.Delta(ctr.IdShort)) + assert.Equal(t, uint64(0), sentTracker2.Delta(ctr.IdShort)) // Unaffected + assert.Equal(t, uint64(0), recvTracker2.Delta(ctr.IdShort)) // Unaffected + + // Cycle cache time 2 and verify it works independently + dm.cycleNetworkDeltasForCacheTime(cacheTime2) + sentTracker2.Set(ctr.IdShort, 2500000) // 0.5MB increase + recvTracker2.Set(ctr.IdShort, 1200000) // 0.2MB increase + + assert.Equal(t, uint64(500000), sentTracker2.Delta(ctr.IdShort)) + assert.Equal(t, uint64(200000), recvTracker2.Delta(ctr.IdShort)) +} + +func TestConstantsAndUtilityFunctions(t *testing.T) { + // Test constants are properly defined + assert.Equal(t, uint16(60000), defaultCacheTimeMs) + assert.Equal(t, uint64(5e9), maxNetworkSpeedBps) + assert.Equal(t, 2100, dockerTimeoutMs) + + // Test utility functions + assert.Equal(t, 1.5, twoDecimals(1.499)) + assert.Equal(t, 1.5, twoDecimals(1.5)) + assert.Equal(t, 1.5, twoDecimals(1.501)) + + assert.Equal(t, 1.0, bytesToMegabytes(1048576)) // 1 MB + assert.Equal(t, 0.5, bytesToMegabytes(524288)) // 512 KB + assert.Equal(t, 0.0, bytesToMegabytes(0)) +} diff --git a/agent/gpu.go b/agent/gpu.go index d6862fc8..b63e56cc 100644 --- a/agent/gpu.go +++ b/agent/gpu.go @@ -44,6 +44,21 @@ type GPUManager struct { tegrastats bool intelGpuStats bool GpuDataMap map[string]*system.GPUData + // lastAvgData stores the last calculated averages for each GPU + // Used when a collection happens before new data arrives (Count == 0) + lastAvgData map[string]system.GPUData + // Per-cache-key tracking for delta calculations + // cacheKey -> gpuId -> snapshot of last count/usage/power values + lastSnapshots map[uint16]map[string]*gpuSnapshot +} + +// gpuSnapshot stores the last observed incremental values for delta tracking +type gpuSnapshot struct { + count uint32 + usage float64 + power float64 + powerPkg float64 + engines map[string]float64 } // RocmSmiJson represents the JSON structure of rocm-smi output @@ -229,48 +244,21 @@ func (gm *GPUManager) parseAmdData(output []byte) bool { return true } -// sums and resets the current GPU utilization data since the last update -func (gm *GPUManager) GetCurrentData() map[string]system.GPUData { +// GetCurrentData returns GPU utilization data averaged since the last call with this cacheKey +func (gm *GPUManager) GetCurrentData(cacheKey uint16) map[string]system.GPUData { gm.Lock() defer gm.Unlock() - // check for GPUs with the same name - nameCounts := make(map[string]int) - for _, gpu := range gm.GpuDataMap { - nameCounts[gpu.Name]++ - } + gm.initializeSnapshots(cacheKey) + nameCounts := gm.countGPUNames() - // copy / reset the data gpuData := make(map[string]system.GPUData, len(gm.GpuDataMap)) for id, gpu := range gm.GpuDataMap { - // avoid division by zero - count := max(gpu.Count, 1) + gpuAvg := gm.calculateGPUAverage(id, gpu, cacheKey) + gm.updateInstantaneousValues(&gpuAvg, gpu) + gm.storeSnapshot(id, gpu, cacheKey) - // average the data - gpuAvg := *gpu - gpuAvg.Temperature = twoDecimals(gpu.Temperature) - gpuAvg.Power = twoDecimals(gpu.Power / count) - - // intel gpu stats doesn't provide usage, memory used, or memory total - if gpu.Engines != nil { - maxEngineUsage := 0.0 - for name, engine := range gpu.Engines { - gpuAvg.Engines[name] = twoDecimals(engine / count) - maxEngineUsage = max(maxEngineUsage, engine/count) - } - gpuAvg.PowerPkg = twoDecimals(gpu.PowerPkg / count) - gpuAvg.Usage = twoDecimals(maxEngineUsage) - } else { - gpuAvg.Usage = twoDecimals(gpu.Usage / count) - gpuAvg.MemoryUsed = twoDecimals(gpu.MemoryUsed) - gpuAvg.MemoryTotal = twoDecimals(gpu.MemoryTotal) - } - - // reset accumulators in the original gpu data for next collection - gpu.Usage, gpu.Power, gpu.PowerPkg, gpu.Count = gpuAvg.Usage, gpuAvg.Power, gpuAvg.PowerPkg, 1 - gpu.Engines = gpuAvg.Engines - - // append id to the name if there are multiple GPUs with the same name + // Append id to name if there are multiple GPUs with the same name if nameCounts[gpu.Name] > 1 { gpuAvg.Name = fmt.Sprintf("%s %s", gpu.Name, id) } @@ -280,6 +268,114 @@ func (gm *GPUManager) GetCurrentData() map[string]system.GPUData { return gpuData } +// initializeSnapshots ensures snapshot maps are initialized for the given cache key +func (gm *GPUManager) initializeSnapshots(cacheKey uint16) { + if gm.lastAvgData == nil { + gm.lastAvgData = make(map[string]system.GPUData) + } + if gm.lastSnapshots == nil { + gm.lastSnapshots = make(map[uint16]map[string]*gpuSnapshot) + } + if gm.lastSnapshots[cacheKey] == nil { + gm.lastSnapshots[cacheKey] = make(map[string]*gpuSnapshot) + } +} + +// countGPUNames returns a map of GPU names to their occurrence count +func (gm *GPUManager) countGPUNames() map[string]int { + nameCounts := make(map[string]int) + for _, gpu := range gm.GpuDataMap { + nameCounts[gpu.Name]++ + } + return nameCounts +} + +// calculateGPUAverage computes the average GPU metrics since the last snapshot for this cache key +func (gm *GPUManager) calculateGPUAverage(id string, gpu *system.GPUData, cacheKey uint16) system.GPUData { + lastSnapshot := gm.lastSnapshots[cacheKey][id] + currentCount := uint32(gpu.Count) + deltaCount := gm.calculateDeltaCount(currentCount, lastSnapshot) + + // If no new data arrived, use last known average + if deltaCount == 0 { + return gm.lastAvgData[id] // zero value if not found + } + + // Calculate new average + gpuAvg := *gpu + deltaUsage, deltaPower, deltaPowerPkg := gm.calculateDeltas(gpu, lastSnapshot) + + gpuAvg.Power = twoDecimals(deltaPower / float64(deltaCount)) + + if gpu.Engines != nil { + gpuAvg.Usage = gm.calculateIntelGPUUsage(&gpuAvg, gpu, lastSnapshot, deltaCount) + gpuAvg.PowerPkg = twoDecimals(deltaPowerPkg / float64(deltaCount)) + } else { + gpuAvg.Usage = twoDecimals(deltaUsage / float64(deltaCount)) + } + + gm.lastAvgData[id] = gpuAvg + return gpuAvg +} + +// calculateDeltaCount returns the change in count since the last snapshot +func (gm *GPUManager) calculateDeltaCount(currentCount uint32, lastSnapshot *gpuSnapshot) uint32 { + if lastSnapshot != nil { + return currentCount - lastSnapshot.count + } + return currentCount +} + +// calculateDeltas computes the change in usage, power, and powerPkg since the last snapshot +func (gm *GPUManager) calculateDeltas(gpu *system.GPUData, lastSnapshot *gpuSnapshot) (deltaUsage, deltaPower, deltaPowerPkg float64) { + if lastSnapshot != nil { + return gpu.Usage - lastSnapshot.usage, + gpu.Power - lastSnapshot.power, + gpu.PowerPkg - lastSnapshot.powerPkg + } + return gpu.Usage, gpu.Power, gpu.PowerPkg +} + +// calculateIntelGPUUsage computes Intel GPU usage from engine metrics and returns max engine usage +func (gm *GPUManager) calculateIntelGPUUsage(gpuAvg, gpu *system.GPUData, lastSnapshot *gpuSnapshot, deltaCount uint32) float64 { + maxEngineUsage := 0.0 + for name, engine := range gpu.Engines { + var deltaEngine float64 + if lastSnapshot != nil && lastSnapshot.engines != nil { + deltaEngine = engine - lastSnapshot.engines[name] + } else { + deltaEngine = engine + } + gpuAvg.Engines[name] = twoDecimals(deltaEngine / float64(deltaCount)) + maxEngineUsage = max(maxEngineUsage, deltaEngine/float64(deltaCount)) + } + return twoDecimals(maxEngineUsage) +} + +// updateInstantaneousValues updates values that should reflect current state, not averages +func (gm *GPUManager) updateInstantaneousValues(gpuAvg *system.GPUData, gpu *system.GPUData) { + gpuAvg.Temperature = twoDecimals(gpu.Temperature) + gpuAvg.MemoryUsed = twoDecimals(gpu.MemoryUsed) + gpuAvg.MemoryTotal = twoDecimals(gpu.MemoryTotal) +} + +// storeSnapshot saves the current GPU state for this cache key +func (gm *GPUManager) storeSnapshot(id string, gpu *system.GPUData, cacheKey uint16) { + snapshot := &gpuSnapshot{ + count: uint32(gpu.Count), + usage: gpu.Usage, + power: gpu.Power, + powerPkg: gpu.PowerPkg, + } + if gpu.Engines != nil { + snapshot.engines = make(map[string]float64, len(gpu.Engines)) + for name, value := range gpu.Engines { + snapshot.engines[name] = value + } + } + gm.lastSnapshots[cacheKey][id] = snapshot +} + // detectGPUs checks for the presence of GPU management tools (nvidia-smi, rocm-smi, tegrastats) // in the system path. It sets the corresponding flags in the GPUManager struct if any of these // tools are found. If none of the tools are found, it returns an error indicating that no GPU diff --git a/agent/gpu_test.go b/agent/gpu_test.go index 92ec531a..d204cf77 100644 --- a/agent/gpu_test.go +++ b/agent/gpu_test.go @@ -332,7 +332,7 @@ func TestParseJetsonData(t *testing.T) { } func TestGetCurrentData(t *testing.T) { - t.Run("calculates averages and resets accumulators", func(t *testing.T) { + t.Run("calculates averages with per-cache-key delta tracking", func(t *testing.T) { gm := &GPUManager{ GpuDataMap: map[string]*system.GPUData{ "0": { @@ -365,7 +365,8 @@ func TestGetCurrentData(t *testing.T) { }, } - result := gm.GetCurrentData() + cacheKey := uint16(5000) + result := gm.GetCurrentData(cacheKey) // Verify name disambiguation assert.Equal(t, "GPU1 0", result["0"].Name) @@ -378,13 +379,19 @@ func TestGetCurrentData(t *testing.T) { assert.InDelta(t, 30.0, result["1"].Usage, 0.01) assert.InDelta(t, 60.0, result["1"].Power, 0.01) - // Verify that accumulators in the original map are reset - assert.EqualValues(t, float64(1), gm.GpuDataMap["0"].Count, "GPU 0 Count should be reset") - assert.EqualValues(t, float64(50.0), gm.GpuDataMap["0"].Usage, "GPU 0 Usage should be reset") - assert.Equal(t, float64(100.0), gm.GpuDataMap["0"].Power, "GPU 0 Power should be reset") - assert.Equal(t, float64(1), gm.GpuDataMap["1"].Count, "GPU 1 Count should be reset") - assert.Equal(t, float64(30), gm.GpuDataMap["1"].Usage, "GPU 1 Usage should be reset") - assert.Equal(t, float64(60), gm.GpuDataMap["1"].Power, "GPU 1 Power should be reset") + // Verify that accumulators in the original map are NOT reset (they keep growing) + assert.EqualValues(t, 2, gm.GpuDataMap["0"].Count, "GPU 0 Count should remain at 2") + assert.EqualValues(t, 100, gm.GpuDataMap["0"].Usage, "GPU 0 Usage should remain at 100") + assert.Equal(t, 200.0, gm.GpuDataMap["0"].Power, "GPU 0 Power should remain at 200") + assert.Equal(t, 1.0, gm.GpuDataMap["1"].Count, "GPU 1 Count should remain at 1") + assert.Equal(t, 30.0, gm.GpuDataMap["1"].Usage, "GPU 1 Usage should remain at 30") + assert.Equal(t, 60.0, gm.GpuDataMap["1"].Power, "GPU 1 Power should remain at 60") + + // Verify snapshots were stored for this cache key + assert.NotNil(t, gm.lastSnapshots[cacheKey]["0"]) + assert.Equal(t, uint32(2), gm.lastSnapshots[cacheKey]["0"].count) + assert.Equal(t, 100.0, gm.lastSnapshots[cacheKey]["0"].usage) + assert.Equal(t, 200.0, gm.lastSnapshots[cacheKey]["0"].power) }) t.Run("handles zero count without panicking", func(t *testing.T) { @@ -399,17 +406,543 @@ func TestGetCurrentData(t *testing.T) { }, } + cacheKey := uint16(5000) var result map[string]system.GPUData assert.NotPanics(t, func() { - result = gm.GetCurrentData() + result = gm.GetCurrentData(cacheKey) }) // Check that usage and power are 0 assert.Equal(t, 0.0, result["0"].Usage) assert.Equal(t, 0.0, result["0"].Power) - // Verify reset count - assert.EqualValues(t, 1, gm.GpuDataMap["0"].Count) + // Verify count remains 0 + assert.EqualValues(t, 0, gm.GpuDataMap["0"].Count) + }) + + t.Run("uses last average when no new data arrives", func(t *testing.T) { + gm := &GPUManager{ + GpuDataMap: map[string]*system.GPUData{ + "0": { + Name: "TestGPU", + Temperature: 55.0, + MemoryUsed: 1500, + MemoryTotal: 8000, + Usage: 100, // Will average to 50 + Power: 200, // Will average to 100 + Count: 2, + }, + }, + } + + cacheKey := uint16(5000) + + // First collection - should calculate averages and store them + result1 := gm.GetCurrentData(cacheKey) + assert.InDelta(t, 50.0, result1["0"].Usage, 0.01) + assert.InDelta(t, 100.0, result1["0"].Power, 0.01) + assert.EqualValues(t, 2, gm.GpuDataMap["0"].Count, "Count should remain at 2") + + // Update temperature but no new usage/power data (count stays same) + gm.GpuDataMap["0"].Temperature = 60.0 + gm.GpuDataMap["0"].MemoryUsed = 1600 + + // Second collection - should use last averages since count hasn't changed (delta = 0) + result2 := gm.GetCurrentData(cacheKey) + assert.InDelta(t, 50.0, result2["0"].Usage, 0.01, "Should use last average") + assert.InDelta(t, 100.0, result2["0"].Power, 0.01, "Should use last average") + assert.InDelta(t, 60.0, result2["0"].Temperature, 0.01, "Should use current temperature") + assert.InDelta(t, 1600.0, result2["0"].MemoryUsed, 0.01, "Should use current memory") + assert.EqualValues(t, 2, gm.GpuDataMap["0"].Count, "Count should still be 2") + }) + + t.Run("tracks separate averages per cache key", func(t *testing.T) { + gm := &GPUManager{ + GpuDataMap: map[string]*system.GPUData{ + "0": { + Name: "TestGPU", + Temperature: 55.0, + MemoryUsed: 1500, + MemoryTotal: 8000, + Usage: 100, // Initial: 100 over 2 counts = 50 avg + Power: 200, // Initial: 200 over 2 counts = 100 avg + Count: 2, + }, + }, + } + + cacheKey1 := uint16(5000) + cacheKey2 := uint16(10000) + + // First check with cacheKey1 - baseline + result1 := gm.GetCurrentData(cacheKey1) + assert.InDelta(t, 50.0, result1["0"].Usage, 0.01, "CacheKey1: Initial average should be 50") + assert.InDelta(t, 100.0, result1["0"].Power, 0.01, "CacheKey1: Initial average should be 100") + + // Simulate GPU activity - accumulate more data + gm.GpuDataMap["0"].Usage += 60 // Now total: 160 + gm.GpuDataMap["0"].Power += 150 // Now total: 350 + gm.GpuDataMap["0"].Count += 3 // Now total: 5 + + // Check with cacheKey1 again - should get delta since last cacheKey1 check + result2 := gm.GetCurrentData(cacheKey1) + assert.InDelta(t, 20.0, result2["0"].Usage, 0.01, "CacheKey1: Delta average should be 60/3 = 20") + assert.InDelta(t, 50.0, result2["0"].Power, 0.01, "CacheKey1: Delta average should be 150/3 = 50") + + // Check with cacheKey2 for the first time - should get average since beginning + result3 := gm.GetCurrentData(cacheKey2) + assert.InDelta(t, 32.0, result3["0"].Usage, 0.01, "CacheKey2: Total average should be 160/5 = 32") + assert.InDelta(t, 70.0, result3["0"].Power, 0.01, "CacheKey2: Total average should be 350/5 = 70") + + // Simulate more GPU activity + gm.GpuDataMap["0"].Usage += 80 // Now total: 240 + gm.GpuDataMap["0"].Power += 160 // Now total: 510 + gm.GpuDataMap["0"].Count += 2 // Now total: 7 + + // Check with cacheKey1 - should get delta since last cacheKey1 check + result4 := gm.GetCurrentData(cacheKey1) + assert.InDelta(t, 40.0, result4["0"].Usage, 0.01, "CacheKey1: New delta average should be 80/2 = 40") + assert.InDelta(t, 80.0, result4["0"].Power, 0.01, "CacheKey1: New delta average should be 160/2 = 80") + + // Check with cacheKey2 - should get delta since last cacheKey2 check + result5 := gm.GetCurrentData(cacheKey2) + assert.InDelta(t, 40.0, result5["0"].Usage, 0.01, "CacheKey2: Delta average should be 80/2 = 40") + assert.InDelta(t, 80.0, result5["0"].Power, 0.01, "CacheKey2: Delta average should be 160/2 = 80") + + // Verify snapshots exist for both cache keys + assert.NotNil(t, gm.lastSnapshots[cacheKey1]) + assert.NotNil(t, gm.lastSnapshots[cacheKey2]) + assert.NotNil(t, gm.lastSnapshots[cacheKey1]["0"]) + assert.NotNil(t, gm.lastSnapshots[cacheKey2]["0"]) + }) +} + +func TestCalculateDeltaCount(t *testing.T) { + gm := &GPUManager{} + + t.Run("with no previous snapshot", func(t *testing.T) { + delta := gm.calculateDeltaCount(10, nil) + assert.Equal(t, uint32(10), delta, "Should return current count when no snapshot exists") + }) + + t.Run("with previous snapshot", func(t *testing.T) { + snapshot := &gpuSnapshot{count: 5} + delta := gm.calculateDeltaCount(15, snapshot) + assert.Equal(t, uint32(10), delta, "Should return difference between current and snapshot") + }) + + t.Run("with same count", func(t *testing.T) { + snapshot := &gpuSnapshot{count: 10} + delta := gm.calculateDeltaCount(10, snapshot) + assert.Equal(t, uint32(0), delta, "Should return zero when count hasn't changed") + }) +} + +func TestCalculateDeltas(t *testing.T) { + gm := &GPUManager{} + + t.Run("with no previous snapshot", func(t *testing.T) { + gpu := &system.GPUData{ + Usage: 100.5, + Power: 250.75, + PowerPkg: 300.25, + } + deltaUsage, deltaPower, deltaPowerPkg := gm.calculateDeltas(gpu, nil) + assert.Equal(t, 100.5, deltaUsage) + assert.Equal(t, 250.75, deltaPower) + assert.Equal(t, 300.25, deltaPowerPkg) + }) + + t.Run("with previous snapshot", func(t *testing.T) { + gpu := &system.GPUData{ + Usage: 150.5, + Power: 300.75, + PowerPkg: 400.25, + } + snapshot := &gpuSnapshot{ + usage: 100.5, + power: 250.75, + powerPkg: 300.25, + } + deltaUsage, deltaPower, deltaPowerPkg := gm.calculateDeltas(gpu, snapshot) + assert.InDelta(t, 50.0, deltaUsage, 0.01) + assert.InDelta(t, 50.0, deltaPower, 0.01) + assert.InDelta(t, 100.0, deltaPowerPkg, 0.01) + }) +} + +func TestCalculateIntelGPUUsage(t *testing.T) { + gm := &GPUManager{} + + t.Run("with no previous snapshot", func(t *testing.T) { + gpuAvg := &system.GPUData{ + Engines: make(map[string]float64), + } + gpu := &system.GPUData{ + Engines: map[string]float64{ + "Render/3D": 80.0, + "Video": 40.0, + "Compute": 60.0, + }, + } + maxUsage := gm.calculateIntelGPUUsage(gpuAvg, gpu, nil, 2) + + assert.Equal(t, 40.0, maxUsage, "Should return max engine usage (80/2=40)") + assert.Equal(t, 40.0, gpuAvg.Engines["Render/3D"]) + assert.Equal(t, 20.0, gpuAvg.Engines["Video"]) + assert.Equal(t, 30.0, gpuAvg.Engines["Compute"]) + }) + + t.Run("with previous snapshot", func(t *testing.T) { + gpuAvg := &system.GPUData{ + Engines: make(map[string]float64), + } + gpu := &system.GPUData{ + Engines: map[string]float64{ + "Render/3D": 180.0, + "Video": 100.0, + "Compute": 140.0, + }, + } + snapshot := &gpuSnapshot{ + engines: map[string]float64{ + "Render/3D": 80.0, + "Video": 40.0, + "Compute": 60.0, + }, + } + maxUsage := gm.calculateIntelGPUUsage(gpuAvg, gpu, snapshot, 5) + + // Deltas: Render/3D=100, Video=60, Compute=80 over 5 counts + assert.Equal(t, 20.0, maxUsage, "Should return max engine delta (100/5=20)") + assert.Equal(t, 20.0, gpuAvg.Engines["Render/3D"]) + assert.Equal(t, 12.0, gpuAvg.Engines["Video"]) + assert.Equal(t, 16.0, gpuAvg.Engines["Compute"]) + }) + + t.Run("handles missing engine in snapshot", func(t *testing.T) { + gpuAvg := &system.GPUData{ + Engines: make(map[string]float64), + } + gpu := &system.GPUData{ + Engines: map[string]float64{ + "Render/3D": 100.0, + "NewEngine": 50.0, + }, + } + snapshot := &gpuSnapshot{ + engines: map[string]float64{ + "Render/3D": 80.0, + // NewEngine doesn't exist in snapshot + }, + } + maxUsage := gm.calculateIntelGPUUsage(gpuAvg, gpu, snapshot, 2) + + assert.Equal(t, 25.0, maxUsage) + assert.Equal(t, 10.0, gpuAvg.Engines["Render/3D"], "Should use delta for existing engine") + assert.Equal(t, 25.0, gpuAvg.Engines["NewEngine"], "Should use full value for new engine") + }) +} + +func TestUpdateInstantaneousValues(t *testing.T) { + gm := &GPUManager{} + + t.Run("updates temperature, memory used and total", func(t *testing.T) { + gpuAvg := &system.GPUData{ + Temperature: 50.123, + MemoryUsed: 1000.456, + MemoryTotal: 8000.789, + } + gpu := &system.GPUData{ + Temperature: 75.567, + MemoryUsed: 2500.891, + MemoryTotal: 8192.234, + } + + gm.updateInstantaneousValues(gpuAvg, gpu) + + assert.Equal(t, 75.57, gpuAvg.Temperature, "Should update and round temperature") + assert.Equal(t, 2500.89, gpuAvg.MemoryUsed, "Should update and round memory used") + assert.Equal(t, 8192.23, gpuAvg.MemoryTotal, "Should update and round memory total") + }) +} + +func TestStoreSnapshot(t *testing.T) { + gm := &GPUManager{ + lastSnapshots: make(map[uint16]map[string]*gpuSnapshot), + } + + t.Run("stores standard GPU snapshot", func(t *testing.T) { + cacheKey := uint16(5000) + gm.lastSnapshots[cacheKey] = make(map[string]*gpuSnapshot) + + gpu := &system.GPUData{ + Count: 10.0, + Usage: 150.5, + Power: 250.75, + PowerPkg: 300.25, + } + + gm.storeSnapshot("0", gpu, cacheKey) + + snapshot := gm.lastSnapshots[cacheKey]["0"] + assert.NotNil(t, snapshot) + assert.Equal(t, uint32(10), snapshot.count) + assert.Equal(t, 150.5, snapshot.usage) + assert.Equal(t, 250.75, snapshot.power) + assert.Equal(t, 300.25, snapshot.powerPkg) + assert.Nil(t, snapshot.engines, "Should not have engines for standard GPU") + }) + + t.Run("stores Intel GPU snapshot with engines", func(t *testing.T) { + cacheKey := uint16(10000) + gm.lastSnapshots[cacheKey] = make(map[string]*gpuSnapshot) + + gpu := &system.GPUData{ + Count: 5.0, + Usage: 100.0, + Power: 200.0, + PowerPkg: 250.0, + Engines: map[string]float64{ + "Render/3D": 80.0, + "Video": 40.0, + }, + } + + gm.storeSnapshot("0", gpu, cacheKey) + + snapshot := gm.lastSnapshots[cacheKey]["0"] + assert.NotNil(t, snapshot) + assert.Equal(t, uint32(5), snapshot.count) + assert.NotNil(t, snapshot.engines, "Should have engines for Intel GPU") + assert.Equal(t, 80.0, snapshot.engines["Render/3D"]) + assert.Equal(t, 40.0, snapshot.engines["Video"]) + assert.Len(t, snapshot.engines, 2) + }) + + t.Run("overwrites existing snapshot", func(t *testing.T) { + cacheKey := uint16(5000) + gm.lastSnapshots[cacheKey] = make(map[string]*gpuSnapshot) + + // Store initial snapshot + gpu1 := &system.GPUData{Count: 5.0, Usage: 100.0, Power: 200.0} + gm.storeSnapshot("0", gpu1, cacheKey) + + // Store updated snapshot + gpu2 := &system.GPUData{Count: 10.0, Usage: 250.0, Power: 400.0} + gm.storeSnapshot("0", gpu2, cacheKey) + + snapshot := gm.lastSnapshots[cacheKey]["0"] + assert.Equal(t, uint32(10), snapshot.count, "Should overwrite previous count") + assert.Equal(t, 250.0, snapshot.usage, "Should overwrite previous usage") + assert.Equal(t, 400.0, snapshot.power, "Should overwrite previous power") + }) +} + +func TestCountGPUNames(t *testing.T) { + t.Run("returns empty map for no GPUs", func(t *testing.T) { + gm := &GPUManager{ + GpuDataMap: make(map[string]*system.GPUData), + } + counts := gm.countGPUNames() + assert.Empty(t, counts) + }) + + t.Run("counts unique GPU names", func(t *testing.T) { + gm := &GPUManager{ + GpuDataMap: map[string]*system.GPUData{ + "0": {Name: "GPU A"}, + "1": {Name: "GPU B"}, + "2": {Name: "GPU C"}, + }, + } + counts := gm.countGPUNames() + assert.Equal(t, 1, counts["GPU A"]) + assert.Equal(t, 1, counts["GPU B"]) + assert.Equal(t, 1, counts["GPU C"]) + assert.Len(t, counts, 3) + }) + + t.Run("counts duplicate GPU names", func(t *testing.T) { + gm := &GPUManager{ + GpuDataMap: map[string]*system.GPUData{ + "0": {Name: "RTX 4090"}, + "1": {Name: "RTX 4090"}, + "2": {Name: "RTX 4090"}, + "3": {Name: "RTX 3080"}, + }, + } + counts := gm.countGPUNames() + assert.Equal(t, 3, counts["RTX 4090"]) + assert.Equal(t, 1, counts["RTX 3080"]) + assert.Len(t, counts, 2) + }) +} + +func TestInitializeSnapshots(t *testing.T) { + t.Run("initializes all maps from scratch", func(t *testing.T) { + gm := &GPUManager{} + cacheKey := uint16(5000) + + gm.initializeSnapshots(cacheKey) + + assert.NotNil(t, gm.lastAvgData) + assert.NotNil(t, gm.lastSnapshots) + assert.NotNil(t, gm.lastSnapshots[cacheKey]) + }) + + t.Run("initializes only missing maps", func(t *testing.T) { + gm := &GPUManager{ + lastAvgData: make(map[string]system.GPUData), + } + cacheKey := uint16(5000) + + gm.initializeSnapshots(cacheKey) + + assert.NotNil(t, gm.lastAvgData, "Should preserve existing lastAvgData") + assert.NotNil(t, gm.lastSnapshots) + assert.NotNil(t, gm.lastSnapshots[cacheKey]) + }) + + t.Run("adds new cache key to existing snapshots", func(t *testing.T) { + existingKey := uint16(5000) + newKey := uint16(10000) + + gm := &GPUManager{ + lastSnapshots: map[uint16]map[string]*gpuSnapshot{ + existingKey: {"0": {count: 10}}, + }, + } + + gm.initializeSnapshots(newKey) + + assert.NotNil(t, gm.lastSnapshots[existingKey], "Should preserve existing cache key") + assert.NotNil(t, gm.lastSnapshots[newKey], "Should add new cache key") + assert.NotNil(t, gm.lastSnapshots[existingKey]["0"], "Should preserve existing snapshot data") + }) +} + +func TestCalculateGPUAverage(t *testing.T) { + t.Run("returns zero value when deltaCount is zero", func(t *testing.T) { + gm := &GPUManager{ + lastSnapshots: map[uint16]map[string]*gpuSnapshot{ + 5000: { + "0": {count: 10, usage: 100, power: 200}, + }, + }, + lastAvgData: map[string]system.GPUData{ + "0": {Usage: 50.0, Power: 100.0}, + }, + } + + gpu := &system.GPUData{ + Count: 10.0, // Same as snapshot, so delta = 0 + Usage: 100.0, + Power: 200.0, + } + + result := gm.calculateGPUAverage("0", gpu, 5000) + + assert.Equal(t, 50.0, result.Usage, "Should return cached average") + assert.Equal(t, 100.0, result.Power, "Should return cached average") + }) + + t.Run("calculates average for standard GPU", func(t *testing.T) { + gm := &GPUManager{ + lastSnapshots: map[uint16]map[string]*gpuSnapshot{ + 5000: {}, + }, + lastAvgData: make(map[string]system.GPUData), + } + + gpu := &system.GPUData{ + Name: "Test GPU", + Count: 4.0, + Usage: 200.0, // 200 / 4 = 50 + Power: 400.0, // 400 / 4 = 100 + } + + result := gm.calculateGPUAverage("0", gpu, 5000) + + assert.Equal(t, 50.0, result.Usage) + assert.Equal(t, 100.0, result.Power) + assert.Equal(t, "Test GPU", result.Name) + }) + + t.Run("calculates average for Intel GPU with engines", func(t *testing.T) { + gm := &GPUManager{ + lastSnapshots: map[uint16]map[string]*gpuSnapshot{ + 5000: {}, + }, + lastAvgData: make(map[string]system.GPUData), + } + + gpu := &system.GPUData{ + Name: "Intel GPU", + Count: 5.0, + Power: 500.0, + PowerPkg: 600.0, + Engines: map[string]float64{ + "Render/3D": 100.0, // 100 / 5 = 20 + "Video": 50.0, // 50 / 5 = 10 + }, + } + + result := gm.calculateGPUAverage("0", gpu, 5000) + + assert.Equal(t, 100.0, result.Power) + assert.Equal(t, 120.0, result.PowerPkg) + assert.Equal(t, 20.0, result.Usage, "Should use max engine usage") + assert.Equal(t, 20.0, result.Engines["Render/3D"]) + assert.Equal(t, 10.0, result.Engines["Video"]) + }) + + t.Run("calculates delta from previous snapshot", func(t *testing.T) { + gm := &GPUManager{ + lastSnapshots: map[uint16]map[string]*gpuSnapshot{ + 5000: { + "0": { + count: 2, + usage: 50.0, + power: 100.0, + powerPkg: 120.0, + }, + }, + }, + lastAvgData: make(map[string]system.GPUData), + } + + gpu := &system.GPUData{ + Name: "Test GPU", + Count: 7.0, // Delta = 7 - 2 = 5 + Usage: 200.0, // Delta = 200 - 50 = 150, avg = 150/5 = 30 + Power: 350.0, // Delta = 350 - 100 = 250, avg = 250/5 = 50 + PowerPkg: 420.0, // Delta = 420 - 120 = 300, avg = 300/5 = 60 + } + + result := gm.calculateGPUAverage("0", gpu, 5000) + + assert.Equal(t, 30.0, result.Usage) + assert.Equal(t, 50.0, result.Power) + }) + + t.Run("stores result in lastAvgData", func(t *testing.T) { + gm := &GPUManager{ + lastSnapshots: map[uint16]map[string]*gpuSnapshot{ + 5000: {}, + }, + lastAvgData: make(map[string]system.GPUData), + } + + gpu := &system.GPUData{ + Count: 2.0, + Usage: 100.0, + Power: 200.0, + } + + result := gm.calculateGPUAverage("0", gpu, 5000) + + assert.Equal(t, result, gm.lastAvgData["0"], "Should store calculated average") }) } @@ -765,7 +1298,8 @@ func TestAccumulation(t *testing.T) { } // Verify average calculation in GetCurrentData - result := gm.GetCurrentData() + cacheKey := uint16(5000) + result := gm.GetCurrentData(cacheKey) for id, expected := range tt.expectedValues { gpu, exists := result[id] assert.True(t, exists, "GPU with ID %s should exist in GetCurrentData result", id) @@ -778,16 +1312,16 @@ func TestAccumulation(t *testing.T) { assert.EqualValues(t, expected.avgPower, gpu.Power, "Average power in GetCurrentData should match") } - // Verify that accumulators in the original map are reset + // Verify that accumulators in the original map are NOT reset (they keep growing) for id, expected := range tt.expectedValues { gpu, exists := gm.GpuDataMap[id] assert.True(t, exists, "GPU with ID %s should still exist after GetCurrentData", id) if !exists { continue } - assert.EqualValues(t, 1, gpu.Count, "Count should be reset for GPU ID %s", id) - assert.EqualValues(t, expected.avgUsage, gpu.Usage, "Usage should be reset for GPU ID %s", id) - assert.EqualValues(t, expected.avgPower, gpu.Power, "Power should be reset for GPU ID %s", id) + assert.EqualValues(t, expected.count, gpu.Count, "Count should remain at accumulated value for GPU ID %s", id) + assert.EqualValues(t, expected.usage, gpu.Usage, "Usage should remain at accumulated value for GPU ID %s", id) + assert.EqualValues(t, expected.power, gpu.Power, "Power should remain at accumulated value for GPU ID %s", id) } }) } diff --git a/agent/handlers.go b/agent/handlers.go new file mode 100644 index 00000000..0553af09 --- /dev/null +++ b/agent/handlers.go @@ -0,0 +1,101 @@ +package agent + +import ( + "errors" + "fmt" + + "github.com/fxamacker/cbor/v2" + "github.com/henrygd/beszel/internal/common" +) + +// HandlerContext provides context for request handlers +type HandlerContext struct { + Client *WebSocketClient + Agent *Agent + Request *common.HubRequest[cbor.RawMessage] + RequestID *uint32 + HubVerified bool + // SendResponse abstracts how a handler sends responses (WS or SSH) + SendResponse func(data any, requestID *uint32) error +} + +// RequestHandler defines the interface for handling specific websocket request types +type RequestHandler interface { + // Handle processes the request and returns an error if unsuccessful + Handle(hctx *HandlerContext) error +} + +// Responder sends handler responses back to the hub (over WS or SSH) +type Responder interface { + SendResponse(data any, requestID *uint32) error +} + +// HandlerRegistry manages the mapping between actions and their handlers +type HandlerRegistry struct { + handlers map[common.WebSocketAction]RequestHandler +} + +// NewHandlerRegistry creates a new handler registry with default handlers +func NewHandlerRegistry() *HandlerRegistry { + registry := &HandlerRegistry{ + handlers: make(map[common.WebSocketAction]RequestHandler), + } + + registry.Register(common.GetData, &GetDataHandler{}) + registry.Register(common.CheckFingerprint, &CheckFingerprintHandler{}) + + return registry +} + +// Register registers a handler for a specific action type +func (hr *HandlerRegistry) Register(action common.WebSocketAction, handler RequestHandler) { + hr.handlers[action] = handler +} + +// Handle routes the request to the appropriate handler +func (hr *HandlerRegistry) Handle(hctx *HandlerContext) error { + handler, exists := hr.handlers[hctx.Request.Action] + if !exists { + return fmt.Errorf("unknown action: %d", hctx.Request.Action) + } + + // Check verification requirement - default to requiring verification + if hctx.Request.Action != common.CheckFingerprint && !hctx.HubVerified { + return errors.New("hub not verified") + } + + // Log handler execution for debugging + // slog.Debug("Executing handler", "action", hctx.Request.Action) + + return handler.Handle(hctx) +} + +// GetHandler returns the handler for a specific action +func (hr *HandlerRegistry) GetHandler(action common.WebSocketAction) (RequestHandler, bool) { + handler, exists := hr.handlers[action] + return handler, exists +} + +//////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////// + +// GetDataHandler handles system data requests +type GetDataHandler struct{} + +func (h *GetDataHandler) Handle(hctx *HandlerContext) error { + var options common.DataRequestOptions + _ = cbor.Unmarshal(hctx.Request.Data, &options) + + sysStats := hctx.Agent.gatherStats(options.CacheTimeMs) + return hctx.SendResponse(sysStats, hctx.RequestID) +} + +//////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////// + +// CheckFingerprintHandler handles authentication challenges +type CheckFingerprintHandler struct{} + +func (h *CheckFingerprintHandler) Handle(hctx *HandlerContext) error { + return hctx.Client.handleAuthChallenge(hctx.Request, hctx.RequestID) +} diff --git a/agent/handlers_test.go b/agent/handlers_test.go new file mode 100644 index 00000000..6a40c618 --- /dev/null +++ b/agent/handlers_test.go @@ -0,0 +1,112 @@ +//go:build testing +// +build testing + +package agent + +import ( + "testing" + + "github.com/fxamacker/cbor/v2" + "github.com/henrygd/beszel/internal/common" + "github.com/stretchr/testify/assert" +) + +// MockHandler for testing +type MockHandler struct { + requiresVerification bool + description string + handleFunc func(ctx *HandlerContext) error +} + +func (m *MockHandler) Handle(ctx *HandlerContext) error { + if m.handleFunc != nil { + return m.handleFunc(ctx) + } + return nil +} + +func (m *MockHandler) RequiresVerification() bool { + return m.requiresVerification +} + +// TestHandlerRegistry tests the handler registry functionality +func TestHandlerRegistry(t *testing.T) { + t.Run("default registration", func(t *testing.T) { + registry := NewHandlerRegistry() + + // Check default handlers are registered + getDataHandler, exists := registry.GetHandler(common.GetData) + assert.True(t, exists) + assert.IsType(t, &GetDataHandler{}, getDataHandler) + + fingerprintHandler, exists := registry.GetHandler(common.CheckFingerprint) + assert.True(t, exists) + assert.IsType(t, &CheckFingerprintHandler{}, fingerprintHandler) + }) + + t.Run("custom handler registration", func(t *testing.T) { + registry := NewHandlerRegistry() + mockHandler := &MockHandler{ + requiresVerification: true, + description: "Test handler", + } + + // Register a custom handler for a mock action + const mockAction common.WebSocketAction = 99 + registry.Register(mockAction, mockHandler) + + // Verify registration + handler, exists := registry.GetHandler(mockAction) + assert.True(t, exists) + assert.Equal(t, mockHandler, handler) + }) + + t.Run("unknown action", func(t *testing.T) { + registry := NewHandlerRegistry() + ctx := &HandlerContext{ + Request: &common.HubRequest[cbor.RawMessage]{ + Action: common.WebSocketAction(255), // Unknown action + }, + HubVerified: true, + } + + err := registry.Handle(ctx) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unknown action: 255") + }) + + t.Run("verification required", func(t *testing.T) { + registry := NewHandlerRegistry() + ctx := &HandlerContext{ + Request: &common.HubRequest[cbor.RawMessage]{ + Action: common.GetData, // Requires verification + }, + HubVerified: false, // Not verified + } + + err := registry.Handle(ctx) + assert.Error(t, err) + assert.Contains(t, err.Error(), "hub not verified") + }) +} + +// TestCheckFingerprintHandler tests the CheckFingerprint handler +func TestCheckFingerprintHandler(t *testing.T) { + handler := &CheckFingerprintHandler{} + + t.Run("handle with invalid data", func(t *testing.T) { + client := &WebSocketClient{} + ctx := &HandlerContext{ + Client: client, + HubVerified: false, + Request: &common.HubRequest[cbor.RawMessage]{ + Action: common.CheckFingerprint, + Data: cbor.RawMessage{}, // Empty/invalid data + }, + } + + // Should fail to decode the fingerprint request + err := handler.Handle(ctx) + assert.Error(t, err) + }) +} diff --git a/agent/network.go b/agent/network.go index 47843d6a..90bb67bd 100644 --- a/agent/network.go +++ b/agent/network.go @@ -12,8 +12,6 @@ import ( psutilNet "github.com/shirou/gopsutil/v4/net" ) -var netInterfaceDeltaTracker = deltatracker.NewDeltaTracker[string, uint64]() - // NicConfig controls inclusion/exclusion of network interfaces via the NICS env var // // Behavior mirrors SensorConfig's matching logic: @@ -77,75 +75,17 @@ func isValidNic(nicName string, cfg *NicConfig) bool { return cfg.isBlacklist } -func (a *Agent) updateNetworkStats(systemStats *system.Stats) { +func (a *Agent) updateNetworkStats(cacheTimeMs uint16, systemStats *system.Stats) { // network stats - if len(a.netInterfaces) == 0 { - // if no network interfaces, initialize again - // this is a fix if agent started before network is online (#466) - // maybe refactor this in the future to not cache interface names at all so we - // don't miss an interface that's been added after agent started in any circumstance - a.initializeNetIoStats() - } + a.ensureNetInterfacesInitialized() - if systemStats.NetworkInterfaces == nil { - systemStats.NetworkInterfaces = make(map[string][4]uint64, 0) - } + a.ensureNetworkInterfacesMap(systemStats) if netIO, err := psutilNet.IOCounters(true); err == nil { - msElapsed := uint64(time.Since(a.netIoStats.Time).Milliseconds()) - a.netIoStats.Time = time.Now() - totalBytesSent := uint64(0) - totalBytesRecv := uint64(0) - netInterfaceDeltaTracker.Cycle() - // sum all bytes sent and received - for _, v := range netIO { - // skip if not in valid network interfaces list - if _, exists := a.netInterfaces[v.Name]; !exists { - continue - } - totalBytesSent += v.BytesSent - totalBytesRecv += v.BytesRecv - - // track deltas for each network interface - var upDelta, downDelta uint64 - upKey, downKey := fmt.Sprintf("%sup", v.Name), fmt.Sprintf("%sdown", v.Name) - netInterfaceDeltaTracker.Set(upKey, v.BytesSent) - netInterfaceDeltaTracker.Set(downKey, v.BytesRecv) - if msElapsed > 0 { - upDelta = netInterfaceDeltaTracker.Delta(upKey) * 1000 / msElapsed - downDelta = netInterfaceDeltaTracker.Delta(downKey) * 1000 / msElapsed - } - // add interface to systemStats - systemStats.NetworkInterfaces[v.Name] = [4]uint64{upDelta, downDelta, v.BytesSent, v.BytesRecv} - } - - // add to systemStats - var bytesSentPerSecond, bytesRecvPerSecond uint64 - if msElapsed > 0 { - bytesSentPerSecond = (totalBytesSent - a.netIoStats.BytesSent) * 1000 / msElapsed - bytesRecvPerSecond = (totalBytesRecv - a.netIoStats.BytesRecv) * 1000 / msElapsed - } - networkSentPs := bytesToMegabytes(float64(bytesSentPerSecond)) - networkRecvPs := bytesToMegabytes(float64(bytesRecvPerSecond)) - // add check for issue (#150) where sent is a massive number - if networkSentPs > 10_000 || networkRecvPs > 10_000 { - slog.Warn("Invalid net stats. Resetting.", "sent", networkSentPs, "recv", networkRecvPs) - for _, v := range netIO { - if _, exists := a.netInterfaces[v.Name]; !exists { - continue - } - slog.Info(v.Name, "recv", v.BytesRecv, "sent", v.BytesSent) - } - // reset network I/O stats - a.initializeNetIoStats() - } else { - systemStats.NetworkSent = networkSentPs - systemStats.NetworkRecv = networkRecvPs - systemStats.Bandwidth[0], systemStats.Bandwidth[1] = bytesSentPerSecond, bytesRecvPerSecond - // update netIoStats - a.netIoStats.BytesSent = totalBytesSent - a.netIoStats.BytesRecv = totalBytesRecv - } + nis, msElapsed := a.loadAndTickNetBaseline(cacheTimeMs) + totalBytesSent, totalBytesRecv := a.sumAndTrackPerNicDeltas(cacheTimeMs, msElapsed, netIO, systemStats) + bytesSentPerSecond, bytesRecvPerSecond := a.computeBytesPerSecond(msElapsed, totalBytesSent, totalBytesRecv, nis) + a.applyNetworkTotals(cacheTimeMs, netIO, systemStats, nis, totalBytesSent, totalBytesRecv, bytesSentPerSecond, bytesRecvPerSecond) } } @@ -160,13 +100,8 @@ func (a *Agent) initializeNetIoStats() { nicCfg = newNicConfig(nicsEnvVal) } - // reset network I/O stats - a.netIoStats.BytesSent = 0 - a.netIoStats.BytesRecv = 0 - - // get intial network I/O stats + // get current network I/O stats and record valid interfaces if netIO, err := psutilNet.IOCounters(true); err == nil { - a.netIoStats.Time = time.Now() for _, v := range netIO { if nicsEnvExists && !isValidNic(v.Name, nicCfg) { continue @@ -175,12 +110,116 @@ func (a *Agent) initializeNetIoStats() { continue } slog.Info("Detected network interface", "name", v.Name, "sent", v.BytesSent, "recv", v.BytesRecv) - a.netIoStats.BytesSent += v.BytesSent - a.netIoStats.BytesRecv += v.BytesRecv // store as a valid network interface a.netInterfaces[v.Name] = struct{}{} } } + + // Reset per-cache-time trackers and baselines so they will reinitialize on next use + a.netInterfaceDeltaTrackers = make(map[uint16]*deltatracker.DeltaTracker[string, uint64]) + a.netIoStats = make(map[uint16]system.NetIoStats) +} + +// ensureNetInterfacesInitialized re-initializes NICs if none are currently tracked +func (a *Agent) ensureNetInterfacesInitialized() { + if len(a.netInterfaces) == 0 { + // if no network interfaces, initialize again + // this is a fix if agent started before network is online (#466) + // maybe refactor this in the future to not cache interface names at all so we + // don't miss an interface that's been added after agent started in any circumstance + a.initializeNetIoStats() + } +} + +// ensureNetworkInterfacesMap ensures systemStats.NetworkInterfaces map exists +func (a *Agent) ensureNetworkInterfacesMap(systemStats *system.Stats) { + if systemStats.NetworkInterfaces == nil { + systemStats.NetworkInterfaces = make(map[string][4]uint64, 0) + } +} + +// loadAndTickNetBaseline returns the NetIoStats baseline and milliseconds elapsed, updating time +func (a *Agent) loadAndTickNetBaseline(cacheTimeMs uint16) (netIoStat system.NetIoStats, msElapsed uint64) { + netIoStat = a.netIoStats[cacheTimeMs] + if netIoStat.Time.IsZero() { + netIoStat.Time = time.Now() + msElapsed = 0 + } else { + msElapsed = uint64(time.Since(netIoStat.Time).Milliseconds()) + netIoStat.Time = time.Now() + } + return netIoStat, msElapsed +} + +// sumAndTrackPerNicDeltas accumulates totals and records per-NIC up/down deltas into systemStats +func (a *Agent) sumAndTrackPerNicDeltas(cacheTimeMs uint16, msElapsed uint64, netIO []psutilNet.IOCountersStat, systemStats *system.Stats) (totalBytesSent, totalBytesRecv uint64) { + tracker := a.netInterfaceDeltaTrackers[cacheTimeMs] + if tracker == nil { + tracker = deltatracker.NewDeltaTracker[string, uint64]() + a.netInterfaceDeltaTrackers[cacheTimeMs] = tracker + } + tracker.Cycle() + + for _, v := range netIO { + if _, exists := a.netInterfaces[v.Name]; !exists { + continue + } + totalBytesSent += v.BytesSent + totalBytesRecv += v.BytesRecv + + var upDelta, downDelta uint64 + upKey, downKey := fmt.Sprintf("%sup", v.Name), fmt.Sprintf("%sdown", v.Name) + tracker.Set(upKey, v.BytesSent) + tracker.Set(downKey, v.BytesRecv) + if msElapsed > 0 { + upDelta = tracker.Delta(upKey) * 1000 / msElapsed + downDelta = tracker.Delta(downKey) * 1000 / msElapsed + } + systemStats.NetworkInterfaces[v.Name] = [4]uint64{upDelta, downDelta, v.BytesSent, v.BytesRecv} + } + + return totalBytesSent, totalBytesRecv +} + +// computeBytesPerSecond calculates per-second totals from elapsed time and totals +func (a *Agent) computeBytesPerSecond(msElapsed, totalBytesSent, totalBytesRecv uint64, nis system.NetIoStats) (bytesSentPerSecond, bytesRecvPerSecond uint64) { + if msElapsed > 0 { + bytesSentPerSecond = (totalBytesSent - nis.BytesSent) * 1000 / msElapsed + bytesRecvPerSecond = (totalBytesRecv - nis.BytesRecv) * 1000 / msElapsed + } + return bytesSentPerSecond, bytesRecvPerSecond +} + +// applyNetworkTotals validates and writes computed network stats, or resets on anomaly +func (a *Agent) applyNetworkTotals( + cacheTimeMs uint16, + netIO []psutilNet.IOCountersStat, + systemStats *system.Stats, + nis system.NetIoStats, + totalBytesSent, totalBytesRecv uint64, + bytesSentPerSecond, bytesRecvPerSecond uint64, +) { + networkSentPs := bytesToMegabytes(float64(bytesSentPerSecond)) + networkRecvPs := bytesToMegabytes(float64(bytesRecvPerSecond)) + if networkSentPs > 10_000 || networkRecvPs > 10_000 { + slog.Warn("Invalid net stats. Resetting.", "sent", networkSentPs, "recv", networkRecvPs) + for _, v := range netIO { + if _, exists := a.netInterfaces[v.Name]; !exists { + continue + } + slog.Info(v.Name, "recv", v.BytesRecv, "sent", v.BytesSent) + } + a.initializeNetIoStats() + delete(a.netIoStats, cacheTimeMs) + delete(a.netInterfaceDeltaTrackers, cacheTimeMs) + } + + systemStats.NetworkSent = networkSentPs + systemStats.NetworkRecv = networkRecvPs + systemStats.Bandwidth[0], systemStats.Bandwidth[1] = bytesSentPerSecond, bytesRecvPerSecond + nis.BytesSent = totalBytesSent + nis.BytesRecv = totalBytesRecv + a.netIoStats[cacheTimeMs] = nis } func (a *Agent) skipNetworkInterface(v psutilNet.IOCountersStat) bool { diff --git a/agent/network_test.go b/agent/network_test.go index beb56fbe..bc3833fa 100644 --- a/agent/network_test.go +++ b/agent/network_test.go @@ -4,7 +4,11 @@ package agent import ( "testing" + "time" + "github.com/henrygd/beszel/agent/deltatracker" + "github.com/henrygd/beszel/internal/entities/system" + psutilNet "github.com/shirou/gopsutil/v4/net" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -257,3 +261,202 @@ func TestNewNicConfig(t *testing.T) { }) } } +func TestEnsureNetworkInterfacesMap(t *testing.T) { + var a Agent + var stats system.Stats + + // Initially nil + assert.Nil(t, stats.NetworkInterfaces) + // Ensure map is created + a.ensureNetworkInterfacesMap(&stats) + assert.NotNil(t, stats.NetworkInterfaces) + // Idempotent + a.ensureNetworkInterfacesMap(&stats) + assert.NotNil(t, stats.NetworkInterfaces) +} + +func TestLoadAndTickNetBaseline(t *testing.T) { + a := &Agent{netIoStats: make(map[uint16]system.NetIoStats)} + + // First call initializes time and returns 0 elapsed + ni, elapsed := a.loadAndTickNetBaseline(100) + assert.Equal(t, uint64(0), elapsed) + assert.False(t, ni.Time.IsZero()) + + // Store back what loadAndTick returns to mimic updateNetworkStats behavior + a.netIoStats[100] = ni + + time.Sleep(2 * time.Millisecond) + + // Next call should produce >= 0 elapsed and update time + ni2, elapsed2 := a.loadAndTickNetBaseline(100) + assert.True(t, elapsed2 > 0) + assert.False(t, ni2.Time.IsZero()) +} + +func TestComputeBytesPerSecond(t *testing.T) { + a := &Agent{} + + // No elapsed -> zero rate + bytesUp, bytesDown := a.computeBytesPerSecond(0, 2000, 3000, system.NetIoStats{BytesSent: 1000, BytesRecv: 1000}) + assert.Equal(t, uint64(0), bytesUp) + assert.Equal(t, uint64(0), bytesDown) + + // With elapsed -> per-second calculation + bytesUp, bytesDown = a.computeBytesPerSecond(500, 6000, 11000, system.NetIoStats{BytesSent: 1000, BytesRecv: 1000}) + // (6000-1000)*1000/500 = 10000; (11000-1000)*1000/500 = 20000 + assert.Equal(t, uint64(10000), bytesUp) + assert.Equal(t, uint64(20000), bytesDown) +} + +func TestSumAndTrackPerNicDeltas(t *testing.T) { + a := &Agent{ + netInterfaces: map[string]struct{}{"eth0": {}, "wlan0": {}}, + netInterfaceDeltaTrackers: make(map[uint16]*deltatracker.DeltaTracker[string, uint64]), + } + + // Two samples for same cache interval to verify delta behavior + cache := uint16(42) + net1 := []psutilNet.IOCountersStat{{Name: "eth0", BytesSent: 1000, BytesRecv: 2000}} + stats1 := &system.Stats{} + a.ensureNetworkInterfacesMap(stats1) + tx1, rx1 := a.sumAndTrackPerNicDeltas(cache, 0, net1, stats1) + assert.Equal(t, uint64(1000), tx1) + assert.Equal(t, uint64(2000), rx1) + + // Second cycle with elapsed, larger counters -> deltas computed inside + net2 := []psutilNet.IOCountersStat{{Name: "eth0", BytesSent: 4000, BytesRecv: 9000}} + stats := &system.Stats{} + a.ensureNetworkInterfacesMap(stats) + tx2, rx2 := a.sumAndTrackPerNicDeltas(cache, 1000, net2, stats) + assert.Equal(t, uint64(4000), tx2) + assert.Equal(t, uint64(9000), rx2) + // Up/Down deltas per second should be (4000-1000)/1s = 3000 and (9000-2000)/1s = 7000 + ni, ok := stats.NetworkInterfaces["eth0"] + assert.True(t, ok) + assert.Equal(t, uint64(3000), ni[0]) + assert.Equal(t, uint64(7000), ni[1]) +} + +func TestApplyNetworkTotals(t *testing.T) { + tests := []struct { + name string + bytesSentPerSecond uint64 + bytesRecvPerSecond uint64 + totalBytesSent uint64 + totalBytesRecv uint64 + expectReset bool + expectedNetworkSent float64 + expectedNetworkRecv float64 + expectedBandwidthSent uint64 + expectedBandwidthRecv uint64 + }{ + { + name: "Valid network stats - normal values", + bytesSentPerSecond: 1000000, // 1 MB/s + bytesRecvPerSecond: 2000000, // 2 MB/s + totalBytesSent: 10000000, + totalBytesRecv: 20000000, + expectReset: false, + expectedNetworkSent: 0.95, // ~1 MB/s rounded to 2 decimals + expectedNetworkRecv: 1.91, // ~2 MB/s rounded to 2 decimals + expectedBandwidthSent: 1000000, + expectedBandwidthRecv: 2000000, + }, + { + name: "Invalid network stats - sent exceeds threshold", + bytesSentPerSecond: 11000000000, // ~10.5 GB/s > 10 GB/s threshold + bytesRecvPerSecond: 1000000, // 1 MB/s + totalBytesSent: 10000000, + totalBytesRecv: 20000000, + expectReset: true, + }, + { + name: "Invalid network stats - recv exceeds threshold", + bytesSentPerSecond: 1000000, // 1 MB/s + bytesRecvPerSecond: 11000000000, // ~10.5 GB/s > 10 GB/s threshold + totalBytesSent: 10000000, + totalBytesRecv: 20000000, + expectReset: true, + }, + { + name: "Invalid network stats - both exceed threshold", + bytesSentPerSecond: 12000000000, // ~11.4 GB/s + bytesRecvPerSecond: 13000000000, // ~12.4 GB/s + totalBytesSent: 10000000, + totalBytesRecv: 20000000, + expectReset: true, + }, + { + name: "Valid network stats - at threshold boundary", + bytesSentPerSecond: 10485750000, // ~9999.99 MB/s (rounds to 9999.99) + bytesRecvPerSecond: 10485750000, // ~9999.99 MB/s (rounds to 9999.99) + totalBytesSent: 10000000, + totalBytesRecv: 20000000, + expectReset: false, + expectedNetworkSent: 9999.99, + expectedNetworkRecv: 9999.99, + expectedBandwidthSent: 10485750000, + expectedBandwidthRecv: 10485750000, + }, + { + name: "Zero values", + bytesSentPerSecond: 0, + bytesRecvPerSecond: 0, + totalBytesSent: 0, + totalBytesRecv: 0, + expectReset: false, + expectedNetworkSent: 0.0, + expectedNetworkRecv: 0.0, + expectedBandwidthSent: 0, + expectedBandwidthRecv: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup agent with initialized maps + a := &Agent{ + netInterfaces: make(map[string]struct{}), + netIoStats: make(map[uint16]system.NetIoStats), + netInterfaceDeltaTrackers: make(map[uint16]*deltatracker.DeltaTracker[string, uint64]), + } + + cacheTimeMs := uint16(100) + netIO := []psutilNet.IOCountersStat{ + {Name: "eth0", BytesSent: 1000, BytesRecv: 2000}, + } + systemStats := &system.Stats{} + nis := system.NetIoStats{} + + a.applyNetworkTotals( + cacheTimeMs, + netIO, + systemStats, + nis, + tt.totalBytesSent, + tt.totalBytesRecv, + tt.bytesSentPerSecond, + tt.bytesRecvPerSecond, + ) + + if tt.expectReset { + // Should have reset network tracking state - delta trackers should be cleared + // Note: initializeNetIoStats resets the maps, then applyNetworkTotals sets nis back + assert.Contains(t, a.netIoStats, cacheTimeMs, "cache entry should exist after reset") + assert.NotContains(t, a.netInterfaceDeltaTrackers, cacheTimeMs, "tracker should be cleared on reset") + } else { + // Should have applied stats + assert.Equal(t, tt.expectedNetworkSent, systemStats.NetworkSent) + assert.Equal(t, tt.expectedNetworkRecv, systemStats.NetworkRecv) + assert.Equal(t, tt.expectedBandwidthSent, systemStats.Bandwidth[0]) + assert.Equal(t, tt.expectedBandwidthRecv, systemStats.Bandwidth[1]) + + // Should have updated NetIoStats + updatedNis := a.netIoStats[cacheTimeMs] + assert.Equal(t, tt.totalBytesSent, updatedNis.BytesSent) + assert.Equal(t, tt.totalBytesRecv, updatedNis.BytesRecv) + } + }) + } +} diff --git a/agent/server.go b/agent/server.go index 58ac9a8e..bd103431 100644 --- a/agent/server.go +++ b/agent/server.go @@ -127,15 +127,75 @@ func (a *Agent) handleSession(s ssh.Session) { hubVersion := a.getHubVersion(sessionID, sessionCtx) - stats := a.gatherStats(sessionID) - - err := a.writeToSession(s, stats, hubVersion) - if err != nil { - slog.Error("Error encoding stats", "err", err, "stats", stats) - s.Exit(1) - } else { - s.Exit(0) + // Legacy one-shot behavior for older hubs + if hubVersion.LT(beszel.MinVersionAgentResponse) { + if err := a.handleLegacyStats(s, hubVersion); err != nil { + slog.Error("Error encoding stats", "err", err) + s.Exit(1) + return + } } + + var req common.HubRequest[cbor.RawMessage] + if err := cbor.NewDecoder(s).Decode(&req); err != nil { + // Fallback to legacy one-shot if the first decode fails + if err2 := a.handleLegacyStats(s, hubVersion); err2 != nil { + slog.Error("Error encoding stats (fallback)", "err", err2) + s.Exit(1) + return + } + s.Exit(0) + return + } + if err := a.handleSSHRequest(s, &req); err != nil { + slog.Error("SSH request handling failed", "err", err) + s.Exit(1) + return + } + s.Exit(0) +} + +// handleSSHRequest builds a handler context and dispatches to the shared registry +func (a *Agent) handleSSHRequest(w io.Writer, req *common.HubRequest[cbor.RawMessage]) error { + // SSH does not support fingerprint auth action + if req.Action == common.CheckFingerprint { + return cbor.NewEncoder(w).Encode(common.AgentResponse{Error: "unsupported action"}) + } + + // responder that writes AgentResponse to stdout + sshResponder := func(data any, requestID *uint32) error { + response := common.AgentResponse{Id: requestID} + switch v := data.(type) { + case *system.CombinedData: + response.SystemData = v + default: + response.Error = fmt.Sprintf("unsupported response type: %T", data) + } + return cbor.NewEncoder(w).Encode(response) + } + + ctx := &HandlerContext{ + Client: nil, + Agent: a, + Request: req, + RequestID: nil, + HubVerified: true, + SendResponse: sshResponder, + } + + if handler, ok := a.handlerRegistry.GetHandler(req.Action); ok { + if err := handler.Handle(ctx); err != nil { + return cbor.NewEncoder(w).Encode(common.AgentResponse{Error: err.Error()}) + } + return nil + } + return cbor.NewEncoder(w).Encode(common.AgentResponse{Error: fmt.Sprintf("unknown action: %d", req.Action)}) +} + +// handleLegacyStats serves the legacy one-shot stats payload for older hubs +func (a *Agent) handleLegacyStats(w io.Writer, hubVersion semver.Version) error { + stats := a.gatherStats(60_000) + return a.writeToSession(w, stats, hubVersion) } // writeToSession encodes and writes system statistics to the session. diff --git a/agent/system.go b/agent/system.go index b2952a3e..ee8b87f3 100644 --- a/agent/system.go +++ b/agent/system.go @@ -14,12 +14,18 @@ import ( "github.com/henrygd/beszel/internal/entities/system" "github.com/shirou/gopsutil/v4/cpu" - "github.com/shirou/gopsutil/v4/disk" "github.com/shirou/gopsutil/v4/host" "github.com/shirou/gopsutil/v4/load" "github.com/shirou/gopsutil/v4/mem" ) +// prevDisk stores previous per-device disk counters for a given cache interval +type prevDisk struct { + readBytes uint64 + writeBytes uint64 + at time.Time +} + // Sets initial / non-changing values about the host system func (a *Agent) initializeSystemInfo() { a.systemInfo.AgentVersion = beszel.Version @@ -68,7 +74,7 @@ func (a *Agent) initializeSystemInfo() { } // Returns current info, stats about the host system -func (a *Agent) getSystemStats() system.Stats { +func (a *Agent) getSystemStats(cacheTimeMs uint16) system.Stats { var systemStats system.Stats // battery @@ -77,11 +83,11 @@ func (a *Agent) getSystemStats() system.Stats { } // cpu percent - cpuPct, err := cpu.Percent(0, false) - if err != nil { + cpuPercent, err := getCpuPercent(cacheTimeMs) + if err == nil { + systemStats.Cpu = twoDecimals(cpuPercent) + } else { slog.Error("Error getting cpu percent", "err", err) - } else if len(cpuPct) > 0 { - systemStats.Cpu = twoDecimals(cpuPct[0]) } // load average @@ -131,56 +137,13 @@ func (a *Agent) getSystemStats() system.Stats { } // disk usage - for _, stats := range a.fsStats { - if d, err := disk.Usage(stats.Mountpoint); err == nil { - stats.DiskTotal = bytesToGigabytes(d.Total) - stats.DiskUsed = bytesToGigabytes(d.Used) - if stats.Root { - systemStats.DiskTotal = bytesToGigabytes(d.Total) - systemStats.DiskUsed = bytesToGigabytes(d.Used) - systemStats.DiskPct = twoDecimals(d.UsedPercent) - } - } else { - // reset stats if error (likely unmounted) - slog.Error("Error getting disk stats", "name", stats.Mountpoint, "err", err) - stats.DiskTotal = 0 - stats.DiskUsed = 0 - stats.TotalRead = 0 - stats.TotalWrite = 0 - } - } + a.updateDiskUsage(&systemStats) - // disk i/o - if ioCounters, err := disk.IOCounters(a.fsNames...); err == nil { - for _, d := range ioCounters { - stats := a.fsStats[d.Name] - if stats == nil { - continue - } - secondsElapsed := time.Since(stats.Time).Seconds() - readPerSecond := bytesToMegabytes(float64(d.ReadBytes-stats.TotalRead) / secondsElapsed) - writePerSecond := bytesToMegabytes(float64(d.WriteBytes-stats.TotalWrite) / secondsElapsed) - // check for invalid values and reset stats if so - if readPerSecond < 0 || writePerSecond < 0 || readPerSecond > 50_000 || writePerSecond > 50_000 { - slog.Warn("Invalid disk I/O. Resetting.", "name", d.Name, "read", readPerSecond, "write", writePerSecond) - a.initializeDiskIoStats(ioCounters) - break - } - stats.Time = time.Now() - stats.DiskReadPs = readPerSecond - stats.DiskWritePs = writePerSecond - stats.TotalRead = d.ReadBytes - stats.TotalWrite = d.WriteBytes - // if root filesystem, update system stats - if stats.Root { - systemStats.DiskReadPs = stats.DiskReadPs - systemStats.DiskWritePs = stats.DiskWritePs - } - } - } + // disk i/o (cache-aware per interval) + a.updateDiskIo(cacheTimeMs, &systemStats) - // network stats - a.updateNetworkStats(&systemStats) + // network stats (per cache interval) + a.updateNetworkStats(cacheTimeMs, &systemStats) // temperatures // TODO: maybe refactor to methods on systemStats @@ -191,7 +154,7 @@ func (a *Agent) getSystemStats() system.Stats { // reset high gpu percent a.systemInfo.GpuPct = 0 // get current GPU data - if gpuData := a.gpuManager.GetCurrentData(); len(gpuData) > 0 { + if gpuData := a.gpuManager.GetCurrentData(cacheTimeMs); len(gpuData) > 0 { systemStats.GPUData = gpuData // add temperatures diff --git a/agent/test-data/container.json b/agent/test-data/container.json new file mode 100644 index 00000000..4b0e234c --- /dev/null +++ b/agent/test-data/container.json @@ -0,0 +1,24 @@ +{ + "cpu_stats": { + "cpu_usage": { + "total_usage": 312055276000 + }, + "system_cpu_usage": 1366399830000000 + }, + "memory_stats": { + "usage": 507400192, + "stats": { + "inactive_file": 165130240 + } + }, + "networks": { + "eth0": { + "tx_bytes": 20376558, + "rx_bytes": 537029455 + }, + "eth1": { + "tx_bytes": 2003766, + "rx_bytes": 6241 + } + } +} diff --git a/agent/test-data/container2.json b/agent/test-data/container2.json new file mode 100644 index 00000000..cd4e882d --- /dev/null +++ b/agent/test-data/container2.json @@ -0,0 +1,24 @@ +{ + "cpu_stats": { + "cpu_usage": { + "total_usage": 314891801000 + }, + "system_cpu_usage": 1368474900000000 + }, + "memory_stats": { + "usage": 507400192, + "stats": { + "inactive_file": 165130240 + } + }, + "networks": { + "eth0": { + "tx_bytes": 20376558, + "rx_bytes": 537029455 + }, + "eth1": { + "tx_bytes": 2003766, + "rx_bytes": 6241 + } + } +} diff --git a/beszel.go b/beszel.go index 1d74522e..76b90e66 100644 --- a/beszel.go +++ b/beszel.go @@ -6,10 +6,13 @@ import "github.com/blang/semver" const ( // Version is the current version of the application. - Version = "0.12.12" + Version = "0.13.0-alpha.1" // AppName is the name of the application. AppName = "beszel" ) // MinVersionCbor is the minimum supported version for CBOR compatibility. var MinVersionCbor = semver.MustParse("0.12.0") + +// MinVersionAgentResponse is the minimum supported version for AgentResponse compatibility. +var MinVersionAgentResponse = semver.MustParse("0.13.0-alpha.1") diff --git a/go.mod b/go.mod index 7e0966f5..1df45943 100644 --- a/go.mod +++ b/go.mod @@ -12,16 +12,16 @@ require ( github.com/gliderlabs/ssh v0.3.8 github.com/google/uuid v1.6.0 github.com/lxzan/gws v1.8.9 - github.com/nicholas-fedor/shoutrrr v0.9.1 + github.com/nicholas-fedor/shoutrrr v0.10.0 github.com/pocketbase/dbx v1.11.0 github.com/pocketbase/pocketbase v0.30.0 - github.com/shirou/gopsutil/v4 v4.25.8 + github.com/shirou/gopsutil/v4 v4.25.9 github.com/spf13/cast v1.10.0 github.com/spf13/cobra v1.10.1 github.com/spf13/pflag v1.0.10 github.com/stretchr/testify v1.11.1 golang.org/x/crypto v0.42.0 - golang.org/x/exp v0.0.0-20250911091902-df9299821621 + golang.org/x/exp v0.0.0-20251002181428-27f1f14c8bb9 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index 55337bfd..f7187102 100644 --- a/go.sum +++ b/go.sum @@ -99,8 +99,8 @@ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qq github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/shirou/gopsutil/v4 v4.25.8 h1:NnAsw9lN7587WHxjJA9ryDnqhJpFH6A+wagYWTOH970= -github.com/shirou/gopsutil/v4 v4.25.8/go.mod h1:q9QdMmfAOVIw7a+eF86P7ISEU6ka+NLgkUxlopV4RwI= +github.com/shirou/gopsutil/v4 v4.25.9 h1:JImNpf6gCVhKgZhtaAHJ0serfFGtlfIlSC08eaKdTrU= +github.com/shirou/gopsutil/v4 v4.25.9/go.mod h1:gxIxoC+7nQRwUl/xNhutXlD8lq+jxTgpIkEf3rADHL8= github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s= @@ -127,8 +127,8 @@ go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8= -golang.org/x/exp v0.0.0-20250911091902-df9299821621 h1:2id6c1/gto0kaHYyrixvknJ8tUK/Qs5IsmBtrc+FtgU= -golang.org/x/exp v0.0.0-20250911091902-df9299821621/go.mod h1:TwQYMMnGpvZyc+JpB/UAuTNIsVJifOlSkrZkhcvpVUk= +golang.org/x/exp v0.0.0-20251002181428-27f1f14c8bb9 h1:TQwNpfvNkxAVlItJf6Cr5JTsVZoC/Sj7K3OZv2Pc14A= +golang.org/x/exp v0.0.0-20251002181428-27f1f14c8bb9/go.mod h1:TwQYMMnGpvZyc+JpB/UAuTNIsVJifOlSkrZkhcvpVUk= golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.31.0 h1:mLChjE2MV6g1S7oqbXC0/UcKijjm5fnJLUYKIYrLESA= golang.org/x/image v0.31.0/go.mod h1:R9ec5Lcp96v9FTF+ajwaH3uGxPH4fKfHHAVbUILxghA= @@ -169,20 +169,20 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= howett.net/plist v1.0.1 h1:37GdZ8tP09Q35o9ych3ehygcsL+HqKSwzctveSlarvM= howett.net/plist v1.0.1/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g= -modernc.org/cc/v4 v4.26.4 h1:jPhG8oNjtTYuP2FA4YefTJ/wioNUGALmGuEWt7SUR6s= -modernc.org/cc/v4 v4.26.4/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/cc/v4 v4.26.5 h1:xM3bX7Mve6G8K8b+T11ReenJOT+BmVqQj0FY5T4+5Y4= +modernc.org/cc/v4 v4.26.5/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= modernc.org/ccgo/v4 v4.28.1 h1:wPKYn5EC/mYTqBO373jKjvX2n+3+aK7+sICCv4Fjy1A= modernc.org/ccgo/v4 v4.28.1/go.mod h1:uD+4RnfrVgE6ec9NGguUNdhqzNIeeomeXf6CL0GTE5Q= -modernc.org/fileutil v1.3.28 h1:Vp156KUA2nPu9F1NEv036x9UGOjg2qsi5QlWTjZmtMk= -modernc.org/fileutil v1.3.28/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc= +modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA= +modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc= modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= modernc.org/libc v1.66.3 h1:cfCbjTUcdsKyyZZfEUKfoHcP3S0Wkvz3jgSzByEWVCQ= modernc.org/libc v1.66.3/go.mod h1:XD9zO8kt59cANKvHPXpx7yS2ELPheAey0vjIuZOhOU8= -modernc.org/libc v1.66.9 h1:YkHp7E1EWrN2iyNav7JE/nHasmshPvlGkon1VxGqOw0= -modernc.org/libc v1.66.9/go.mod h1:aVdcY7udcawRqauu0HukYYxtBSizV+R80n/6aQe9D5k= +modernc.org/libc v1.66.10 h1:yZkb3YeLx4oynyR+iUsXsybsX4Ubx7MQlSYEw4yj59A= +modernc.org/libc v1.66.10/go.mod h1:8vGSEwvoUoltr4dlywvHqjtAqHBaw0j1jI7iFBTAr2I= modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= @@ -193,6 +193,8 @@ modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= modernc.org/sqlite v1.38.2 h1:Aclu7+tgjgcQVShZqim41Bbw9Cho0y/7WzYptXqkEek= modernc.org/sqlite v1.38.2/go.mod h1:cPTJYSlgg3Sfg046yBShXENNtPrWrDX8bsbAQBzgQ5E= +modernc.org/sqlite v1.39.0 h1:6bwu9Ooim0yVYA7IZn9demiQk/Ejp0BtTjBWFLymSeY= +modernc.org/sqlite v1.39.0/go.mod h1:cPTJYSlgg3Sfg046yBShXENNtPrWrDX8bsbAQBzgQ5E= modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= diff --git a/internal/common/common-ws.go b/internal/common/common-ws.go index 374f477d..b07f3a18 100644 --- a/internal/common/common-ws.go +++ b/internal/common/common-ws.go @@ -1,22 +1,33 @@ package common -type WebSocketAction = uint8 +import ( + "github.com/henrygd/beszel/internal/entities/system" +) -// Not implemented yet -// type AgentError = uint8 +type WebSocketAction = uint8 const ( // Request system data from agent GetData WebSocketAction = iota // Check the fingerprint of the agent CheckFingerprint + // Add new actions here... ) // HubRequest defines the structure for requests sent from hub to agent. type HubRequest[T any] struct { Action WebSocketAction `cbor:"0,keyasint"` Data T `cbor:"1,keyasint,omitempty,omitzero"` - // Error AgentError `cbor:"error,omitempty,omitzero"` + Id *uint32 `cbor:"2,keyasint,omitempty"` +} + +// AgentResponse defines the structure for responses sent from agent to hub. +type AgentResponse struct { + Id *uint32 `cbor:"0,keyasint,omitempty"` + SystemData *system.CombinedData `cbor:"1,keyasint,omitempty,omitzero"` + Fingerprint *FingerprintResponse `cbor:"2,keyasint,omitempty,omitzero"` + Error string `cbor:"3,keyasint,omitempty,omitzero"` + // RawBytes []byte `cbor:"4,keyasint,omitempty,omitzero"` } type FingerprintRequest struct { @@ -30,3 +41,8 @@ type FingerprintResponse struct { Hostname string `cbor:"1,keyasint,omitempty,omitzero"` Port string `cbor:"2,keyasint,omitempty,omitzero"` } + +type DataRequestOptions struct { + CacheTimeMs uint16 `cbor:"0,keyasint"` + // ResourceType uint8 `cbor:"1,keyasint,omitempty,omitzero"` +} diff --git a/internal/entities/system/system.go b/internal/entities/system/system.go index c04e0404..5676e182 100644 --- a/internal/entities/system/system.go +++ b/internal/entities/system/system.go @@ -42,6 +42,8 @@ type Stats struct { Battery [2]uint8 `json:"bat,omitzero" cbor:"29,keyasint,omitzero"` // [percent, charge state, current] MaxMem float64 `json:"mm,omitempty" cbor:"30,keyasint,omitempty"` NetworkInterfaces map[string][4]uint64 `json:"ni,omitempty" cbor:"31,keyasint,omitempty"` // [upload bytes, download bytes, total upload, total download] + DiskIO [2]uint64 `json:"dio,omitzero" cbor:"32,keyasint,omitzero"` // [read bytes, write bytes] + MaxDiskIO [2]uint64 `json:"diom,omitzero" cbor:"-"` // [max read bytes, max write bytes] } type GPUData struct { @@ -68,6 +70,11 @@ type FsStats struct { DiskWritePs float64 `json:"w" cbor:"3,keyasint"` MaxDiskReadPS float64 `json:"rm,omitempty" cbor:"4,keyasint,omitempty"` MaxDiskWritePS float64 `json:"wm,omitempty" cbor:"5,keyasint,omitempty"` + // TODO: remove DiskReadPs and DiskWritePs in future release in favor of DiskReadBytes and DiskWriteBytes + DiskReadBytes uint64 `json:"rb" cbor:"6,keyasint,omitempty"` + DiskWriteBytes uint64 `json:"wb" cbor:"7,keyasint,omitempty"` + MaxDiskReadBytes uint64 `json:"rbm,omitempty" cbor:"-"` + MaxDiskWriteBytes uint64 `json:"wbm,omitempty" cbor:"-"` } type NetIoStats struct { diff --git a/internal/hub/agent_connect.go b/internal/hub/agent_connect.go index 25e8c883..225d535d 100644 --- a/internal/hub/agent_connect.go +++ b/internal/hub/agent_connect.go @@ -1,6 +1,7 @@ package hub import ( + "context" "errors" "net" "net/http" @@ -93,7 +94,7 @@ func (acr *agentConnectRequest) agentConnect() (err error) { // verifyWsConn verifies the WebSocket connection using the agent's fingerprint and // SSH key signature, then adds the system to the system manager. func (acr *agentConnectRequest) verifyWsConn(conn *gws.Conn, fpRecords []ws.FingerprintRecord) (err error) { - wsConn := ws.NewWsConnection(conn) + wsConn := ws.NewWsConnection(conn, acr.agentSemVer) // must set wsConn in connection store before the read loop conn.Session().Store("wsConn", wsConn) @@ -112,7 +113,7 @@ func (acr *agentConnectRequest) verifyWsConn(conn *gws.Conn, fpRecords []ws.Fing return err } - agentFingerprint, err := wsConn.GetFingerprint(acr.token, signer, acr.isUniversalToken) + agentFingerprint, err := wsConn.GetFingerprint(context.Background(), acr.token, signer, acr.isUniversalToken) if err != nil { return err } diff --git a/internal/hub/systems/system.go b/internal/hub/systems/system.go index 03c6a557..00144706 100644 --- a/internal/hub/systems/system.go +++ b/internal/hub/systems/system.go @@ -10,6 +10,7 @@ import ( "strings" "time" + "github.com/henrygd/beszel/internal/common" "github.com/henrygd/beszel/internal/hub/ws" "github.com/henrygd/beszel/internal/entities/system" @@ -107,7 +108,7 @@ func (sys *System) update() error { sys.handlePaused() return nil } - data, err := sys.fetchDataFromAgent() + data, err := sys.fetchDataFromAgent(common.DataRequestOptions{CacheTimeMs: uint16(interval)}) if err == nil { _, err = sys.createRecords(data) } @@ -209,13 +210,13 @@ func (sys *System) getContext() (context.Context, context.CancelFunc) { // fetchDataFromAgent attempts to fetch data from the agent, // prioritizing WebSocket if available. -func (sys *System) fetchDataFromAgent() (*system.CombinedData, error) { +func (sys *System) fetchDataFromAgent(options common.DataRequestOptions) (*system.CombinedData, error) { if sys.data == nil { sys.data = &system.CombinedData{} } if sys.WsConn != nil && sys.WsConn.IsConnected() { - wsData, err := sys.fetchDataViaWebSocket() + wsData, err := sys.fetchDataViaWebSocket(options) if err == nil { return wsData, nil } @@ -223,18 +224,18 @@ func (sys *System) fetchDataFromAgent() (*system.CombinedData, error) { sys.closeWebSocketConnection() } - sshData, err := sys.fetchDataViaSSH() + sshData, err := sys.fetchDataViaSSH(options) if err != nil { return nil, err } return sshData, nil } -func (sys *System) fetchDataViaWebSocket() (*system.CombinedData, error) { +func (sys *System) fetchDataViaWebSocket(options common.DataRequestOptions) (*system.CombinedData, error) { if sys.WsConn == nil || !sys.WsConn.IsConnected() { return nil, errors.New("no websocket connection") } - err := sys.WsConn.RequestSystemData(sys.data) + err := sys.WsConn.RequestSystemData(context.Background(), sys.data, options) if err != nil { return nil, err } @@ -244,7 +245,7 @@ func (sys *System) fetchDataViaWebSocket() (*system.CombinedData, error) { // fetchDataViaSSH handles fetching data using SSH. // This function encapsulates the original SSH logic. // It updates sys.data directly upon successful fetch. -func (sys *System) fetchDataViaSSH() (*system.CombinedData, error) { +func (sys *System) fetchDataViaSSH(options common.DataRequestOptions) (*system.CombinedData, error) { maxRetries := 1 for attempt := 0; attempt <= maxRetries; attempt++ { if sys.client == nil || sys.Status == down { @@ -269,12 +270,31 @@ func (sys *System) fetchDataViaSSH() (*system.CombinedData, error) { if err != nil { return nil, err } + stdin, stdinErr := session.StdinPipe() if err := session.Shell(); err != nil { return nil, err } *sys.data = system.CombinedData{} + if sys.agentVersion.GTE(beszel.MinVersionAgentResponse) && stdinErr == nil { + req := common.HubRequest[any]{Action: common.GetData, Data: options} + _ = cbor.NewEncoder(stdin).Encode(req) + // Close write side to signal end of request + _ = stdin.Close() + + var resp common.AgentResponse + if decErr := cbor.NewDecoder(stdout).Decode(&resp); decErr == nil && resp.SystemData != nil { + *sys.data = *resp.SystemData + // wait for the session to complete + if err := session.Wait(); err != nil { + return nil, err + } + return sys.data, nil + } + // If decoding failed, fall back below + } + if sys.agentVersion.GTE(beszel.MinVersionCbor) { err = cbor.NewDecoder(stdout).Decode(sys.data) } else { @@ -379,11 +399,11 @@ func extractAgentVersion(versionString string) (semver.Version, error) { } // getJitter returns a channel that will be triggered after a random delay -// between 40% and 90% of the interval. +// between 51% and 95% of the interval. // This is used to stagger the initial WebSocket connections to prevent clustering. func getJitter() <-chan time.Time { - minPercent := 40 - maxPercent := 90 + minPercent := 51 + maxPercent := 95 jitterRange := maxPercent - minPercent msDelay := (interval * minPercent / 100) + rand.Intn(interval*jitterRange/100) return time.After(time.Duration(msDelay) * time.Millisecond) diff --git a/internal/hub/systems/system_manager.go b/internal/hub/systems/system_manager.go index 4211a030..35e52141 100644 --- a/internal/hub/systems/system_manager.go +++ b/internal/hub/systems/system_manager.go @@ -106,6 +106,8 @@ func (sm *SystemManager) bindEventHooks() { sm.hub.OnRecordAfterUpdateSuccess("systems").BindFunc(sm.onRecordAfterUpdateSuccess) sm.hub.OnRecordAfterDeleteSuccess("systems").BindFunc(sm.onRecordAfterDeleteSuccess) sm.hub.OnRecordAfterUpdateSuccess("fingerprints").BindFunc(sm.onTokenRotated) + sm.hub.OnRealtimeSubscribeRequest().BindFunc(sm.onRealtimeSubscribeRequest) + sm.hub.OnRealtimeConnectRequest().BindFunc(sm.onRealtimeConnectRequest) } // onTokenRotated handles fingerprint token rotation events. diff --git a/internal/hub/systems/system_realtime.go b/internal/hub/systems/system_realtime.go new file mode 100644 index 00000000..20debda0 --- /dev/null +++ b/internal/hub/systems/system_realtime.go @@ -0,0 +1,187 @@ +package systems + +import ( + "encoding/json" + "strings" + "sync" + "time" + + "github.com/henrygd/beszel/internal/common" + "github.com/pocketbase/pocketbase/core" + "github.com/pocketbase/pocketbase/tools/subscriptions" +) + +type subscriptionInfo struct { + subscription string + connectedClients uint8 +} + +var ( + activeSubscriptions = make(map[string]*subscriptionInfo) + workerRunning bool + realtimeTicker *time.Ticker + tickerStopChan chan struct{} + realtimeMutex sync.Mutex +) + +// onRealtimeConnectRequest handles client connection events for realtime subscriptions. +// It cleans up existing subscriptions when a client connects. +func (sm *SystemManager) onRealtimeConnectRequest(e *core.RealtimeConnectRequestEvent) error { + // after e.Next() is the client disconnection + e.Next() + subscriptions := e.Client.Subscriptions() + for k := range subscriptions { + sm.removeRealtimeSubscription(k, subscriptions[k]) + } + return nil +} + +// onRealtimeSubscribeRequest handles client subscription events for realtime metrics. +// It tracks new subscriptions and unsubscriptions to manage the realtime worker lifecycle. +func (sm *SystemManager) onRealtimeSubscribeRequest(e *core.RealtimeSubscribeRequestEvent) error { + oldSubs := e.Client.Subscriptions() + // after e.Next() is the result of the subscribe request + err := e.Next() + newSubs := e.Client.Subscriptions() + + // handle new subscriptions + for k, options := range newSubs { + if _, ok := oldSubs[k]; !ok { + if strings.HasPrefix(k, "rt_metrics") { + systemId := options.Query["system"] + if _, ok := activeSubscriptions[systemId]; !ok { + activeSubscriptions[systemId] = &subscriptionInfo{ + subscription: k, + } + } + activeSubscriptions[systemId].connectedClients += 1 + sm.onRealtimeSubscriptionAdded() + } + } + } + // handle unsubscriptions + for k := range oldSubs { + if _, ok := newSubs[k]; !ok { + sm.removeRealtimeSubscription(k, oldSubs[k]) + } + } + + return err +} + +// onRealtimeSubscriptionAdded initializes or starts the realtime worker when the first subscription is added. +// It ensures only one worker runs at a time and creates the ticker for periodic data fetching. +func (sm *SystemManager) onRealtimeSubscriptionAdded() { + realtimeMutex.Lock() + defer realtimeMutex.Unlock() + + // Start the worker if it's not already running + if !workerRunning { + workerRunning = true + // Create a new stop channel for this worker instance + tickerStopChan = make(chan struct{}) + go sm.startRealtimeWorker() + } + + // If no ticker exists, create one + if realtimeTicker == nil { + realtimeTicker = time.NewTicker(1 * time.Second) + } +} + +// checkSubscriptions stops the realtime worker when there are no active subscriptions. +// This prevents unnecessary resource usage when no clients are listening for realtime data. +func (sm *SystemManager) checkSubscriptions() { + if !workerRunning || len(activeSubscriptions) > 0 { + return + } + + realtimeMutex.Lock() + defer realtimeMutex.Unlock() + + // Signal the worker to stop + if tickerStopChan != nil { + select { + case tickerStopChan <- struct{}{}: + default: + } + } + + if realtimeTicker != nil { + realtimeTicker.Stop() + realtimeTicker = nil + } + + // Mark worker as stopped (will be reset when next subscription comes in) + workerRunning = false +} + +// removeRealtimeSubscription removes a realtime subscription and checks if the worker should be stopped. +// It only processes subscriptions with the "rt_metrics" prefix and triggers cleanup when subscriptions are removed. +func (sm *SystemManager) removeRealtimeSubscription(subscription string, options subscriptions.SubscriptionOptions) { + if strings.HasPrefix(subscription, "rt_metrics") { + systemId := options.Query["system"] + if info, ok := activeSubscriptions[systemId]; ok { + info.connectedClients -= 1 + if info.connectedClients <= 0 { + delete(activeSubscriptions, systemId) + } + } + sm.checkSubscriptions() + } +} + +// startRealtimeWorker runs the main loop for fetching realtime data from agents. +// It continuously fetches system data and broadcasts it to subscribed clients via WebSocket. +func (sm *SystemManager) startRealtimeWorker() { + sm.fetchRealtimeDataAndNotify() + + for { + select { + case <-tickerStopChan: + return + case <-realtimeTicker.C: + // Check if ticker is still valid (might have been stopped) + if realtimeTicker == nil || len(activeSubscriptions) == 0 { + return + } + // slog.Debug("activeSubscriptions", "count", len(activeSubscriptions)) + sm.fetchRealtimeDataAndNotify() + } + } +} + +// fetchRealtimeDataAndNotify fetches realtime data for all active subscriptions and notifies the clients. +func (sm *SystemManager) fetchRealtimeDataAndNotify() { + for systemId, info := range activeSubscriptions { + system, ok := sm.systems.GetOk(systemId) + if ok { + go func() { + data, err := system.fetchDataFromAgent(common.DataRequestOptions{CacheTimeMs: 1000}) + if err != nil { + return + } + bytes, err := json.Marshal(data) + if err == nil { + notify(sm.hub, info.subscription, bytes) + } + }() + } + } +} + +// notify broadcasts realtime data to all clients subscribed to a specific subscription. +// It iterates through all connected clients and sends the data only to those with matching subscriptions. +func notify(app core.App, subscription string, data []byte) error { + message := subscriptions.Message{ + Name: subscription, + Data: data, + } + for _, client := range app.SubscriptionsBroker().Clients() { + if !client.HasSubscription(subscription) { + continue + } + client.Send(message) + } + return nil +} diff --git a/internal/hub/ws/handlers.go b/internal/hub/ws/handlers.go new file mode 100644 index 00000000..26ac2d4c --- /dev/null +++ b/internal/hub/ws/handlers.go @@ -0,0 +1,107 @@ +package ws + +import ( + "context" + "errors" + + "github.com/fxamacker/cbor/v2" + "github.com/henrygd/beszel/internal/common" + "github.com/henrygd/beszel/internal/entities/system" + "github.com/lxzan/gws" + "golang.org/x/crypto/ssh" +) + +// ResponseHandler defines interface for handling agent responses +type ResponseHandler interface { + Handle(agentResponse common.AgentResponse) error + HandleLegacy(rawData []byte) error +} + +// BaseHandler provides a default implementation that can be embedded to make HandleLegacy optional +// type BaseHandler struct{} + +// func (h *BaseHandler) HandleLegacy(rawData []byte) error { +// return errors.New("legacy format not supported") +// } + +//////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////// + +// systemDataHandler implements ResponseHandler for system data requests +type systemDataHandler struct { + data *system.CombinedData +} + +func (h *systemDataHandler) HandleLegacy(rawData []byte) error { + return cbor.Unmarshal(rawData, h.data) +} + +func (h *systemDataHandler) Handle(agentResponse common.AgentResponse) error { + if agentResponse.SystemData != nil { + *h.data = *agentResponse.SystemData + } + return nil +} + +// RequestSystemData requests system metrics from the agent and unmarshals the response. +func (ws *WsConn) RequestSystemData(ctx context.Context, data *system.CombinedData, options common.DataRequestOptions) error { + if !ws.IsConnected() { + return gws.ErrConnClosed + } + + req, err := ws.requestManager.SendRequest(ctx, common.GetData, options) + if err != nil { + return err + } + + handler := &systemDataHandler{data: data} + return ws.handleAgentRequest(req, handler) +} + +//////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////// + +// fingerprintHandler implements ResponseHandler for fingerprint requests +type fingerprintHandler struct { + result *common.FingerprintResponse +} + +func (h *fingerprintHandler) HandleLegacy(rawData []byte) error { + return cbor.Unmarshal(rawData, h.result) +} + +func (h *fingerprintHandler) Handle(agentResponse common.AgentResponse) error { + if agentResponse.Fingerprint != nil { + *h.result = *agentResponse.Fingerprint + return nil + } + return errors.New("no fingerprint data in response") +} + +// GetFingerprint authenticates with the agent using SSH signature and returns the agent's fingerprint. +func (ws *WsConn) GetFingerprint(ctx context.Context, token string, signer ssh.Signer, needSysInfo bool) (common.FingerprintResponse, error) { + if !ws.IsConnected() { + return common.FingerprintResponse{}, gws.ErrConnClosed + } + + challenge := []byte(token) + signature, err := signer.Sign(nil, challenge) + if err != nil { + return common.FingerprintResponse{}, err + } + + req, err := ws.requestManager.SendRequest(ctx, common.CheckFingerprint, common.FingerprintRequest{ + Signature: signature.Blob, + NeedSysInfo: needSysInfo, + }) + if err != nil { + return common.FingerprintResponse{}, err + } + + var result common.FingerprintResponse + handler := &fingerprintHandler{result: &result} + err = ws.handleAgentRequest(req, handler) + return result, err +} diff --git a/internal/hub/ws/request_manager.go b/internal/hub/ws/request_manager.go new file mode 100644 index 00000000..28dab40d --- /dev/null +++ b/internal/hub/ws/request_manager.go @@ -0,0 +1,186 @@ +package ws + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/fxamacker/cbor/v2" + "github.com/henrygd/beszel/internal/common" + "github.com/lxzan/gws" +) + +// RequestID uniquely identifies a request +type RequestID uint32 + +// PendingRequest tracks an in-flight request +type PendingRequest struct { + ID RequestID + ResponseCh chan *gws.Message + Context context.Context + Cancel context.CancelFunc + CreatedAt time.Time +} + +// RequestManager handles concurrent requests to an agent +type RequestManager struct { + sync.RWMutex + conn *gws.Conn + pendingReqs map[RequestID]*PendingRequest + nextID atomic.Uint32 +} + +// NewRequestManager creates a new request manager for a WebSocket connection +func NewRequestManager(conn *gws.Conn) *RequestManager { + rm := &RequestManager{ + conn: conn, + pendingReqs: make(map[RequestID]*PendingRequest), + } + return rm +} + +// SendRequest sends a request and returns a channel for the response +func (rm *RequestManager) SendRequest(ctx context.Context, action common.WebSocketAction, data any) (*PendingRequest, error) { + reqID := RequestID(rm.nextID.Add(1)) + + reqCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + + req := &PendingRequest{ + ID: reqID, + ResponseCh: make(chan *gws.Message, 1), + Context: reqCtx, + Cancel: cancel, + CreatedAt: time.Now(), + } + + rm.Lock() + rm.pendingReqs[reqID] = req + rm.Unlock() + + hubReq := common.HubRequest[any]{ + Id: (*uint32)(&reqID), + Action: action, + Data: data, + } + + // Send the request + if err := rm.sendMessage(hubReq); err != nil { + rm.cancelRequest(reqID) + return nil, fmt.Errorf("failed to send request: %w", err) + } + + // Start cleanup watcher for timeout/cancellation + go rm.cleanupRequest(req) + + return req, nil +} + +// sendMessage encodes and sends a message over WebSocket +func (rm *RequestManager) sendMessage(data any) error { + if rm.conn == nil { + return gws.ErrConnClosed + } + + bytes, err := cbor.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + return rm.conn.WriteMessage(gws.OpcodeBinary, bytes) +} + +// handleResponse processes a single response message +func (rm *RequestManager) handleResponse(message *gws.Message) { + var response common.AgentResponse + if err := cbor.Unmarshal(message.Data.Bytes(), &response); err != nil { + // Legacy response without ID - route to first pending request of any type + rm.routeLegacyResponse(message) + return + } + + reqID := RequestID(*response.Id) + + rm.RLock() + req, exists := rm.pendingReqs[reqID] + rm.RUnlock() + + if !exists { + // Request not found (might have timed out) - close the message + message.Close() + return + } + + select { + case req.ResponseCh <- message: + // Message successfully delivered - the receiver will close it + rm.deleteRequest(reqID) + case <-req.Context.Done(): + // Request was cancelled/timed out - close the message + message.Close() + } +} + +// routeLegacyResponse handles responses that don't have request IDs (backwards compatibility) +func (rm *RequestManager) routeLegacyResponse(message *gws.Message) { + // Snapshot the oldest pending request without holding the lock during send + rm.RLock() + var oldestReq *PendingRequest + for _, req := range rm.pendingReqs { + if oldestReq == nil || req.CreatedAt.Before(oldestReq.CreatedAt) { + oldestReq = req + } + } + rm.RUnlock() + + if oldestReq != nil { + select { + case oldestReq.ResponseCh <- message: + // Message successfully delivered - the receiver will close it + rm.deleteRequest(oldestReq.ID) + case <-oldestReq.Context.Done(): + // Request was cancelled - close the message + message.Close() + } + } else { + // No pending requests - close the message + message.Close() + } +} + +// cleanupRequest handles request timeout and cleanup +func (rm *RequestManager) cleanupRequest(req *PendingRequest) { + <-req.Context.Done() + rm.cancelRequest(req.ID) +} + +// cancelRequest removes a request and cancels its context +func (rm *RequestManager) cancelRequest(reqID RequestID) { + rm.Lock() + defer rm.Unlock() + + if req, exists := rm.pendingReqs[reqID]; exists { + req.Cancel() + delete(rm.pendingReqs, reqID) + } +} + +// deleteRequest removes a request from the pending map without cancelling its context. +func (rm *RequestManager) deleteRequest(reqID RequestID) { + rm.Lock() + defer rm.Unlock() + delete(rm.pendingReqs, reqID) +} + +// Close shuts down the request manager +func (rm *RequestManager) Close() { + rm.Lock() + defer rm.Unlock() + + // Cancel all pending requests + for _, req := range rm.pendingReqs { + req.Cancel() + } + rm.pendingReqs = make(map[RequestID]*PendingRequest) +} diff --git a/internal/hub/ws/request_manager_test.go b/internal/hub/ws/request_manager_test.go new file mode 100644 index 00000000..b5140f04 --- /dev/null +++ b/internal/hub/ws/request_manager_test.go @@ -0,0 +1,81 @@ +//go:build testing +// +build testing + +package ws + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// TestRequestManager_BasicFunctionality tests the request manager without mocking gws.Conn +func TestRequestManager_BasicFunctionality(t *testing.T) { + // We'll test the core logic without mocking the connection + // since the gws.Conn interface is complex to mock properly + + t.Run("request ID generation", func(t *testing.T) { + // Test that request IDs are generated sequentially and uniquely + rm := &RequestManager{} + + // Simulate multiple ID generations + id1 := rm.nextID.Add(1) + id2 := rm.nextID.Add(1) + id3 := rm.nextID.Add(1) + + assert.NotEqual(t, id1, id2) + assert.NotEqual(t, id2, id3) + assert.Greater(t, id2, id1) + assert.Greater(t, id3, id2) + }) + + t.Run("pending request tracking", func(t *testing.T) { + rm := &RequestManager{ + pendingReqs: make(map[RequestID]*PendingRequest), + } + + // Initially no pending requests + assert.Equal(t, 0, rm.GetPendingCount()) + + // Add some fake pending requests + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + req1 := &PendingRequest{ + ID: RequestID(1), + Context: ctx, + Cancel: cancel, + } + req2 := &PendingRequest{ + ID: RequestID(2), + Context: ctx, + Cancel: cancel, + } + + rm.pendingReqs[req1.ID] = req1 + rm.pendingReqs[req2.ID] = req2 + + assert.Equal(t, 2, rm.GetPendingCount()) + + // Remove one + delete(rm.pendingReqs, req1.ID) + assert.Equal(t, 1, rm.GetPendingCount()) + + // Remove all + delete(rm.pendingReqs, req2.ID) + assert.Equal(t, 0, rm.GetPendingCount()) + }) + + t.Run("context cancellation", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + // Wait for context to timeout + <-ctx.Done() + + // Verify context was cancelled + assert.Equal(t, context.DeadlineExceeded, ctx.Err()) + }) +} diff --git a/internal/hub/ws/ws.go b/internal/hub/ws/ws.go index 412bee1a..0539ec7d 100644 --- a/internal/hub/ws/ws.go +++ b/internal/hub/ws/ws.go @@ -5,13 +5,13 @@ import ( "time" "weak" - "github.com/henrygd/beszel/internal/entities/system" + "github.com/blang/semver" + "github.com/henrygd/beszel" "github.com/henrygd/beszel/internal/common" "github.com/fxamacker/cbor/v2" "github.com/lxzan/gws" - "golang.org/x/crypto/ssh" ) const ( @@ -25,9 +25,10 @@ type Handler struct { // WsConn represents a WebSocket connection to an agent. type WsConn struct { - conn *gws.Conn - responseChan chan *gws.Message - DownChan chan struct{} + conn *gws.Conn + requestManager *RequestManager + DownChan chan struct{} + agentVersion semver.Version } // FingerprintRecord is fingerprints collection record data in the hub @@ -50,21 +51,22 @@ func GetUpgrader() *gws.Upgrader { return upgrader } -// NewWsConnection creates a new WebSocket connection wrapper. -func NewWsConnection(conn *gws.Conn) *WsConn { +// NewWsConnection creates a new WebSocket connection wrapper with agent version. +func NewWsConnection(conn *gws.Conn, agentVersion semver.Version) *WsConn { return &WsConn{ - conn: conn, - responseChan: make(chan *gws.Message, 1), - DownChan: make(chan struct{}, 1), + conn: conn, + requestManager: NewRequestManager(conn), + DownChan: make(chan struct{}, 1), + agentVersion: agentVersion, } } -// OnOpen sets a deadline for the WebSocket connection. +// OnOpen sets a deadline for the WebSocket connection and extracts agent version. func (h *Handler) OnOpen(conn *gws.Conn) { conn.SetDeadline(time.Now().Add(deadline)) } -// OnMessage routes incoming WebSocket messages to the response channel. +// OnMessage routes incoming WebSocket messages to the request manager. func (h *Handler) OnMessage(conn *gws.Conn, message *gws.Message) { conn.SetDeadline(time.Now().Add(deadline)) if message.Opcode != gws.OpcodeBinary || message.Data.Len() == 0 { @@ -75,12 +77,7 @@ func (h *Handler) OnMessage(conn *gws.Conn, message *gws.Message) { _ = conn.WriteClose(1000, nil) return } - select { - case wsConn.(*WsConn).responseChan <- message: - default: - // close if the connection is not expecting a response - wsConn.(*WsConn).Close(nil) - } + wsConn.(*WsConn).requestManager.handleResponse(message) } // OnClose handles WebSocket connection closures and triggers system down status after delay. @@ -106,6 +103,9 @@ func (ws *WsConn) Close(msg []byte) { if ws.IsConnected() { ws.conn.WriteClose(1000, msg) } + if ws.requestManager != nil { + ws.requestManager.Close() + } } // Ping sends a ping frame to keep the connection alive. @@ -115,6 +115,7 @@ func (ws *WsConn) Ping() error { } // sendMessage encodes data to CBOR and sends it as a binary message to the agent. +// This is kept for backwards compatibility but new actions should use RequestManager. func (ws *WsConn) sendMessage(data common.HubRequest[any]) error { if ws.conn == nil { return gws.ErrConnClosed @@ -126,54 +127,34 @@ func (ws *WsConn) sendMessage(data common.HubRequest[any]) error { return ws.conn.WriteMessage(gws.OpcodeBinary, bytes) } -// RequestSystemData requests system metrics from the agent and unmarshals the response. -func (ws *WsConn) RequestSystemData(data *system.CombinedData) error { - var message *gws.Message - - ws.sendMessage(common.HubRequest[any]{ - Action: common.GetData, - }) +// handleAgentRequest processes a request to the agent, handling both legacy and new formats. +func (ws *WsConn) handleAgentRequest(req *PendingRequest, handler ResponseHandler) error { + // Wait for response select { - case <-time.After(10 * time.Second): - ws.Close(nil) - return gws.ErrConnClosed - case message = <-ws.responseChan: + case message := <-req.ResponseCh: + defer message.Close() + // Cancel request context to stop timeout watcher promptly + defer req.Cancel() + data := message.Data.Bytes() + + // Legacy format - unmarshal directly + if ws.agentVersion.LT(beszel.MinVersionAgentResponse) { + return handler.HandleLegacy(data) + } + + // New format with AgentResponse wrapper + var agentResponse common.AgentResponse + if err := cbor.Unmarshal(data, &agentResponse); err != nil { + return err + } + if agentResponse.Error != "" { + return errors.New(agentResponse.Error) + } + return handler.Handle(agentResponse) + + case <-req.Context.Done(): + return req.Context.Err() } - defer message.Close() - return cbor.Unmarshal(message.Data.Bytes(), data) -} - -// GetFingerprint authenticates with the agent using SSH signature and returns the agent's fingerprint. -func (ws *WsConn) GetFingerprint(token string, signer ssh.Signer, needSysInfo bool) (common.FingerprintResponse, error) { - var clientFingerprint common.FingerprintResponse - challenge := []byte(token) - - signature, err := signer.Sign(nil, challenge) - if err != nil { - return clientFingerprint, err - } - - err = ws.sendMessage(common.HubRequest[any]{ - Action: common.CheckFingerprint, - Data: common.FingerprintRequest{ - Signature: signature.Blob, - NeedSysInfo: needSysInfo, - }, - }) - if err != nil { - return clientFingerprint, err - } - - var message *gws.Message - select { - case message = <-ws.responseChan: - case <-time.After(10 * time.Second): - return clientFingerprint, errors.New("request expired") - } - defer message.Close() - - err = cbor.Unmarshal(message.Data.Bytes(), &clientFingerprint) - return clientFingerprint, err } // IsConnected returns true if the WebSocket connection is active. diff --git a/internal/hub/ws/ws_test.go b/internal/hub/ws/ws_test.go index c6a74c34..fac446e5 100644 --- a/internal/hub/ws/ws_test.go +++ b/internal/hub/ws/ws_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/blang/semver" "github.com/henrygd/beszel/internal/common" "github.com/fxamacker/cbor/v2" @@ -36,26 +37,25 @@ func TestGetUpgrader(t *testing.T) { // TestNewWsConnection tests WebSocket connection creation func TestNewWsConnection(t *testing.T) { // We can't easily mock gws.Conn, so we'll pass nil and test the structure - wsConn := NewWsConnection(nil) + wsConn := NewWsConnection(nil, semver.MustParse("0.12.10")) assert.NotNil(t, wsConn, "WebSocket connection should not be nil") assert.Nil(t, wsConn.conn, "Connection should be nil as passed") - assert.NotNil(t, wsConn.responseChan, "Response channel should be initialized") + assert.NotNil(t, wsConn.requestManager, "Request manager should be initialized") assert.NotNil(t, wsConn.DownChan, "Down channel should be initialized") - assert.Equal(t, 1, cap(wsConn.responseChan), "Response channel should have capacity of 1") assert.Equal(t, 1, cap(wsConn.DownChan), "Down channel should have capacity of 1") } // TestWsConn_IsConnected tests the connection status check func TestWsConn_IsConnected(t *testing.T) { // Test with nil connection - wsConn := NewWsConnection(nil) + wsConn := NewWsConnection(nil, semver.MustParse("0.12.10")) assert.False(t, wsConn.IsConnected(), "Should not be connected when conn is nil") } // TestWsConn_Close tests the connection closing with nil connection func TestWsConn_Close(t *testing.T) { - wsConn := NewWsConnection(nil) + wsConn := NewWsConnection(nil, semver.MustParse("0.12.10")) // Should handle nil connection gracefully assert.NotPanics(t, func() { @@ -65,7 +65,7 @@ func TestWsConn_Close(t *testing.T) { // TestWsConn_SendMessage_CBOR tests CBOR encoding in sendMessage func TestWsConn_SendMessage_CBOR(t *testing.T) { - wsConn := NewWsConnection(nil) + wsConn := NewWsConnection(nil, semver.MustParse("0.12.10")) testData := common.HubRequest[any]{ Action: common.GetData, @@ -194,7 +194,7 @@ func TestHandler(t *testing.T) { // TestWsConnChannelBehavior tests channel behavior without WebSocket connections func TestWsConnChannelBehavior(t *testing.T) { - wsConn := NewWsConnection(nil) + wsConn := NewWsConnection(nil, semver.MustParse("0.12.10")) // Test that channels are properly initialized and can be used select { @@ -212,11 +212,6 @@ func TestWsConnChannelBehavior(t *testing.T) { t.Error("Should be able to read from DownChan") } - // Response channel should be empty initially - select { - case <-wsConn.responseChan: - t.Error("Response channel should be empty initially") - default: - // Expected - channel should be empty - } + // Request manager should have no pending requests initially + assert.Equal(t, 0, wsConn.requestManager.GetPendingCount(), "Should have no pending requests initially") } diff --git a/internal/hub/ws/ws_test_helpers.go b/internal/hub/ws/ws_test_helpers.go new file mode 100644 index 00000000..daf84741 --- /dev/null +++ b/internal/hub/ws/ws_test_helpers.go @@ -0,0 +1,11 @@ +//go:build testing +// +build testing + +package ws + +// GetPendingCount returns the number of pending requests (for monitoring) +func (rm *RequestManager) GetPendingCount() int { + rm.RLock() + defer rm.RUnlock() + return len(rm.pendingReqs) +} diff --git a/internal/records/records.go b/internal/records/records.go index 416ecb24..7c53aed3 100644 --- a/internal/records/records.go +++ b/internal/records/records.go @@ -213,6 +213,8 @@ func (rm *RecordManager) AverageSystemStats(db dbx.Builder, records RecordIds) * sum.LoadAvg[2] += stats.LoadAvg[2] sum.Bandwidth[0] += stats.Bandwidth[0] sum.Bandwidth[1] += stats.Bandwidth[1] + sum.DiskIO[0] += stats.DiskIO[0] + sum.DiskIO[1] += stats.DiskIO[1] batterySum += int(stats.Battery[0]) sum.Battery[1] = stats.Battery[1] // Set peak values @@ -224,6 +226,8 @@ func (rm *RecordManager) AverageSystemStats(db dbx.Builder, records RecordIds) * sum.MaxDiskWritePs = max(sum.MaxDiskWritePs, stats.MaxDiskWritePs, stats.DiskWritePs) sum.MaxBandwidth[0] = max(sum.MaxBandwidth[0], stats.MaxBandwidth[0], stats.Bandwidth[0]) sum.MaxBandwidth[1] = max(sum.MaxBandwidth[1], stats.MaxBandwidth[1], stats.Bandwidth[1]) + sum.MaxDiskIO[0] = max(sum.MaxDiskIO[0], stats.MaxDiskIO[0], stats.DiskIO[0]) + sum.MaxDiskIO[1] = max(sum.MaxDiskIO[1], stats.MaxDiskIO[1], stats.DiskIO[1]) // Accumulate network interfaces if sum.NetworkInterfaces == nil { @@ -314,6 +318,8 @@ func (rm *RecordManager) AverageSystemStats(db dbx.Builder, records RecordIds) * sum.DiskPct = twoDecimals(sum.DiskPct / count) sum.DiskReadPs = twoDecimals(sum.DiskReadPs / count) sum.DiskWritePs = twoDecimals(sum.DiskWritePs / count) + sum.DiskIO[0] = sum.DiskIO[0] / uint64(count) + sum.DiskIO[1] = sum.DiskIO[1] / uint64(count) sum.NetworkSent = twoDecimals(sum.NetworkSent / count) sum.NetworkRecv = twoDecimals(sum.NetworkRecv / count) sum.LoadAvg[0] = twoDecimals(sum.LoadAvg[0] / count) diff --git a/internal/site/biome.json b/internal/site/biome.json index 14bd3e89..a2da8da2 100644 --- a/internal/site/biome.json +++ b/internal/site/biome.json @@ -1,41 +1,83 @@ { "$schema": "https://biomejs.dev/schemas/2.2.3/schema.json", "vcs": { - "enabled": false, + "enabled": true, "clientKind": "git", - "useIgnoreFile": false - }, - "files": { - "ignoreUnknown": false + "useIgnoreFile": true, + "defaultBranch": "main" }, "formatter": { "enabled": true, "indentStyle": "tab", - "indentWidth": 2, - "lineWidth": 120 + "lineWidth": 120, + "formatWithErrors": true }, + "assist": { "actions": { "source": { "organizeImports": "on" } } }, "linter": { "enabled": true, "rules": { "recommended": true, + "complexity": { + "noUselessStringConcat": "error", + "noUselessUndefinedInitialization": "error", + "noVoid": "error", + "useDateNow": "error" + }, "correctness": { - "useUniqueElementIds": "off" + "noConstantMathMinMaxClamp": "error", + "noUndeclaredVariables": "error", + "noUnusedImports": "error", + "noUnusedFunctionParameters": "error", + "noUnusedPrivateClassMembers": "error", + "useExhaustiveDependencies": { + "level": "error", + "options": { + "reportUnnecessaryDependencies": false + } + }, + "noUnusedVariables": "error" + }, + "style": { + "noParameterProperties": "error", + "noYodaExpression": "error", + "useConsistentBuiltinInstantiation": "error", + "useFragmentSyntax": "error", + "useShorthandAssign": "error", + "useArrayLiterals": "error" + }, + "suspicious": { + "useAwait": "error", + "noEvolvingTypes": "error" } } }, "javascript": { "formatter": { "quoteStyle": "double", - "semicolons": "asNeeded", - "trailingCommas": "es5" + "trailingCommas": "es5", + "semicolons": "asNeeded" } }, - "assist": { - "enabled": true, - "actions": { - "source": { - "organizeImports": "on" + "overrides": [ + { + "includes": ["**/*.jsx", "**/*.tsx"], + "linter": { + "rules": { + "style": { + "noParameterAssign": "error" + } + } + } + }, + { + "includes": ["**/*.ts", "**/*.tsx"], + "linter": { + "rules": { + "correctness": { + "noUnusedVariables": "off" + } + } } } - } + ] } diff --git a/internal/site/src/components/charts/chart-time-select.tsx b/internal/site/src/components/charts/chart-time-select.tsx index 0b9896ad..ab3473aa 100644 --- a/internal/site/src/components/charts/chart-time-select.tsx +++ b/internal/site/src/components/charts/chart-time-select.tsx @@ -2,12 +2,27 @@ import { useStore } from "@nanostores/react" import { HistoryIcon } from "lucide-react" import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select" import { $chartTime } from "@/lib/stores" -import { chartTimeData, cn } from "@/lib/utils" -import type { ChartTimes } from "@/types" +import { chartTimeData, cn, compareSemVer, parseSemVer } from "@/lib/utils" +import type { ChartTimes, SemVer } from "@/types" +import { memo } from "react" -export default function ChartTimeSelect({ className }: { className?: string }) { +export default memo(function ChartTimeSelect({ + className, + agentVersion, +}: { + className?: string + agentVersion: SemVer +}) { const chartTime = useStore($chartTime) + // remove chart times that are not supported by the system agent version + const availableChartTimes = Object.entries(chartTimeData).filter(([_, { minVersion }]) => { + if (!minVersion) { + return true + } + return compareSemVer(agentVersion, parseSemVer(minVersion)) >= 0 + }) + return ( ) -} +}) diff --git a/internal/site/src/components/charts/load-average-chart.tsx b/internal/site/src/components/charts/load-average-chart.tsx index 8ea7300c..c8763638 100644 --- a/internal/site/src/components/charts/load-average-chart.tsx +++ b/internal/site/src/components/charts/load-average-chart.tsx @@ -59,8 +59,6 @@ export default memo(function LoadAverageChart({ chartData }: { chartData: ChartD b.value - a.value} content={ formatShortDate(data[0].payload.created)} @@ -70,14 +68,15 @@ export default memo(function LoadAverageChart({ chartData }: { chartData: ChartD /> {keys.map(({ legacy, color, label }, i) => { const dataKey = (value: { stats: SystemStats }) => { - if (chartData.agentVersion.patch < 1) { + const { minor, patch } = chartData.agentVersion + if (minor <= 12 && patch < 1) { return value.stats?.[legacy] } return value.stats?.la?.[i] ?? value.stats?.[legacy] } return ( ( +function addEmptyValues( prevRecords: T[], newRecords: T[], expectedInterval: number -) { +): T[] { const modifiedRecords: T[] = [] let prevTime = (prevRecords.at(-1)?.created ?? 0) as number for (let i = 0; i < newRecords.length; i++) { const record = newRecords[i] - record.created = new Date(record.created).getTime() - if (prevTime) { + if (record.created !== null) { + record.created = new Date(record.created).getTime() + } + if (prevTime && record.created !== null) { const interval = record.created - prevTime // if interval is too large, add a null record if (interval > expectedInterval / 2 + expectedInterval) { - // @ts-expect-error - modifiedRecords.push({ created: null, stats: null }) + modifiedRecords.push({ created: null, ...("stats" in record ? { stats: null } : {}) } as T) } } - prevTime = record.created + if (record.created !== null) { + prevTime = record.created + } modifiedRecords.push(record) } return modifiedRecords @@ -137,7 +149,7 @@ async function getStats( }) } -function dockerOrPodman(str: string, system: SystemRecord) { +function dockerOrPodman(str: string, system: SystemRecord): string { if (system.info.p) { return str.replace("docker", "podman").replace("Docker", "Podman") } @@ -156,10 +168,9 @@ export default memo(function SystemDetail({ name }: { name: string }) { const [containerData, setContainerData] = useState([] as ChartData["containerData"]) const netCardRef = useRef(null) const persistChartTime = useRef(false) - const [containerFilterBar, setContainerFilterBar] = useState(null as null | JSX.Element) const [bottomSpacing, setBottomSpacing] = useState(0) const [chartLoading, setChartLoading] = useState(true) - const isLongerChart = chartTime !== "1h" + const isLongerChart = !["1m", "1h"].includes(chartTime) // true if chart time is not 1m or 1h const userSettings = $userSettings.get() const chartWrapRef = useRef(null) @@ -172,7 +183,6 @@ export default memo(function SystemDetail({ name }: { name: string }) { persistChartTime.current = false setSystemStats([]) setContainerData([]) - setContainerFilterBar(null) $containerFilter.set("") } }, [name]) @@ -185,6 +195,51 @@ export default memo(function SystemDetail({ name }: { name: string }) { }) }, [name]) + // hide 1m chart time if system agent version is less than 0.13.0 + useEffect(() => { + if (parseSemVer(system?.info?.v) < parseSemVer("0.13.0")) { + $chartTime.set("1h") + } + }, [system?.info?.v]) + + // subscribe to realtime metrics if chart time is 1m + // biome-ignore lint/correctness/useExhaustiveDependencies: not necessary + useEffect(() => { + let unsub = () => {} + if (!system.id || chartTime !== "1m") { + return + } + if (system.status !== SystemStatus.Up || parseSemVer(system?.info?.v).minor < 13) { + $chartTime.set("1h") + return + } + pb.realtime + .subscribe( + `rt_metrics`, + (data: { container: ContainerStatsRecord[]; info: SystemInfo; stats: SystemStats }) => { + // console.log("received realtime metrics", data) + const newContainerData = makeContainerData([ + { created: Date.now(), stats: data.container } as unknown as ContainerStatsRecord, + ]) + setContainerData((prevData) => addEmptyValues(prevData, prevData.slice(-59).concat(newContainerData), 1000)) + setSystemStats((prevStats) => + addEmptyValues( + prevStats, + prevStats.slice(-59).concat({ created: Date.now(), stats: data.stats } as SystemStatsRecord), + 1000 + ) + ) + }, + { query: { system: system.id } } + ) + .then((us) => { + unsub = us + }) + return () => { + unsub?.() + } + }, [chartTime, system.id]) + // biome-ignore lint/correctness/useExhaustiveDependencies: not necessary const chartData: ChartData = useMemo(() => { const lastCreated = Math.max( @@ -221,13 +276,13 @@ export default memo(function SystemDetail({ name }: { name: string }) { } containerData.push(containerStats) } - setContainerData(containerData) + return containerData }, []) // get stats // biome-ignore lint/correctness/useExhaustiveDependencies: not necessary useEffect(() => { - if (!system.id || !chartTime) { + if (!system.id || !chartTime || chartTime === "1m") { return } // loading: true @@ -261,12 +316,7 @@ export default memo(function SystemDetail({ name }: { name: string }) { } cache.set(cs_cache_key, containerData) } - if (containerData.length) { - !containerFilterBar && setContainerFilterBar() - } else if (containerFilterBar) { - setContainerFilterBar(null) - } - makeContainerData(containerData) + setContainerData(makeContainerData(containerData)) }) }, [system, chartTime]) @@ -392,9 +442,10 @@ export default memo(function SystemDetail({ name }: { name: string }) { // select field for switching between avg and max values const maxValSelect = isLongerChart ? : null - const showMax = chartTime !== "1h" && maxValues + const showMax = maxValues && isLongerChart + + const containerFilterBar = containerData.length ? : null - // if no data, show empty message const dataEmpty = !chartLoading && chartData.systemStats.length === 0 const lastGpuVals = Object.values(systemStats.at(-1)?.stats.g ?? {}) const hasGpuData = lastGpuVals.length > 0 @@ -483,7 +534,7 @@ export default memo(function SystemDetail({ name }: { name: string }) {
- + @@ -594,23 +645,33 @@ export default memo(function SystemDetail({ name }: { name: string }) { dataPoints={[ { label: t({ message: "Write", comment: "Disk write" }), - dataKey: ({ stats }: SystemStatsRecord) => (showMax ? stats?.dwm : stats?.dw), + dataKey: ({ stats }: SystemStatsRecord) => { + if (showMax) { + return stats?.dio?.[1] ?? (stats?.dwm ?? 0) * 1024 * 1024 + } + return stats?.dio?.[1] ?? (stats?.dw ?? 0) * 1024 * 1024 + }, color: 3, opacity: 0.3, }, { label: t({ message: "Read", comment: "Disk read" }), - dataKey: ({ stats }: SystemStatsRecord) => (showMax ? stats?.drm : stats?.dr), + dataKey: ({ stats }: SystemStatsRecord) => { + if (showMax) { + return stats?.diom?.[0] ?? (stats?.drm ?? 0) * 1024 * 1024 + } + return stats?.dio?.[0] ?? (stats?.dr ?? 0) * 1024 * 1024 + }, color: 1, opacity: 0.3, }, ]} tickFormatter={(val) => { - const { value, unit } = formatBytes(val, true, userSettings.unitDisk, true) + const { value, unit } = formatBytes(val, true, userSettings.unitDisk, false) return `${toFixedFloat(value, value >= 10 ? 0 : 1)} ${unit}` }} contentFormatter={({ value }) => { - const { value: convertedValue, unit } = formatBytes(value, true, userSettings.unitDisk, true) + const { value: convertedValue, unit } = formatBytes(value, true, userSettings.unitDisk, false) return `${decimalString(convertedValue, convertedValue >= 100 ? 1 : 2)} ${unit}` }} /> @@ -791,7 +852,7 @@ export default memo(function SystemDetail({ name }: { name: string }) { return (
stats?.efs?.[extraFsName]?.[showMax ? "wm" : "w"] ?? 0, + dataKey: ({ stats }) => { + if (showMax) { + return stats?.efs?.[extraFsName]?.wb ?? (stats?.efs?.[extraFsName]?.wm ?? 0) * 1024 * 1024 + } + return stats?.efs?.[extraFsName]?.wb ?? (stats?.efs?.[extraFsName]?.w ?? 0) * 1024 * 1024 + }, color: 3, opacity: 0.3, }, { label: t`Read`, - dataKey: ({ stats }) => stats?.efs?.[extraFsName]?.[showMax ? "rm" : "r"] ?? 0, + dataKey: ({ stats }) => { + if (showMax) { + return ( + stats?.efs?.[extraFsName]?.rbm ?? (stats?.efs?.[extraFsName]?.rm ?? 0) * 1024 * 1024 + ) + } + return stats?.efs?.[extraFsName]?.rb ?? (stats?.efs?.[extraFsName]?.r ?? 0) * 1024 * 1024 + }, color: 1, opacity: 0.3, }, ]} maxToggled={maxValues} tickFormatter={(val) => { - const { value, unit } = formatBytes(val, true, userSettings.unitDisk, true) + const { value, unit } = formatBytes(val, true, userSettings.unitDisk, false) return `${toFixedFloat(value, value >= 10 ? 0 : 1)} ${unit}` }} contentFormatter={({ value }) => { - const { value: convertedValue, unit } = formatBytes(value, true, userSettings.unitDisk, true) + const { value: convertedValue, unit } = formatBytes(value, true, userSettings.unitDisk, false) return `${decimalString(convertedValue, convertedValue >= 100 ? 1 : 2)} ${unit}` }} /> @@ -913,7 +986,7 @@ export default memo(function SystemDetail({ name }: { name: string }) { }) function GpuEnginesChart({ chartData }: { chartData: ChartData }) { - const dataPoints = [] + const dataPoints: DataPoint[] = [] const engines = Object.keys(chartData.systemStats?.at(-1)?.stats.g?.[0]?.e ?? {}).sort() for (const engine of engines) { dataPoints.push({ diff --git a/internal/site/src/components/routes/system/network-sheet.tsx b/internal/site/src/components/routes/system/network-sheet.tsx index c87e87ea..62eb16bc 100644 --- a/internal/site/src/components/routes/system/network-sheet.tsx +++ b/internal/site/src/components/routes/system/network-sheet.tsx @@ -53,7 +53,7 @@ export default memo(function NetworkSheet({ {hasOpened.current && ( - + { return ( diff --git a/internal/site/src/lib/api.ts b/internal/site/src/lib/api.ts index 38a52145..eed6469f 100644 --- a/internal/site/src/lib/api.ts +++ b/internal/site/src/lib/api.ts @@ -26,7 +26,7 @@ export const verifyAuth = () => { } /** Logs the user out by clearing the auth store and unsubscribing from realtime updates. */ -export async function logOut() { +export function logOut() { $allSystemsByName.set({}) $alerts.set({}) $userSettings.set({} as UserSettings) diff --git a/internal/site/src/lib/utils.ts b/internal/site/src/lib/utils.ts index d5b36eca..8c6a4782 100644 --- a/internal/site/src/lib/utils.ts +++ b/internal/site/src/lib/utils.ts @@ -1,7 +1,7 @@ import { t } from "@lingui/core/macro" import { type ClassValue, clsx } from "clsx" -import { timeDay, timeHour } from "d3-time" import { listenKeys } from "nanostores" +import { timeDay, timeHour, timeMinute } from "d3-time" import { useEffect, useState } from "react" import { twMerge } from "tailwind-merge" import { prependBasePath } from "@/components/router" @@ -54,9 +54,18 @@ const createShortDateFormatter = (hour12?: boolean) => hour12, }) +const createHourWithSecondsFormatter = (hour12?: boolean) => + new Intl.DateTimeFormat(undefined, { + hour: "numeric", + minute: "numeric", + second: "numeric", + hour12, + }) + // Initialize formatters with default values let hourWithMinutesFormatter = createHourWithMinutesFormatter() let shortDateFormatter = createShortDateFormatter() +let hourWithSecondsFormatter = createHourWithSecondsFormatter() export const currentHour12 = () => shortDateFormatter.resolvedOptions().hour12 @@ -68,6 +77,10 @@ export const formatShortDate = (timestamp: string) => { return shortDateFormatter.format(new Date(timestamp)) } +export const hourWithSeconds = (timestamp: string) => { + return hourWithSecondsFormatter.format(new Date(timestamp)) +} + // Update the time formatters if user changes hourFormat listenKeys($userSettings, ["hourFormat"], ({ hourFormat }) => { if (!hourFormat) return @@ -75,6 +88,7 @@ listenKeys($userSettings, ["hourFormat"], ({ hourFormat }) => { if (currentHour12() !== newHour12) { hourWithMinutesFormatter = createHourWithMinutesFormatter(newHour12) shortDateFormatter = createShortDateFormatter(newHour12) + hourWithSecondsFormatter = createHourWithSecondsFormatter(newHour12) } }) @@ -91,6 +105,15 @@ export const updateFavicon = (newIcon: string) => { } export const chartTimeData: ChartTimeData = { + "1m": { + type: "1m", + expectedInterval: 1000, + label: () => t`1 minute`, + format: (timestamp: string) => hourWithSeconds(timestamp), + ticks: 3, + getOffset: (endTime: Date) => timeMinute.offset(endTime, -1), + minVersion: "0.13.0", + }, "1h": { type: "1m", expectedInterval: 60_000, @@ -278,7 +301,7 @@ export const generateToken = () => { } /** Get the hub URL from the global BESZEL object */ -export const getHubURL = () => BESZEL?.HUB_URL || window.location.origin +export const getHubURL = () => globalThis.BESZEL?.HUB_URL || window.location.origin /** Map of system IDs to their corresponding tokens (used to avoid fetching in add-system dialog) */ export const tokenMap = new Map() @@ -333,6 +356,17 @@ export const parseSemVer = (semVer = ""): SemVer => { return { major: parts?.[0] ?? 0, minor: parts?.[1] ?? 0, patch: parts?.[2] ?? 0 } } +/** Compare two semver strings. Returns -1 if a is less than b, 0 if a is equal to b, and 1 if a is greater than b. */ +export function compareSemVer(a: SemVer, b: SemVer) { + if (a.major !== b.major) { + return a.major - b.major + } + if (a.minor !== b.minor) { + return a.minor - b.minor + } + return a.patch - b.patch +} + /** Get meter state from 0-100 value. Used for color coding meters. */ export function getMeterState(value: number): MeterState { const { colorWarn = 65, colorCrit = 90 } = $userSettings.get() diff --git a/internal/site/src/types.d.ts b/internal/site/src/types.d.ts index 8656fb81..66527407 100644 --- a/internal/site/src/types.d.ts +++ b/internal/site/src/types.d.ts @@ -123,6 +123,10 @@ export interface SystemStats { drm?: number /** max disk write (mb) */ dwm?: number + /** disk I/O bytes [read, write] */ + dio?: [number, number] + /** max disk I/O bytes [read, write] */ + diom?: [number, number] /** network sent (mb) */ ns: number /** network received (mb) */ @@ -177,6 +181,14 @@ export interface ExtraFsStats { rm: number /** max write (mb) */ wm: number + /** read per second (bytes) */ + rb: number + /** write per second (bytes) */ + wb: number + /** max read per second (bytes) */ + rbm: number + /** max write per second (mb) */ + wbm: number } export interface ContainerStatsRecord extends RecordModel { @@ -224,7 +236,7 @@ export interface AlertsHistoryRecord extends RecordModel { resolved?: string | null } -export type ChartTimes = "1h" | "12h" | "24h" | "1w" | "30d" +export type ChartTimes = "1m" | "1h" | "12h" | "24h" | "1w" | "30d" export interface ChartTimeData { [key: string]: { @@ -234,6 +246,7 @@ export interface ChartTimeData { ticks?: number format: (timestamp: string) => string getOffset: (endTime: Date) => Date + minVersion?: string } }