From 473cb7f437a337c8873495786719f31bffad6067 Mon Sep 17 00:00:00 2001 From: henrygd Date: Tue, 28 Oct 2025 15:38:47 -0400 Subject: [PATCH] merge SMART_DEVICES with devices returned from smartctl scan --- agent/smart.go | 138 +++++++++++++++++++++++++++++++++----------- agent/smart_test.go | 40 +++++++++++-- 2 files changed, 141 insertions(+), 37 deletions(-) diff --git a/agent/smart.go b/agent/smart.go index c9bb82d6..f8defd97 100644 --- a/agent/smart.go +++ b/agent/smart.go @@ -136,17 +136,19 @@ func (sm *SmartManager) ScanDevices(force bool) error { } sm.lastScanTime = time.Now() - if configuredDevices, ok := GetEnv("SMART_DEVICES"); ok { - config := strings.TrimSpace(configuredDevices) + var configuredDevices []*DeviceInfo + if configuredRaw, ok := GetEnv("SMART_DEVICES"); ok { + config := strings.TrimSpace(configuredRaw) if config == "" { return errNoValidSmartData } slog.Info("SMART_DEVICES", "config", config) - if err := sm.parseConfiguredDevices(config); err != nil { + parsedDevices, err := sm.parseConfiguredDevices(config) + if err != nil { return err } - return nil + configuredDevices = parsedDevices } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -155,18 +157,34 @@ func (sm *SmartManager) ScanDevices(force bool) error { cmd := exec.CommandContext(ctx, "smartctl", "--scan", "-j") output, err := cmd.Output() + var ( + scanErr error + scannedDevices []*DeviceInfo + hasValidScan bool + ) + if err != nil { - return err + scanErr = err + } else { + scannedDevices, hasValidScan = sm.parseScan(output) + if !hasValidScan { + scanErr = errNoValidSmartData + } } - hasValidData := sm.parseScan(output) - if !hasValidData { + finalDevices := mergeDeviceLists(scannedDevices, configuredDevices) + sm.updateSmartDevices(finalDevices) + if len(finalDevices) == 0 { + if scanErr != nil { + return scanErr + } return errNoValidSmartData } + return nil } -func (sm *SmartManager) parseConfiguredDevices(config string) error { +func (sm *SmartManager) parseConfiguredDevices(config string) ([]*DeviceInfo, error) { entries := strings.Split(config, ",") devices := make([]*DeviceInfo, 0, len(entries)) for _, entry := range entries { @@ -179,7 +197,7 @@ func (sm *SmartManager) parseConfiguredDevices(config string) error { name := strings.TrimSpace(parts[0]) if name == "" { - return fmt.Errorf("invalid SMART_DEVICES entry %q: device name is required", entry) + return nil, fmt.Errorf("invalid SMART_DEVICES entry %q: device name is required", entry) } devType := "" @@ -194,16 +212,10 @@ func (sm *SmartManager) parseConfiguredDevices(config string) error { } if len(devices) == 0 { - sm.Lock() - sm.SmartDevices = nil - sm.Unlock() - return errNoValidSmartData + return nil, errNoValidSmartData } - sm.Lock() - sm.SmartDevices = devices - sm.Unlock() - return nil + return devices, nil } // detectDeviceType extracts the device type reported in smartctl JSON output. @@ -345,12 +357,8 @@ func (sm *SmartManager) hasDataForDevice(deviceName string) bool { return false } -// parseScan parses the output of smartctl --scan -j and updates the SmartDevices slice -func (sm *SmartManager) parseScan(output []byte) bool { - sm.Lock() - defer sm.Unlock() - - sm.SmartDevices = make([]*DeviceInfo, 0) +// parseScan parses the output of smartctl --scan -j and returns the discovered devices. +func (sm *SmartManager) parseScan(output []byte) ([]*DeviceInfo, bool) { scan := &scanOutput{} if err := json.Unmarshal(output, scan); err != nil { @@ -362,33 +370,97 @@ func (sm *SmartManager) parseScan(output []byte) bool { return false } - scannedDeviceNameMap := make(map[string]bool, len(scan.Devices)) - + devices := make([]*DeviceInfo, 0, len(scan.Devices)) for _, device := range scan.Devices { - deviceInfo := &DeviceInfo{ + // slog.Info("found device during scan", "name", device.Name, "type", device.Type, "protocol", device.Protocol) + devices = append(devices, &DeviceInfo{ Name: device.Name, Type: device.Type, InfoName: device.InfoName, Protocol: device.Protocol, - } - sm.SmartDevices = append(sm.SmartDevices, deviceInfo) - scannedDeviceNameMap[device.Name] = true + }) } - // remove cached entries whose device path no longer appears in the scan + + return devices, true +} + +// mergeDeviceLists combines scanned and configured SMART devices, preferring +// configured SMART_DEVICES when both sources reference the same device. +func mergeDeviceLists(scanned, configured []*DeviceInfo) []*DeviceInfo { + if len(scanned) == 0 && len(configured) == 0 { + return nil + } + + finalDevices := make([]*DeviceInfo, 0, len(scanned)+len(configured)) + deviceIndex := make(map[string]*DeviceInfo, len(scanned)+len(configured)) + + for _, dev := range scanned { + if dev == nil || dev.Name == "" { + continue + } + copyDev := *dev + finalDevices = append(finalDevices, ©Dev) + deviceIndex[copyDev.Name] = finalDevices[len(finalDevices)-1] + } + + for _, dev := range configured { + if dev == nil || dev.Name == "" { + continue + } + + if existing, ok := deviceIndex[dev.Name]; ok { + if dev.Type != "" { + existing.Type = dev.Type + } + if dev.InfoName != "" { + existing.InfoName = dev.InfoName + } + if dev.Protocol != "" { + existing.Protocol = dev.Protocol + } + continue + } + + copyDev := *dev + finalDevices = append(finalDevices, ©Dev) + deviceIndex[copyDev.Name] = finalDevices[len(finalDevices)-1] + } + + return finalDevices +} + +// updateSmartDevices replaces the cached device list and prunes SMART data +// entries whose backing device no longer exists. +func (sm *SmartManager) updateSmartDevices(devices []*DeviceInfo) { + sm.Lock() + defer sm.Unlock() + + sm.SmartDevices = devices + + if len(sm.SmartDataMap) == 0 { + return + } + + validNames := make(map[string]struct{}, len(devices)) + for _, device := range devices { + if device == nil || device.Name == "" { + continue + } + validNames[device.Name] = struct{}{} + } + for key, data := range sm.SmartDataMap { if data == nil { delete(sm.SmartDataMap, key) continue } - if _, ok := scannedDeviceNameMap[data.DiskName]; ok { + if _, ok := validNames[data.DiskName]; ok { continue } delete(sm.SmartDataMap, key) } - - return true } // isVirtualDevice checks if a device is a virtual disk that should be filtered out diff --git a/agent/smart_test.go b/agent/smart_test.go index bb0ab424..cc1acebc 100644 --- a/agent/smart_test.go +++ b/agent/smart_test.go @@ -159,7 +159,7 @@ func TestScanDevicesWithEnvOverride(t *testing.T) { SmartDataMap: make(map[string]*smart.SmartData), } - err := sm.ScanDevices() + err := sm.ScanDevices(true) require.NoError(t, err) require.Len(t, sm.SmartDevices, 2) @@ -176,7 +176,7 @@ func TestScanDevicesWithEnvOverrideInvalid(t *testing.T) { SmartDataMap: make(map[string]*smart.SmartData), } - err := sm.ScanDevices() + err := sm.ScanDevices(true) require.Error(t, err) } @@ -187,7 +187,7 @@ func TestScanDevicesWithEnvOverrideEmpty(t *testing.T) { SmartDataMap: make(map[string]*smart.SmartData), } - err := sm.ScanDevices() + err := sm.ScanDevices(true) assert.ErrorIs(t, err, errNoValidSmartData) assert.Empty(t, sm.SmartDevices) } @@ -315,9 +315,11 @@ func TestParseScan(t *testing.T) { ] }`) - hasData := sm.parseScan(scanJSON) + devices, hasData := sm.parseScan(scanJSON) assert.True(t, hasData) + sm.updateSmartDevices(devices) + require.Len(t, sm.SmartDevices, 2) assert.Equal(t, "/dev/sda", sm.SmartDevices[0].Name) assert.Equal(t, "sat", sm.SmartDevices[0].Type) @@ -331,6 +333,36 @@ func TestParseScan(t *testing.T) { assert.False(t, staleExists, "stale smart data entry should be removed when device path disappears") } +func TestMergeDeviceListsPrefersConfigured(t *testing.T) { + scanned := []*DeviceInfo{ + {Name: "/dev/sda", Type: "sat", InfoName: "scan-info", Protocol: "ATA"}, + {Name: "/dev/nvme0", Type: "nvme"}, + } + + configured := []*DeviceInfo{ + {Name: "/dev/sda", Type: "sat-override"}, + {Name: "/dev/sdb", Type: "sat"}, + } + + merged := mergeDeviceLists(scanned, configured) + require.Len(t, merged, 3) + + byName := make(map[string]*DeviceInfo, len(merged)) + for _, dev := range merged { + byName[dev.Name] = dev + } + + require.Contains(t, byName, "/dev/sda") + assert.Equal(t, "sat-override", byName["/dev/sda"].Type, "configured type should override scanned type") + assert.Equal(t, "scan-info", byName["/dev/sda"].InfoName, "scan metadata should be preserved when config does not provide it") + + require.Contains(t, byName, "/dev/nvme0") + assert.Equal(t, "nvme", byName["/dev/nvme0"].Type) + + require.Contains(t, byName, "/dev/sdb") + assert.Equal(t, "sat", byName["/dev/sdb"].Type) +} + func assertAttrValue(t *testing.T, attributes []*smart.SmartAttribute, name string, expected uint64) { t.Helper() attr := findAttr(attributes, name)