Skip to content

Commit d99a67e

Browse files
committed
[no-relnote] Add basic CDI generate test
Signed-off-by: Evan Lezar <[email protected]>
1 parent f15964f commit d99a67e

File tree

5 files changed

+1010
-0
lines changed

5 files changed

+1010
-0
lines changed

cmd/nvidia-ctk/cdi/generate/generate.go

+7
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ import (
2525
"github.com/urfave/cli/v2"
2626
cdi "tags.cncf.io/container-device-interface/pkg/parser"
2727

28+
"github.com/NVIDIA/go-nvml/pkg/nvml"
29+
2830
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
2931
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
3032
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv"
@@ -60,6 +62,9 @@ type options struct {
6062
files cli.StringSlice
6163
ignorePatterns cli.StringSlice
6264
}
65+
66+
// the following are used for dependency injection during spec generation.
67+
nvmllib nvml.Interface
6368
}
6469

6570
// NewCommand constructs a generate-cdi command with the specified logger
@@ -269,6 +274,8 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
269274
nvcdi.WithLibrarySearchPaths(opts.librarySearchPaths.Value()),
270275
nvcdi.WithCSVFiles(opts.csv.files.Value()),
271276
nvcdi.WithCSVIgnorePatterns(opts.csv.ignorePatterns.Value()),
277+
// We set the following to allow for dependency injection:
278+
nvcdi.WithNvmlLib(opts.nvmllib),
272279
)
273280
if err != nil {
274281
return nil, fmt.Errorf("failed to create CDI library: %v", err)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
/**
2+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package generate
18+
19+
import (
20+
"bytes"
21+
"path/filepath"
22+
"strings"
23+
"testing"
24+
25+
"github.com/NVIDIA/go-nvml/pkg/nvml"
26+
"github.com/NVIDIA/go-nvml/pkg/nvml/mock/dgxa100"
27+
testlog "github.com/sirupsen/logrus/hooks/test"
28+
"github.com/stretchr/testify/require"
29+
30+
"github.com/NVIDIA/nvidia-container-toolkit/internal/test"
31+
)
32+
33+
func TestGenerateSpec(t *testing.T) {
34+
t.Setenv("__NVCT_TESTING_DEVICES_ARE_FILES", "true")
35+
moduleRoot, err := test.GetModuleRoot()
36+
require.NoError(t, err)
37+
38+
driverRoot := filepath.Join(moduleRoot, "testdata", "lookup", "rootfs-1")
39+
40+
logger, _ := testlog.NewNullLogger()
41+
testCases := []struct {
42+
description string
43+
options options
44+
expectedValidateError error
45+
expectedOptions options
46+
expectedError error
47+
expectedSpec string
48+
}{
49+
{
50+
description: "default",
51+
options: options{
52+
format: "yaml",
53+
mode: "nvml",
54+
vendor: "example.com",
55+
class: "device",
56+
driverRoot: driverRoot,
57+
},
58+
expectedOptions: options{
59+
format: "yaml",
60+
mode: "nvml",
61+
vendor: "example.com",
62+
class: "device",
63+
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook",
64+
driverRoot: driverRoot,
65+
},
66+
expectedSpec: `---
67+
cdiVersion: 0.5.0
68+
containerEdits:
69+
deviceNodes:
70+
- hostPath: {{ .driverRoot }}/dev/nvidiactl
71+
path: /dev/nvidiactl
72+
env:
73+
- NVIDIA_VISIBLE_DEVICES=void
74+
hooks:
75+
- args:
76+
- nvidia-cdi-hook
77+
- create-symlinks
78+
- --link
79+
- libcuda.so.1::/lib/x86_64-linux-gnu/libcuda.so
80+
hookName: createContainer
81+
path: /usr/bin/nvidia-cdi-hook
82+
- args:
83+
- nvidia-cdi-hook
84+
- update-ldcache
85+
- --folder
86+
- /lib/x86_64-linux-gnu
87+
hookName: createContainer
88+
path: /usr/bin/nvidia-cdi-hook
89+
mounts:
90+
- containerPath: /lib/x86_64-linux-gnu/libcuda.so.999.88.77
91+
hostPath: {{ .driverRoot }}/lib/x86_64-linux-gnu/libcuda.so.999.88.77
92+
options:
93+
- ro
94+
- nosuid
95+
- nodev
96+
- bind
97+
devices:
98+
- containerEdits:
99+
deviceNodes:
100+
- hostPath: {{ .driverRoot }}/dev/nvidia0
101+
path: /dev/nvidia0
102+
name: "0"
103+
- containerEdits:
104+
deviceNodes:
105+
- hostPath: {{ .driverRoot }}/dev/nvidia0
106+
path: /dev/nvidia0
107+
name: all
108+
kind: example.com/device
109+
`,
110+
},
111+
}
112+
113+
for _, tc := range testCases {
114+
t.Run(tc.description, func(t *testing.T) {
115+
c := command{
116+
logger: logger,
117+
}
118+
119+
err := c.validateFlags(nil, &tc.options)
120+
require.ErrorIs(t, err, tc.expectedValidateError)
121+
require.EqualValues(t, tc.expectedOptions, tc.options)
122+
123+
// Set up a mock server, reusing the DGX A100 mock.
124+
server := dgxa100.New()
125+
// Override the driver version to match the version in our mock filesystem.
126+
server.SystemGetDriverVersionFunc = func() (string, nvml.Return) {
127+
return "999.88.77", nvml.SUCCESS
128+
}
129+
// Set the device count to 1 explicitly since we only have a single device node.
130+
server.DeviceGetCountFunc = func() (int, nvml.Return) {
131+
return 1, nvml.SUCCESS
132+
}
133+
for _, d := range server.Devices {
134+
// TODO: This is not implemented in the mock.
135+
(d.(*dgxa100.Device)).GetMaxMigDeviceCountFunc = func() (int, nvml.Return) {
136+
return 0, nvml.SUCCESS
137+
}
138+
}
139+
tc.options.nvmllib = server
140+
141+
spec, err := c.generateSpec(&tc.options)
142+
require.ErrorIs(t, err, tc.expectedError)
143+
144+
var buf bytes.Buffer
145+
_, err = spec.WriteTo(&buf)
146+
require.NoError(t, err)
147+
148+
require.Equal(t, strings.ReplaceAll(tc.expectedSpec, "{{ .driverRoot }}", driverRoot), buf.String())
149+
})
150+
}
151+
}

0 commit comments

Comments
 (0)