diff --git a/agent/gpu_amd_linux.go b/agent/gpu_amd_linux.go index ab809e90..6af0784c 100644 --- a/agent/gpu_amd_linux.go +++ b/agent/gpu_amd_linux.go @@ -33,8 +33,8 @@ func (gm *GPUManager) hasAmdSysfs() bool { return false } for _, vendorPath := range cards { - vendor, err := os.ReadFile(vendorPath) - if err == nil && strings.TrimSpace(string(vendor)) == "0x1002" { + vendor, err := utils.ReadStringFileLimited(vendorPath, 64) + if err == nil && vendor == "0x1002" { return true } } @@ -88,12 +88,11 @@ func (gm *GPUManager) collectAmdStats() error { // isAmdGpu checks whether a DRM card path belongs to AMD vendor ID 0x1002. func isAmdGpu(cardPath string) bool { - vendorPath := filepath.Join(cardPath, "device/vendor") - vendor, err := os.ReadFile(vendorPath) + vendor, err := utils.ReadStringFileLimited(filepath.Join(cardPath, "device/vendor"), 64) if err != nil { return false } - return strings.TrimSpace(string(vendor)) == "0x1002" + return vendor == "0x1002" } // updateAmdGpuData reads GPU metrics from sysfs and updates the GPU data map. @@ -155,11 +154,11 @@ func (gm *GPUManager) updateAmdGpuData(cardPath string) bool { // readSysfsFloat reads and parses a numeric value from a sysfs file. func readSysfsFloat(path string) (float64, error) { - val, err := os.ReadFile(path) + val, err := utils.ReadStringFileLimited(path, 64) if err != nil { return 0, err } - return strconv.ParseFloat(strings.TrimSpace(string(val)), 64) + return strconv.ParseFloat(val, 64) } // normalizeHexID normalizes hex IDs by trimming spaces, lowercasing, and dropping 0x. @@ -274,16 +273,16 @@ func cacheMissingAmdgpuName(deviceID, revisionID string) { // Falls back to showing the raw device ID if not found in the lookup table. func getAmdGpuName(devicePath string) string { // Try product_name first (works for some enterprise GPUs) - if prod, err := os.ReadFile(filepath.Join(devicePath, "product_name")); err == nil { - return strings.TrimSpace(string(prod)) + if prod, err := utils.ReadStringFileLimited(filepath.Join(devicePath, "product_name"), 128); err == nil { + return prod } // Read PCI device ID and look it up - if deviceID, err := os.ReadFile(filepath.Join(devicePath, "device")); err == nil { - id := normalizeHexID(string(deviceID)) + if deviceID, err := utils.ReadStringFileLimited(filepath.Join(devicePath, "device"), 64); err == nil { + id := normalizeHexID(deviceID) revision := "" - if revBytes, revErr := os.ReadFile(filepath.Join(devicePath, "revision")); revErr == nil { - revision = normalizeHexID(string(revBytes)) + if rev, revErr := utils.ReadStringFileLimited(filepath.Join(devicePath, "revision"), 64); revErr == nil { + revision = normalizeHexID(rev) } if name, found, done := getCachedAmdgpuName(id, revision); found { diff --git a/agent/utils/utils.go b/agent/utils/utils.go index 86b4567a..1941a909 100644 --- a/agent/utils/utils.go +++ b/agent/utils/utils.go @@ -1,6 +1,7 @@ package utils import ( + "io" "math" "os" "strconv" @@ -50,6 +51,23 @@ func ReadStringFileOK(path string) (string, bool) { return strings.TrimSpace(string(b)), true } +// ReadStringFileLimited reads a file into a string with a maximum size (in bytes) to avoid +// allocating large buffers and potential panics with pseudo-files when the size is misreported. +func ReadStringFileLimited(path string, maxSize int) (string, error) { + f, err := os.Open(path) + if err != nil { + return "", err + } + defer f.Close() + + buf := make([]byte, maxSize) + n, err := f.Read(buf) + if err != nil && err != io.EOF { + return "", err + } + return strings.TrimSpace(string(buf[:n])), nil +} + // FileExists reports whether the given path exists. func FileExists(path string) bool { _, err := os.Stat(path)