Skip to content

Commit e32a94a

Browse files
committed
[no-relnote] Add basic CDI generate test
Signed-off-by: Evan Lezar <[email protected]>
1 parent c1577a4 commit e32a94a

File tree

1 file changed

+158
-0
lines changed

1 file changed

+158
-0
lines changed
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
/**
2+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package generate
18+
19+
import (
20+
"bytes"
21+
"path/filepath"
22+
"strings"
23+
"testing"
24+
25+
"github.com/NVIDIA/go-nvml/pkg/nvml"
26+
"github.com/NVIDIA/go-nvml/pkg/nvml/mock/dgxa100"
27+
testlog "github.com/sirupsen/logrus/hooks/test"
28+
"github.com/stretchr/testify/require"
29+
30+
"github.com/NVIDIA/nvidia-container-toolkit/internal/test"
31+
)
32+
33+
func TestGenerateSpec(t *testing.T) {
34+
t.Setenv("__NVCT_TESTING_DEVICES_ARE_FILES", "true")
35+
moduleRoot, err := test.GetModuleRoot()
36+
require.NoError(t, err)
37+
38+
driverRoot := filepath.Join(moduleRoot, "testdata", "lookup", "rootfs-1")
39+
40+
logger, _ := testlog.NewNullLogger()
41+
testCases := []struct {
42+
description string
43+
options options
44+
expectedValidateError error
45+
expectedOptions options
46+
expectedError error
47+
expectedSpec string
48+
}{
49+
{
50+
description: "default",
51+
options: options{
52+
format: "yaml",
53+
mode: "nvml",
54+
vendor: "example.com",
55+
class: "device",
56+
driverRoot: driverRoot,
57+
},
58+
expectedOptions: options{
59+
format: "yaml",
60+
mode: "nvml",
61+
vendor: "example.com",
62+
class: "device",
63+
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook",
64+
driverRoot: driverRoot,
65+
},
66+
expectedSpec: `---
67+
cdiVersion: 0.5.0
68+
containerEdits:
69+
deviceNodes:
70+
- hostPath: {{ .driverRoot }}/dev/nvidiactl
71+
path: /dev/nvidiactl
72+
env:
73+
- NVIDIA_VISIBLE_DEVICES=void
74+
hooks:
75+
- args:
76+
- nvidia-cdi-hook
77+
- create-symlinks
78+
- --link
79+
- libcuda.so.1::/lib/x86_64-linux-gnu/libcuda.so
80+
hookName: createContainer
81+
path: /usr/bin/nvidia-cdi-hook
82+
- args:
83+
- nvidia-cdi-hook
84+
- cuda-compat
85+
- --driver-version
86+
- 999.88.77
87+
hookName: createContainer
88+
path: /usr/bin/nvidia-cdi-hook
89+
- args:
90+
- nvidia-cdi-hook
91+
- update-ldcache
92+
- --folder
93+
- /lib/x86_64-linux-gnu
94+
hookName: createContainer
95+
path: /usr/bin/nvidia-cdi-hook
96+
mounts:
97+
- containerPath: /lib/x86_64-linux-gnu/libcuda.so.999.88.77
98+
hostPath: {{ .driverRoot }}/lib/x86_64-linux-gnu/libcuda.so.999.88.77
99+
options:
100+
- ro
101+
- nosuid
102+
- nodev
103+
- bind
104+
devices:
105+
- containerEdits:
106+
deviceNodes:
107+
- hostPath: {{ .driverRoot }}/dev/nvidia0
108+
path: /dev/nvidia0
109+
name: "0"
110+
- containerEdits:
111+
deviceNodes:
112+
- hostPath: {{ .driverRoot }}/dev/nvidia0
113+
path: /dev/nvidia0
114+
name: all
115+
kind: example.com/device
116+
`,
117+
},
118+
}
119+
120+
for _, tc := range testCases {
121+
t.Run(tc.description, func(t *testing.T) {
122+
c := command{
123+
logger: logger,
124+
}
125+
126+
err := c.validateFlags(nil, &tc.options)
127+
require.ErrorIs(t, err, tc.expectedValidateError)
128+
require.EqualValues(t, tc.expectedOptions, tc.options)
129+
130+
// Set up a mock server, reusing the DGX A100 mock.
131+
server := dgxa100.New()
132+
// Override the driver version to match the version in our mock filesystem.
133+
server.SystemGetDriverVersionFunc = func() (string, nvml.Return) {
134+
return "999.88.77", nvml.SUCCESS
135+
}
136+
// Set the device count to 1 explicitly since we only have a single device node.
137+
server.DeviceGetCountFunc = func() (int, nvml.Return) {
138+
return 1, nvml.SUCCESS
139+
}
140+
for _, d := range server.Devices {
141+
// TODO: This is not implemented in the mock.
142+
(d.(*dgxa100.Device)).GetMaxMigDeviceCountFunc = func() (int, nvml.Return) {
143+
return 0, nvml.SUCCESS
144+
}
145+
}
146+
tc.options.nvmllib = server
147+
148+
spec, err := c.generateSpec(&tc.options)
149+
require.ErrorIs(t, err, tc.expectedError)
150+
151+
var buf bytes.Buffer
152+
_, err = spec.WriteTo(&buf)
153+
require.NoError(t, err)
154+
155+
require.Equal(t, strings.ReplaceAll(tc.expectedSpec, "{{ .driverRoot }}", driverRoot), buf.String())
156+
})
157+
}
158+
}

0 commit comments

Comments
 (0)