Skip to content

Commit 5ecb08b

Browse files
committed
WIP: Add an NRI plugin to filter which NUMA nodes a container sees
Signed-off-by: Kevin Klues <[email protected]>
1 parent 42e47a7 commit 5ecb08b

File tree

4 files changed

+497
-0
lines changed

4 files changed

+497
-0
lines changed

cmd/nvidia-device-plugin/main.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import (
3636
spec "github.com/NVIDIA/k8s-device-plugin/api/config/v1"
3737
"github.com/NVIDIA/k8s-device-plugin/internal/info"
3838
"github.com/NVIDIA/k8s-device-plugin/internal/logger"
39+
"github.com/NVIDIA/k8s-device-plugin/internal/nri"
3940
"github.com/NVIDIA/k8s-device-plugin/internal/plugin"
4041
"github.com/NVIDIA/k8s-device-plugin/internal/rm"
4142
"github.com/NVIDIA/k8s-device-plugin/internal/watch"
@@ -248,16 +249,25 @@ func start(c *cli.Context, o *options) error {
248249
var started bool
249250
var restartTimeout <-chan time.Time
250251
var plugins []plugin.Interface
252+
var numaPlugin *nri.NUMAFilterPlugin
251253
restart:
252254
// If we are restarting, stop plugins from previous run.
253255
if started {
254256
err := stopPlugins(plugins)
255257
if err != nil {
256258
return fmt.Errorf("error stopping plugins from previous run: %v", err)
257259
}
260+
numaPlugin.Stop()
258261
}
259262

260263
klog.Info("Starting Plugins.")
264+
265+
// Start the NRI plugin
266+
nriPlugin := nri.NewNUMAFilterPlugin()
267+
if err := nriPlugin.Start(c.Context); err != nil {
268+
klog.Fatalf("Error starting NRI plugin: %v", err)
269+
}
270+
261271
plugins, restartPlugins, err := startPlugins(c, o)
262272
if err != nil {
263273
return fmt.Errorf("error starting plugins: %v", err)

deployments/helm/nvidia-device-plugin/templates/daemonset-device-plugin.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ spec:
212212
mountPath: /mps
213213
- name: cdi-root
214214
mountPath: /var/run/cdi
215+
- name: nri-root
216+
mountPath: /var/run/nri
215217
{{- if $options.hasConfigMap }}
216218
- name: available-configs
217219
mountPath: /available-configs
@@ -242,6 +244,10 @@ spec:
242244
hostPath:
243245
path: /var/run/cdi
244246
type: DirectoryOrCreate
247+
- name: nri-root
248+
hostPath:
249+
path: /var/run/nri
250+
type: DirectoryOrCreate
245251
{{- if $options.hasConfigMap }}
246252
- name: available-configs
247253
configMap:

internal/nri/discoverer.go

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
/*
2+
* Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package nri
18+
19+
import (
20+
"fmt"
21+
"strconv"
22+
"strings"
23+
24+
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
25+
"github.com/NVIDIA/go-nvml/pkg/nvml"
26+
"github.com/containerd/nri/pkg/api"
27+
"k8s.io/klog/v2"
28+
)
29+
30+
// DeviceDiscoverer defines the interface for discovering assigned devices
31+
type DeviceDiscoverer interface {
32+
// GetAssignedDevices returns the list of GPU and MIG device UUIDs assigned to a container
33+
GetAssignedDevices(c *api.Container) ([]string, error)
34+
// GetAllDevices returns all GPU and MIG devices in the system
35+
GetAllDevices() (map[string]*Device, error)
36+
}
37+
38+
// Device represents a GPU or MIG device with cached UUID and index
39+
type Device struct {
40+
nvml.Device
41+
uuid string
42+
index string
43+
}
44+
45+
// deviceDiscoverer discovers devices based on environment variables
46+
type deviceDiscoverer struct{}
47+
48+
// NewDeviceDiscoverer creates a new device discoverer
49+
func NewDeviceDiscoverer() DeviceDiscoverer {
50+
return &deviceDiscoverer{}
51+
}
52+
53+
// GetAssignedDevices implements DeviceDiscoverer interface
54+
func (d *deviceDiscoverer) GetAssignedDevices(c *api.Container) ([]string, error) {
55+
// Look for NVIDIA_VISIBLE_DEVICES environment variable
56+
for _, env := range c.Env {
57+
if strings.HasPrefix(env, "NVIDIA_VISIBLE_DEVICES=") {
58+
// Get the value after the equals sign
59+
value := strings.TrimPrefix(env, "NVIDIA_VISIBLE_DEVICES=")
60+
61+
switch value {
62+
case "", "void":
63+
return nil, nil
64+
case "all":
65+
uuids, err := d.getUUIDsFromAllDevices()
66+
if err != nil {
67+
return nil, fmt.Errorf("failed to get UUIDs from all devices: %w", err)
68+
}
69+
return uuids, nil
70+
default:
71+
ids := strings.Split(value, ",")
72+
uuids, err := d.processDeviceIDs(ids)
73+
if err != nil {
74+
return nil, fmt.Errorf("failed to process device IDs as UUID or Index: %w", err)
75+
}
76+
return uuids, nil
77+
}
78+
}
79+
}
80+
return nil, nil
81+
}
82+
83+
// GetAllDevices implements DeviceDiscoverer interface
84+
func (d *deviceDiscoverer) GetAllDevices() (map[string]*Device, error) {
85+
// Initialize NVML
86+
nvmlLib := nvml.New()
87+
if ret := nvmlLib.Init(); ret != nvml.SUCCESS {
88+
return nil, fmt.Errorf("failed to initialize NVML: %w", ret)
89+
}
90+
defer func() {
91+
if ret := nvmlLib.Shutdown(); ret != nvml.SUCCESS {
92+
klog.Warning("Failed to shutdown NVML", "error", ret)
93+
}
94+
}()
95+
96+
// Create the nvlib device interface
97+
nvlib := device.New(nvmlLib)
98+
99+
devices, err := nvlib.GetDevices()
100+
if err != nil {
101+
return nil, fmt.Errorf("failed to get devices from nvlib: %w", err)
102+
}
103+
104+
allDevices := make(map[string]*Device)
105+
for i, dev := range devices {
106+
// Add the GPU device
107+
uuid, ret := dev.GetUUID()
108+
if ret != nvml.SUCCESS {
109+
return nil, fmt.Errorf("failed to get GPU device UUID: %v", ret)
110+
}
111+
allDevices[uuid] = &Device{Device: dev, uuid: uuid, index: strconv.Itoa(i)}
112+
113+
// Add MIG devices
114+
migs, err := dev.GetMigDevices()
115+
if err != nil {
116+
return nil, fmt.Errorf("failed to get MIG devices: %w", err)
117+
}
118+
for j, mig := range migs {
119+
// Convert MIG device to NVML device
120+
uuid, ret := mig.GetUUID()
121+
if ret != nvml.SUCCESS {
122+
return nil, fmt.Errorf("failed to get MIG device UUID: %v", ret)
123+
}
124+
migDevice, ret := nvmlLib.DeviceGetHandleByUUID(uuid)
125+
if ret != nvml.SUCCESS {
126+
return nil, fmt.Errorf("failed to get MIG device handle: %v", ret)
127+
}
128+
allDevices[uuid] = &Device{Device: migDevice, uuid: uuid, index: fmt.Sprintf("%d:%d", i, j)}
129+
}
130+
}
131+
return allDevices, nil
132+
}
133+
134+
// getUUIDsFromAllDevices returns a list of UUIDs from all devices
135+
func (d *deviceDiscoverer) getUUIDsFromAllDevices() ([]string, error) {
136+
devices, err := d.GetAllDevices()
137+
if err != nil {
138+
return nil, fmt.Errorf("failed to get all devices: %w", err)
139+
}
140+
var uuids []string
141+
for uuid := range devices {
142+
uuids = append(uuids, uuid)
143+
}
144+
return uuids, nil
145+
}
146+
147+
// processDeviceIDs processes a comma-separated list of UUIDs or indices
148+
func (d *deviceDiscoverer) processDeviceIDs(ids []string) ([]string, error) {
149+
var uuids []string
150+
devices, err := d.GetAllDevices()
151+
if err != nil {
152+
return nil, fmt.Errorf("failed to get all devices: %w", err)
153+
}
154+
155+
for _, id := range ids {
156+
// Check if the ID is a UUID or an index
157+
if strings.Contains(id, ":") {
158+
// Convert MIG index to UUID
159+
for _, device := range devices {
160+
if device.index == id {
161+
uuids = append(uuids, device.uuid)
162+
break
163+
}
164+
}
165+
} else if _, err := strconv.Atoi(id); err == nil {
166+
// Convert GPU index to UUID
167+
for _, device := range devices {
168+
if device.index == id {
169+
uuids = append(uuids, device.uuid)
170+
break
171+
}
172+
}
173+
} else {
174+
// Assume it's a UUID
175+
uuids = append(uuids, id)
176+
}
177+
}
178+
return uuids, nil
179+
}

0 commit comments

Comments
 (0)