diff --git a/agent/smart.go b/agent/smart.go index 383c7508..c17d9b77 100644 --- a/agent/smart.go +++ b/agent/smart.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "os/exec" + "slices" "strconv" "strings" "sync" @@ -129,6 +130,19 @@ func (sm *SmartManager) GetCurrentData() map[string]smart.SmartData { // If scan fails, return error // If scan succeeds, parse the output and update the SmartDevices slice func (sm *SmartManager) ScanDevices() error { + if configuredDevices, ok := GetEnv("SMART_DEVICES"); ok { + config := strings.TrimSpace(configuredDevices) + if config == "" { + return errNoValidSmartData + } + slog.Info("SMART_DEVICES", "config", config) + + if err := sm.parseConfiguredDevices(config); err != nil { + return err + } + return nil + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -146,6 +160,103 @@ func (sm *SmartManager) ScanDevices() error { return nil } +func (sm *SmartManager) parseConfiguredDevices(config string) error { + entries := strings.Split(config, ",") + devices := make([]*DeviceInfo, 0, len(entries)) + for _, entry := range entries { + entry = strings.TrimSpace(entry) + if entry == "" { + continue + } + + parts := strings.SplitN(entry, ":", 2) + + name := strings.TrimSpace(parts[0]) + if name == "" { + return fmt.Errorf("invalid SMART_DEVICES entry %q: device name is required", entry) + } + + devType := "" + if len(parts) == 2 { + devType = strings.ToLower(strings.TrimSpace(parts[1])) + } + + devices = append(devices, &DeviceInfo{ + Name: name, + Type: devType, + }) + } + + if len(devices) == 0 { + sm.Lock() + sm.SmartDevices = nil + sm.Unlock() + return errNoValidSmartData + } + + sm.Lock() + sm.SmartDevices = devices + sm.Unlock() + return nil +} + +// detectDeviceType extracts the device type reported in smartctl JSON output. +func detectDeviceType(output []byte) string { + var payload struct { + Device struct { + Type string `json:"type"` + } `json:"device"` + } + + if err := json.Unmarshal(output, &payload); err != nil { + return "" + } + + return strings.ToLower(payload.Device.Type) +} + +// parseSmartOutput attempts each SMART parser, optionally detecting the type when +// it is not provided, and updates the device info when a parser succeeds. +func (sm *SmartManager) parseSmartOutput(deviceInfo *DeviceInfo, output []byte) bool { + deviceType := strings.ToLower(deviceInfo.Type) + + if deviceType == "" { + if detected := detectDeviceType(output); detected != "" { + deviceType = detected + deviceInfo.Type = detected + } + } + + parsers := []struct { + Type string + Parse func([]byte) (bool, int) + Alias []string + }{ + {Type: "nvme", Parse: sm.parseSmartForNvme, Alias: []string{"sntasmedia"}}, + {Type: "sat", Parse: sm.parseSmartForSata, Alias: []string{"ata"}}, + {Type: "scsi", Parse: sm.parseSmartForScsi}, + } + + for _, parser := range parsers { + if deviceType != "" && deviceType != parser.Type { + aliasMatched := slices.Contains(parser.Alias, deviceType) + if !aliasMatched { + continue + } + } + + hasData, _ := parser.Parse(output) + if hasData { + if deviceInfo.Type == "" { + deviceInfo.Type = parser.Type + } + return true + } + } + + return false +} + // CollectSmart collects SMART data for a device // Collect data using `smartctl -d -aj /dev/` when device type is known // Always attempts to parse output even if command fails, as some data may still be available @@ -181,16 +292,7 @@ func (sm *SmartManager) CollectSmart(deviceInfo *DeviceInfo) error { output, err = cmd.CombinedOutput() } - hasValidData := false - - switch deviceInfo.Type { - case "scsi": - hasValidData, _ = sm.parseSmartForScsi(output) - case "sat", "ata": - hasValidData, _ = sm.parseSmartForSata(output) - case "nvme", "sntasmedia": - hasValidData, _ = sm.parseSmartForNvme(output) - } + hasValidData := sm.parseSmartOutput(deviceInfo, output) if !hasValidData { if err != nil { diff --git a/agent/smart_test.go b/agent/smart_test.go index 70020d65..48d8409e 100644 --- a/agent/smart_test.go +++ b/agent/smart_test.go @@ -165,6 +165,55 @@ func TestDevicesSnapshotReturnsCopy(t *testing.T) { assert.Len(t, snapshot, 2) } +func TestScanDevicesWithEnvOverride(t *testing.T) { + t.Setenv("SMART_DEVICES", "/dev/sda:sat, /dev/nvme0:nvme") + + sm := &SmartManager{ + SmartDataMap: make(map[string]*smart.SmartData), + } + + err := sm.ScanDevices() + require.NoError(t, err) + + require.Len(t, sm.SmartDevices, 2) + assert.Equal(t, "/dev/sda", sm.SmartDevices[0].Name) + assert.Equal(t, "sat", sm.SmartDevices[0].Type) + assert.Equal(t, "/dev/nvme0", sm.SmartDevices[1].Name) + assert.Equal(t, "nvme", sm.SmartDevices[1].Type) +} + +func TestScanDevicesWithEnvOverrideInvalid(t *testing.T) { + t.Setenv("SMART_DEVICES", ":sat") + + sm := &SmartManager{ + SmartDataMap: make(map[string]*smart.SmartData), + } + + err := sm.ScanDevices() + require.Error(t, err) +} + +func TestScanDevicesWithEnvOverrideEmpty(t *testing.T) { + t.Setenv("SMART_DEVICES", " ") + + sm := &SmartManager{ + SmartDataMap: make(map[string]*smart.SmartData), + } + + err := sm.ScanDevices() + assert.ErrorIs(t, err, errNoValidSmartData) + assert.Empty(t, sm.SmartDevices) +} + +func TestSmartctlArgsWithoutType(t *testing.T) { + device := &DeviceInfo{Name: "/dev/sda"} + + sm := &SmartManager{} + + args := sm.smartctlArgs(device, true) + assert.Equal(t, []string{"-aj", "-n", "standby", "/dev/sda"}, args) +} + func TestSmartctlArgs(t *testing.T) { sm := &SmartManager{}