Skip to content

Commit e3361c4

Browse files
committed
[no-relnote] Add basic CDI generate test
Signed-off-by: Evan Lezar <[email protected]>
1 parent 1615f55 commit e3361c4

File tree

5 files changed

+1009
-0
lines changed

5 files changed

+1009
-0
lines changed

cmd/nvidia-ctk/cdi/generate/generate.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ import (
2525
"github.com/urfave/cli/v2"
2626
cdi "tags.cncf.io/container-device-interface/pkg/parser"
2727

28+
"github.com/NVIDIA/go-nvml/pkg/nvml"
29+
2830
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
2931
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
3032
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv"
@@ -60,6 +62,9 @@ type options struct {
6062
files cli.StringSlice
6163
ignorePatterns cli.StringSlice
6264
}
65+
66+
// the following are used for dependency injection during spec generation.
67+
nvmllib nvml.Interface
6368
}
6469

6570
// NewCommand constructs a generate-cdi command with the specified logger
@@ -269,6 +274,8 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
269274
nvcdi.WithLibrarySearchPaths(opts.librarySearchPaths.Value()),
270275
nvcdi.WithCSVFiles(opts.csv.files.Value()),
271276
nvcdi.WithCSVIgnorePatterns(opts.csv.ignorePatterns.Value()),
277+
// We set the following to allow for dependency injection:
278+
nvcdi.WithNvmlLib(opts.nvmllib),
272279
)
273280
if err != nil {
274281
return nil, fmt.Errorf("failed to create CDI library: %v", err)
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
/**
2+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package generate
18+
19+
import (
20+
"bytes"
21+
"path/filepath"
22+
"testing"
23+
24+
"github.com/NVIDIA/go-nvml/pkg/nvml"
25+
"github.com/NVIDIA/go-nvml/pkg/nvml/mock/dgxa100"
26+
testlog "github.com/sirupsen/logrus/hooks/test"
27+
"github.com/stretchr/testify/require"
28+
29+
"github.com/NVIDIA/nvidia-container-toolkit/internal/test"
30+
)
31+
32+
func TestGenerateSpec(t *testing.T) {
33+
t.Setenv("__NVCT_TESTING_DEVICES_ARE_FILES", "true")
34+
moduleRoot, err := test.GetModuleRoot()
35+
require.NoError(t, err)
36+
37+
driverRoot := filepath.Join(moduleRoot, "testdata", "lookup", "rootfs-1")
38+
39+
logger, _ := testlog.NewNullLogger()
40+
testCases := []struct {
41+
description string
42+
options options
43+
expectedValidateError error
44+
expectedOptions options
45+
expectedError error
46+
expectedSpec string
47+
}{
48+
{
49+
description: "default",
50+
options: options{
51+
format: "yaml",
52+
mode: "nvml",
53+
vendor: "example.com",
54+
class: "device",
55+
driverRoot: driverRoot,
56+
},
57+
expectedOptions: options{
58+
format: "yaml",
59+
mode: "nvml",
60+
vendor: "example.com",
61+
class: "device",
62+
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook",
63+
driverRoot: driverRoot,
64+
},
65+
expectedSpec: `---
66+
cdiVersion: 0.5.0
67+
containerEdits:
68+
deviceNodes:
69+
- hostPath: /Users/elezar/dev/container-toolkit/testdata/lookup/rootfs-1/dev/nvidiactl
70+
path: /dev/nvidiactl
71+
env:
72+
- NVIDIA_VISIBLE_DEVICES=void
73+
hooks:
74+
- args:
75+
- nvidia-cdi-hook
76+
- create-symlinks
77+
- --link
78+
- libcuda.so.1::/lib/x86_64-linux-gnu/libcuda.so
79+
hookName: createContainer
80+
path: /usr/bin/nvidia-cdi-hook
81+
- args:
82+
- nvidia-cdi-hook
83+
- update-ldcache
84+
- --folder
85+
- /lib/x86_64-linux-gnu
86+
hookName: createContainer
87+
path: /usr/bin/nvidia-cdi-hook
88+
mounts:
89+
- containerPath: /lib/x86_64-linux-gnu/libcuda.so.999.88.77
90+
hostPath: /Users/elezar/dev/container-toolkit/testdata/lookup/rootfs-1/lib/x86_64-linux-gnu/libcuda.so.999.88.77
91+
options:
92+
- ro
93+
- nosuid
94+
- nodev
95+
- bind
96+
devices:
97+
- containerEdits:
98+
deviceNodes:
99+
- hostPath: /Users/elezar/dev/container-toolkit/testdata/lookup/rootfs-1/dev/nvidia0
100+
path: /dev/nvidia0
101+
name: "0"
102+
- containerEdits:
103+
deviceNodes:
104+
- hostPath: /Users/elezar/dev/container-toolkit/testdata/lookup/rootfs-1/dev/nvidia0
105+
path: /dev/nvidia0
106+
name: all
107+
kind: example.com/device
108+
`,
109+
},
110+
}
111+
112+
for _, tc := range testCases {
113+
t.Run(tc.description, func(t *testing.T) {
114+
c := command{
115+
logger: logger,
116+
}
117+
118+
err := c.validateFlags(nil, &tc.options)
119+
require.ErrorIs(t, err, tc.expectedValidateError)
120+
require.EqualValues(t, tc.expectedOptions, tc.options)
121+
122+
// Set up a mock server, reusing the DGX A100 mock.
123+
server := dgxa100.New()
124+
// Override the driver version to match the version in our mock filesystem.
125+
server.SystemGetDriverVersionFunc = func() (string, nvml.Return) {
126+
return "999.88.77", nvml.SUCCESS
127+
}
128+
// Set the device count to 1 explicitly since we only have a single device node.
129+
server.DeviceGetCountFunc = func() (int, nvml.Return) {
130+
return 1, nvml.SUCCESS
131+
}
132+
for _, d := range server.Devices {
133+
// TODO: This is not implemented in the mock.
134+
(d.(*dgxa100.Device)).GetMaxMigDeviceCountFunc = func() (int, nvml.Return) {
135+
return 0, nvml.SUCCESS
136+
}
137+
}
138+
tc.options.nvmllib = server
139+
140+
spec, err := c.generateSpec(&tc.options)
141+
require.ErrorIs(t, err, tc.expectedError)
142+
143+
var buf bytes.Buffer
144+
_, err = spec.WriteTo(&buf)
145+
require.NoError(t, err)
146+
147+
require.Equal(t, tc.expectedSpec, buf.String())
148+
})
149+
}
150+
}

0 commit comments

Comments
 (0)