Skip to content

Commit f981460

Browse files
committed
gpu: add pci device id whitelist support
By defining whitelisted PCI IDs, it's possible to only select some GPUs per host. For example, on a desktop with integrated and discrete graphics, GPU plugin can only register the discrete one. Signed-off-by: Tuomas Katila <[email protected]>
1 parent 60f41aa commit f981460

File tree

8 files changed

+234
-0
lines changed

8 files changed

+234
-0
lines changed

cmd/gpu_plugin/gpu_plugin.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ const (
6969

7070
type cliOptions struct {
7171
preferredAllocationPolicy string
72+
whitelistIDs string
7273
sharedDevNum int
7374
temperatureLimit int
7475
enableMonitoring bool
@@ -204,6 +205,30 @@ func packedPolicy(req *pluginapi.ContainerPreferredAllocationRequest) []string {
204205
return deviceIds
205206
}
206207

208+
func parsePCIDeviceIDs(whitelist string) ([]string, error) {
209+
var deviceIDs []string
210+
211+
r := regexp.MustCompile(`^0x[0-9a-f]{4}$`)
212+
213+
for id := range strings.SplitSeq(whitelist, ",") {
214+
id = strings.TrimSpace(id)
215+
if id == "" {
216+
klog.Errorf("Empty PCI device ID given")
217+
218+
return []string{}, os.ErrInvalid
219+
}
220+
if !r.MatchString(id) {
221+
klog.Errorf("Invalid PCI device ID %s", id)
222+
223+
return []string{}, os.ErrInvalid
224+
}
225+
226+
deviceIDs = append(deviceIDs, id)
227+
}
228+
229+
return deviceIDs, nil
230+
}
231+
207232
func (dp *devicePlugin) pciAddressForCard(cardPath, cardName string) (string, error) {
208233
linkPath, err := os.Readlink(cardPath)
209234
if err != nil {
@@ -585,6 +610,23 @@ func (dp *devicePlugin) filterOutInvalidCards(files []fs.DirEntry) []fs.DirEntry
585610
continue
586611
}
587612

613+
// Skip if the device is not in the whitelist.
614+
if len(dp.options.whitelistIDs) > 0 {
615+
pciID, err := pciDeviceIDForCard(path.Join(dp.sysfsDir, f.Name()))
616+
if err != nil {
617+
klog.Warningf("Failed to get PCI ID for device %s: %+v", f.Name(), err)
618+
619+
continue
620+
}
621+
622+
if !strings.Contains(dp.options.whitelistIDs, pciID) {
623+
klog.V(4).Infof("Skipping device %s (%s), not in whitelist: %s",
624+
f.Name(), pciID, dp.options.whitelistIDs)
625+
626+
continue
627+
}
628+
}
629+
588630
filtered = append(filtered, f)
589631
}
590632

@@ -723,6 +765,7 @@ func main() {
723765
flag.IntVar(&opts.sharedDevNum, "shared-dev-num", 1, "number of containers sharing the same GPU device")
724766
flag.IntVar(&opts.temperatureLimit, "temp-limit", 100, "temperature limit at which device is marked unhealthy")
725767
flag.StringVar(&opts.preferredAllocationPolicy, "allocation-policy", "none", "modes of allocating GPU devices: balanced, packed and none")
768+
flag.StringVar(&opts.whitelistIDs, "whitelist-ids", "", "comma-separated list of device IDs to whitelist (e.g. 0x49c5,0x49c6)")
726769
flag.Parse()
727770

728771
if opts.sharedDevNum < 1 {
@@ -736,6 +779,18 @@ func main() {
736779
os.Exit(1)
737780
}
738781

782+
if opts.whitelistIDs != "" {
783+
if whiteListIDs, err := parsePCIDeviceIDs(opts.whitelistIDs); err != nil {
784+
klog.Error("Failed to parse whitelist-ids: ", err)
785+
786+
os.Exit(1)
787+
} else {
788+
klog.V(2).Infof("Whitelisted device IDs: %q", whiteListIDs)
789+
790+
opts.whitelistIDs = strings.Join(whiteListIDs, ",")
791+
}
792+
}
793+
739794
klog.V(1).Infof("GPU device plugin started with %s preferred allocation policy", opts.preferredAllocationPolicy)
740795

741796
plugin := newDevicePlugin(prefix+sysfsDrmDirectory, prefix+devfsDriDirectory, opts)

cmd/gpu_plugin/gpu_plugin_test.go

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,66 @@ func TestScan(t *testing.T) {
361361
expectedI915Devs: 1,
362362
expectedI915Monitors: 1,
363363
},
364+
{
365+
name: "two devices with only one whitelisted",
366+
sysfsdirs: []string{"card0/device/drm/card0", "card0/device/drm/controlD64", "card1/device/drm/card1"},
367+
sysfsfiles: map[string][]byte{
368+
"card0/device/vendor": []byte("0x8086"),
369+
"card0/device/device": []byte("0x1234"),
370+
"card1/device/vendor": []byte("0x8086"),
371+
"card1/device/device": []byte("0x9876"),
372+
},
373+
symlinkfiles: map[string]string{
374+
"card0/device/driver": "drivers/xe",
375+
"card1/device/driver": "drivers/i915",
376+
},
377+
devfsdirs: []string{
378+
"card0",
379+
"by-path/pci-0000:00:00.0-card",
380+
"by-path/pci-0000:00:00.0-render",
381+
"card1",
382+
"by-path/pci-0000:00:01.0-card",
383+
"by-path/pci-0000:00:01.0-render",
384+
},
385+
options: cliOptions{enableMonitoring: true, whitelistIDs: "0x1234"},
386+
expectedXeDevs: 1,
387+
expectedXeMonitors: 1,
388+
expectedI915Devs: 0,
389+
expectedI915Monitors: 0,
390+
},
391+
{
392+
name: "three devices with two whitelisted",
393+
sysfsdirs: []string{"card0/device/drm/card0", "card0/device/drm/controlD64", "card1/device/drm/card1", "card2/device/drm/card2"},
394+
sysfsfiles: map[string][]byte{
395+
"card0/device/vendor": []byte("0x8086"),
396+
"card0/device/device": []byte("0x1234"),
397+
"card1/device/vendor": []byte("0x8086"),
398+
"card1/device/device": []byte("0x9876"),
399+
"card2/device/vendor": []byte("0x8086"),
400+
"card2/device/device": []byte("0x0101"),
401+
},
402+
symlinkfiles: map[string]string{
403+
"card0/device/driver": "drivers/xe",
404+
"card1/device/driver": "drivers/i915",
405+
"card2/device/driver": "drivers/i915",
406+
},
407+
devfsdirs: []string{
408+
"card0",
409+
"by-path/pci-0000:00:00.0-card",
410+
"by-path/pci-0000:00:00.0-render",
411+
"card1",
412+
"by-path/pci-0000:00:01.0-card",
413+
"by-path/pci-0000:00:01.0-render",
414+
"card2",
415+
"by-path/pci-0000:00:02.0-card",
416+
"by-path/pci-0000:00:02.0-render",
417+
},
418+
options: cliOptions{enableMonitoring: true, whitelistIDs: "0x1234,0x9876"},
419+
expectedXeDevs: 1,
420+
expectedXeMonitors: 1,
421+
expectedI915Devs: 1,
422+
expectedI915Monitors: 1,
423+
},
364424
{
365425
name: "sriov-1-pf-no-vfs + monitoring",
366426
sysfsdirs: []string{"card0/device/drm/card0", "card0/device/drm/controlD64"},
@@ -1048,3 +1108,73 @@ func TestCDIDeviceInclusion(t *testing.T) {
10481108
t.Error("Invalid count for device (xe)")
10491109
}
10501110
}
1111+
1112+
func TestParsePCIDeviceIDs(t *testing.T) {
1113+
tests := []struct {
1114+
name string
1115+
input string
1116+
want []string
1117+
wantError bool
1118+
}{
1119+
{
1120+
name: "valid single ID",
1121+
input: "0x1234",
1122+
want: []string{"0x1234"},
1123+
wantError: false,
1124+
},
1125+
{
1126+
name: "valid multiple IDs",
1127+
input: "0x1234,0x5678,0x9abc",
1128+
want: []string{"0x1234", "0x5678", "0x9abc"},
1129+
wantError: false,
1130+
},
1131+
{
1132+
name: "valid IDs with spaces",
1133+
input: " 0x1234 , 0x5678 ",
1134+
want: []string{"0x1234", "0x5678"},
1135+
wantError: false,
1136+
},
1137+
{
1138+
name: "empty string",
1139+
input: "",
1140+
want: []string{},
1141+
wantError: true,
1142+
},
1143+
{
1144+
name: "invalid ID format",
1145+
input: "0x1234,abcd",
1146+
want: []string{},
1147+
wantError: true,
1148+
},
1149+
{
1150+
name: "invalid hex length",
1151+
input: "0x123,0x5678",
1152+
want: []string{},
1153+
wantError: true,
1154+
},
1155+
{
1156+
name: "extra comma",
1157+
input: "0x1234,",
1158+
want: []string{},
1159+
wantError: true,
1160+
},
1161+
{
1162+
name: "capita hex",
1163+
input: "0xAA12,",
1164+
want: []string{},
1165+
wantError: true,
1166+
},
1167+
}
1168+
1169+
for _, tt := range tests {
1170+
t.Run(tt.name, func(t *testing.T) {
1171+
got, err := parsePCIDeviceIDs(tt.input)
1172+
if (err != nil) != tt.wantError {
1173+
t.Errorf("parsePCIDeviceIDs() error = %v, wantError %v", err, tt.wantError)
1174+
}
1175+
if !reflect.DeepEqual(got, tt.want) {
1176+
t.Errorf("parsePCIDeviceIDs() = %v, want %v", got, tt.want)
1177+
}
1178+
})
1179+
}
1180+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
apiVersion: apps/v1
2+
kind: DaemonSet
3+
metadata:
4+
name: intel-gpu-plugin
5+
spec:
6+
template:
7+
spec:
8+
containers:
9+
- name: intel-gpu-plugin
10+
args:
11+
- "-v=4"
12+
- "-whitelist-ids=0x56a6,0x56a5,0x56a1,0x56a0,0x5694,0x5693,0x5692,0x5691,0x5690,0x56b3,0x56b2,0x56a4,0x56a3,0x5697,0x5696,0x5695,0x56b1,0x56b0,0x56a2,0x56ba,0x56bc,0x56bd,0x56bb"
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
resources:
2+
- ../../base
3+
patches:
4+
- path: add-args.yaml

deployments/operator/crd/bases/deviceplugin.intel.com_gpudeviceplugins.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,12 @@ spec:
132132
type: string
133133
type: object
134134
type: array
135+
whiteListIDs:
136+
description: |-
137+
WhiteListIDs is a comma-separated list of PCI IDs of GPU devices that should only be advertised by the plugin.
138+
If not set, all devices are advertised.
139+
The list can contain IDs in the form of '0x1234,0x49a4,0x50b4.
140+
type: string
135141
type: object
136142
status:
137143
description: GpuDevicePluginStatus defines the observed state of GpuDevicePlugin.

pkg/apis/deviceplugin/v1/gpudeviceplugin_types.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ type GpuDevicePluginSpec struct {
5353
// EnableMonitoring enables the monitoring resource ('i915_monitoring')
5454
// which gives access to all GPU devices on given node. Typically used with Intel XPU-Manager.
5555
EnableMonitoring bool `json:"enableMonitoring,omitempty"`
56+
57+
// WhiteListIDs is a comma-separated list of PCI IDs of GPU devices that should only be advertised by the plugin.
58+
// If not set, all devices are advertised.
59+
// The list can contain IDs in the form of '0x1234,0x49a4,0x50b4'.
60+
WhiteListIDs string `json:"whiteListIDs,omitempty"`
5661
}
5762

5863
// GpuDevicePluginStatus defines the observed state of GpuDevicePlugin.

pkg/apis/deviceplugin/v1/gpudeviceplugin_webhook.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,20 @@ package v1
1616

1717
import (
1818
"fmt"
19+
"regexp"
20+
"strings"
1921

2022
ctrl "sigs.k8s.io/controller-runtime"
2123

2224
"github.com/intel/intel-device-plugins-for-kubernetes/pkg/controllers"
2325
)
2426

27+
var pciIDRegex regexp.Regexp
28+
2529
// SetupWebhookWithManager sets up a webhook for GpuDevicePlugin custom resources.
2630
func (r *GpuDevicePlugin) SetupWebhookWithManager(mgr ctrl.Manager) error {
31+
pciIDRegex = *regexp.MustCompile(`^0x[0-9a-f]{4}$`)
32+
2733
return ctrl.NewWebhookManagedBy(mgr).
2834
For(r).
2935
WithDefaulter(&commonDevicePluginDefaulter{
@@ -44,5 +50,17 @@ func (r *GpuDevicePlugin) validatePlugin(ref *commonDevicePluginValidator) error
4450
return fmt.Errorf("%w: PreferredAllocationPolicy is valid only when setting sharedDevNum > 1", errValidation)
4551
}
4652

53+
if r.Spec.WhiteListIDs != "" {
54+
for id := range strings.SplitSeq(r.Spec.WhiteListIDs, ",") {
55+
if id == "" {
56+
return fmt.Errorf("%w: Empty PCI Device ID in WhiteListIDs", errValidation)
57+
}
58+
59+
if !pciIDRegex.MatchString(id) {
60+
return fmt.Errorf("%w: Invalid PCI Device ID: %s", errValidation, id)
61+
}
62+
}
63+
}
64+
4765
return validatePluginImage(r.Spec.Image, ref.expectedImage, &ref.expectedVersion)
4866
}

pkg/controllers/gpu/controller.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,5 +277,9 @@ func getPodArgs(gdp *devicepluginv1.GpuDevicePlugin) []string {
277277
args = append(args, "-allocation-policy", "none")
278278
}
279279

280+
if gdp.Spec.WhiteListIDs != "" {
281+
args = append(args, "-whitelist-ids", gdp.Spec.WhiteListIDs)
282+
}
283+
280284
return args
281285
}

0 commit comments

Comments
 (0)