gpu(amd): add workaround for misreported sysfs filesize (#1799)

This commit is contained in:
henrygd
2026-03-09 13:49:37 -04:00
parent 35d0e792ad
commit 6b1ff264f2
2 changed files with 30 additions and 13 deletions

View File

@@ -33,8 +33,8 @@ func (gm *GPUManager) hasAmdSysfs() bool {
return false return false
} }
for _, vendorPath := range cards { for _, vendorPath := range cards {
vendor, err := os.ReadFile(vendorPath) vendor, err := utils.ReadStringFileLimited(vendorPath, 64)
if err == nil && strings.TrimSpace(string(vendor)) == "0x1002" { if err == nil && vendor == "0x1002" {
return true return true
} }
} }
@@ -88,12 +88,11 @@ func (gm *GPUManager) collectAmdStats() error {
// isAmdGpu checks whether a DRM card path belongs to AMD vendor ID 0x1002. // isAmdGpu checks whether a DRM card path belongs to AMD vendor ID 0x1002.
func isAmdGpu(cardPath string) bool { func isAmdGpu(cardPath string) bool {
vendorPath := filepath.Join(cardPath, "device/vendor") vendor, err := utils.ReadStringFileLimited(filepath.Join(cardPath, "device/vendor"), 64)
vendor, err := os.ReadFile(vendorPath)
if err != nil { if err != nil {
return false return false
} }
return strings.TrimSpace(string(vendor)) == "0x1002" return vendor == "0x1002"
} }
// updateAmdGpuData reads GPU metrics from sysfs and updates the GPU data map. // updateAmdGpuData reads GPU metrics from sysfs and updates the GPU data map.
@@ -155,11 +154,11 @@ func (gm *GPUManager) updateAmdGpuData(cardPath string) bool {
// readSysfsFloat reads and parses a numeric value from a sysfs file. // readSysfsFloat reads and parses a numeric value from a sysfs file.
func readSysfsFloat(path string) (float64, error) { func readSysfsFloat(path string) (float64, error) {
val, err := os.ReadFile(path) val, err := utils.ReadStringFileLimited(path, 64)
if err != nil { if err != nil {
return 0, err return 0, err
} }
return strconv.ParseFloat(strings.TrimSpace(string(val)), 64) return strconv.ParseFloat(val, 64)
} }
// normalizeHexID normalizes hex IDs by trimming spaces, lowercasing, and dropping 0x. // normalizeHexID normalizes hex IDs by trimming spaces, lowercasing, and dropping 0x.
@@ -274,16 +273,16 @@ func cacheMissingAmdgpuName(deviceID, revisionID string) {
// Falls back to showing the raw device ID if not found in the lookup table. // Falls back to showing the raw device ID if not found in the lookup table.
func getAmdGpuName(devicePath string) string { func getAmdGpuName(devicePath string) string {
// Try product_name first (works for some enterprise GPUs) // Try product_name first (works for some enterprise GPUs)
if prod, err := os.ReadFile(filepath.Join(devicePath, "product_name")); err == nil { if prod, err := utils.ReadStringFileLimited(filepath.Join(devicePath, "product_name"), 128); err == nil {
return strings.TrimSpace(string(prod)) return prod
} }
// Read PCI device ID and look it up // Read PCI device ID and look it up
if deviceID, err := os.ReadFile(filepath.Join(devicePath, "device")); err == nil { if deviceID, err := utils.ReadStringFileLimited(filepath.Join(devicePath, "device"), 64); err == nil {
id := normalizeHexID(string(deviceID)) id := normalizeHexID(deviceID)
revision := "" revision := ""
if revBytes, revErr := os.ReadFile(filepath.Join(devicePath, "revision")); revErr == nil { if rev, revErr := utils.ReadStringFileLimited(filepath.Join(devicePath, "revision"), 64); revErr == nil {
revision = normalizeHexID(string(revBytes)) revision = normalizeHexID(rev)
} }
if name, found, done := getCachedAmdgpuName(id, revision); found { if name, found, done := getCachedAmdgpuName(id, revision); found {

View File

@@ -1,6 +1,7 @@
package utils package utils
import ( import (
"io"
"math" "math"
"os" "os"
"strconv" "strconv"
@@ -50,6 +51,23 @@ func ReadStringFileOK(path string) (string, bool) {
return strings.TrimSpace(string(b)), true return strings.TrimSpace(string(b)), true
} }
// ReadStringFileLimited reads a file into a string with a maximum size (in bytes) to avoid
// allocating large buffers and potential panics with pseudo-files when the size is misreported.
func ReadStringFileLimited(path string, maxSize int) (string, error) {
f, err := os.Open(path)
if err != nil {
return "", err
}
defer f.Close()
buf := make([]byte, maxSize)
n, err := f.Read(buf)
if err != nil && err != io.EOF {
return "", err
}
return strings.TrimSpace(string(buf[:n])), nil
}
// FileExists reports whether the given path exists. // FileExists reports whether the given path exists.
func FileExists(path string) bool { func FileExists(path string) bool {
_, err := os.Stat(path) _, err := os.Stat(path)