gpu(amd): add workaround for misreported sysfs filesize (#1799)

This commit is contained in:
henrygd
2026-03-09 13:49:37 -04:00
parent 35d0e792ad
commit 6b1ff264f2
2 changed files with 30 additions and 13 deletions

View File

@@ -33,8 +33,8 @@ func (gm *GPUManager) hasAmdSysfs() bool {
return false
}
for _, vendorPath := range cards {
vendor, err := os.ReadFile(vendorPath)
if err == nil && strings.TrimSpace(string(vendor)) == "0x1002" {
vendor, err := utils.ReadStringFileLimited(vendorPath, 64)
if err == nil && vendor == "0x1002" {
return true
}
}
@@ -88,12 +88,11 @@ func (gm *GPUManager) collectAmdStats() error {
// isAmdGpu checks whether a DRM card path belongs to AMD vendor ID 0x1002.
func isAmdGpu(cardPath string) bool {
vendorPath := filepath.Join(cardPath, "device/vendor")
vendor, err := os.ReadFile(vendorPath)
vendor, err := utils.ReadStringFileLimited(filepath.Join(cardPath, "device/vendor"), 64)
if err != nil {
return false
}
return strings.TrimSpace(string(vendor)) == "0x1002"
return vendor == "0x1002"
}
// updateAmdGpuData reads GPU metrics from sysfs and updates the GPU data map.
@@ -155,11 +154,11 @@ func (gm *GPUManager) updateAmdGpuData(cardPath string) bool {
// readSysfsFloat reads and parses a numeric value from a sysfs file.
func readSysfsFloat(path string) (float64, error) {
val, err := os.ReadFile(path)
val, err := utils.ReadStringFileLimited(path, 64)
if err != nil {
return 0, err
}
return strconv.ParseFloat(strings.TrimSpace(string(val)), 64)
return strconv.ParseFloat(val, 64)
}
// normalizeHexID normalizes hex IDs by trimming spaces, lowercasing, and dropping 0x.
@@ -274,16 +273,16 @@ func cacheMissingAmdgpuName(deviceID, revisionID string) {
// Falls back to showing the raw device ID if not found in the lookup table.
func getAmdGpuName(devicePath string) string {
// Try product_name first (works for some enterprise GPUs)
if prod, err := os.ReadFile(filepath.Join(devicePath, "product_name")); err == nil {
return strings.TrimSpace(string(prod))
if prod, err := utils.ReadStringFileLimited(filepath.Join(devicePath, "product_name"), 128); err == nil {
return prod
}
// Read PCI device ID and look it up
if deviceID, err := os.ReadFile(filepath.Join(devicePath, "device")); err == nil {
id := normalizeHexID(string(deviceID))
if deviceID, err := utils.ReadStringFileLimited(filepath.Join(devicePath, "device"), 64); err == nil {
id := normalizeHexID(deviceID)
revision := ""
if revBytes, revErr := os.ReadFile(filepath.Join(devicePath, "revision")); revErr == nil {
revision = normalizeHexID(string(revBytes))
if rev, revErr := utils.ReadStringFileLimited(filepath.Join(devicePath, "revision"), 64); revErr == nil {
revision = normalizeHexID(rev)
}
if name, found, done := getCachedAmdgpuName(id, revision); found {

View File

@@ -1,6 +1,7 @@
package utils
import (
"io"
"math"
"os"
"strconv"
@@ -50,6 +51,23 @@ func ReadStringFileOK(path string) (string, bool) {
return strings.TrimSpace(string(b)), true
}
// ReadStringFileLimited reads a file into a string with a maximum size (in bytes) to avoid
// allocating large buffers and potential panics with pseudo-files when the size is misreported.
func ReadStringFileLimited(path string, maxSize int) (string, error) {
f, err := os.Open(path)
if err != nil {
return "", err
}
defer f.Close()
buf := make([]byte, maxSize)
n, err := f.Read(buf)
if err != nil && err != io.EOF {
return "", err
}
return strings.TrimSpace(string(buf[:n])), nil
}
// FileExists reports whether the given path exists.
func FileExists(path string) bool {
_, err := os.Stat(path)