Skip to content

Commit 743d84d

Browse files
committed
Allow cdi mode to work with --gpus flag
This changes ensures that the cdi modifier also removes the NVIDIA Container Runtime Hook from the incoming spec. This aligns with what is done for CSV modifications and prevents an error when starting the container. Signed-off-by: Evan Lezar <[email protected]>
1 parent c7fec23 commit 743d84d

File tree

5 files changed

+197
-13
lines changed

5 files changed

+197
-13
lines changed

internal/modifier/csv.go

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,20 +68,10 @@ func NewCSVModifier(logger logger.Interface, cfg *config.Config, container image
6868
return nil, fmt.Errorf("failed to get CDI spec: %v", err)
6969
}
7070

71-
cdiModifier, err := cdi.New(
71+
return cdi.New(
7272
cdi.WithLogger(logger),
7373
cdi.WithSpec(spec.Raw()),
7474
)
75-
if err != nil {
76-
return nil, fmt.Errorf("failed to construct CDI modifier: %v", err)
77-
}
78-
79-
modifiers := Merge(
80-
nvidiaContainerRuntimeHookRemover{logger},
81-
cdiModifier,
82-
)
83-
84-
return modifiers, nil
8575
}
8676

8777
func checkRequirements(logger logger.Interface, image image.CUDA) error {

internal/modifier/hook_remover.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ type nvidiaContainerRuntimeHookRemover struct {
3333

3434
var _ oci.SpecModifier = (*nvidiaContainerRuntimeHookRemover)(nil)
3535

36+
// NewNvidiaContainerRuntimeHookRemover creates a modifier that removes any NVIDIA Container Runtime hooks from the provided spec.
37+
func NewNvidiaContainerRuntimeHookRemover(logger logger.Interface) oci.SpecModifier {
38+
return nvidiaContainerRuntimeHookRemover{
39+
logger: logger,
40+
}
41+
}
42+
3643
// Modify removes any NVIDIA Container Runtime hooks from the provided spec
3744
func (m nvidiaContainerRuntimeHookRemover) Modify(spec *specs.Spec) error {
3845
if spec == nil {

internal/runtime/runtime_factory.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp
8585
switch modifierType {
8686
case "mode":
8787
modifiers = append(modifiers, modeModifier)
88+
case "nvidia-hook-remover":
89+
modifiers = append(modifiers, modifier.NewNvidiaContainerRuntimeHookRemover(logger))
8890
case "graphics":
8991
graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image, driver)
9092
if err != nil {
@@ -121,10 +123,10 @@ func supportedModifierTypes(mode string) []string {
121123
switch mode {
122124
case "cdi":
123125
// For CDI mode we make no additional modifications.
124-
return []string{"mode"}
126+
return []string{"nvidia-hook-remover", "mode"}
125127
case "csv":
126128
// For CSV mode we support mode and feature-gated modification.
127-
return []string{"mode", "feature-gated"}
129+
return []string{"nvidia-hook-remover", "mode", "feature-gated"}
128130
default:
129131
return []string{"mode", "graphics", "feature-gated"}
130132
}

internal/runtime/runtime_factory_test.go

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030

3131
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
3232
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
33+
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
3334
"github.com/NVIDIA/nvidia-container-toolkit/internal/test"
3435
)
3536

@@ -165,3 +166,181 @@ func TestFactoryMethod(t *testing.T) {
165166
})
166167
}
167168
}
169+
170+
func TestNewSpecModifier(t *testing.T) {
171+
logger, _ := testlog.NewNullLogger()
172+
driver := root.New(
173+
root.WithDriverRoot("/nvidia/driver/root"),
174+
)
175+
testCases := []struct {
176+
description string
177+
config *config.Config
178+
spec *specs.Spec
179+
expectedSpec *specs.Spec
180+
}{
181+
{
182+
description: "csv mode removes nvidia-container-runtime-hook",
183+
config: &config.Config{
184+
NVIDIAContainerRuntimeConfig: config.RuntimeConfig{
185+
Mode: "csv",
186+
},
187+
},
188+
spec: &specs.Spec{
189+
Hooks: &specs.Hooks{
190+
Prestart: []specs.Hook{
191+
{
192+
Path: "/path/to/nvidia-container-runtime-hook",
193+
Args: []string{"/path/to/nvidia-container-runtime-hook", "prestart"},
194+
},
195+
},
196+
},
197+
},
198+
expectedSpec: &specs.Spec{
199+
Hooks: &specs.Hooks{
200+
Prestart: nil,
201+
},
202+
},
203+
},
204+
{
205+
description: "csv mode removes nvidia-container-toolkit",
206+
config: &config.Config{
207+
NVIDIAContainerRuntimeConfig: config.RuntimeConfig{
208+
Mode: "csv",
209+
},
210+
},
211+
spec: &specs.Spec{
212+
Hooks: &specs.Hooks{
213+
Prestart: []specs.Hook{
214+
{
215+
Path: "/path/to/nvidia-container-toolkit",
216+
Args: []string{"/path/to/nvidia-container-toolkit", "prestart"},
217+
},
218+
},
219+
},
220+
},
221+
expectedSpec: &specs.Spec{
222+
Hooks: &specs.Hooks{
223+
Prestart: nil,
224+
},
225+
},
226+
},
227+
{
228+
description: "cdi mode removes nvidia-container-runtime-hook",
229+
config: &config.Config{
230+
NVIDIAContainerRuntimeConfig: config.RuntimeConfig{
231+
Mode: "cdi",
232+
},
233+
},
234+
spec: &specs.Spec{
235+
Hooks: &specs.Hooks{
236+
Prestart: []specs.Hook{
237+
{
238+
Path: "/path/to/nvidia-container-runtime-hook",
239+
Args: []string{"/path/to/nvidia-container-runtime-hook", "prestart"},
240+
},
241+
},
242+
},
243+
},
244+
expectedSpec: &specs.Spec{
245+
Hooks: &specs.Hooks{
246+
Prestart: nil,
247+
},
248+
},
249+
},
250+
{
251+
description: "cdi mode removes nvidia-container-toolkit",
252+
config: &config.Config{
253+
NVIDIAContainerRuntimeConfig: config.RuntimeConfig{
254+
Mode: "cdi",
255+
},
256+
},
257+
spec: &specs.Spec{
258+
Hooks: &specs.Hooks{
259+
Prestart: []specs.Hook{
260+
{
261+
Path: "/path/to/nvidia-container-toolkit",
262+
Args: []string{"/path/to/nvidia-container-toolkit", "prestart"},
263+
},
264+
},
265+
},
266+
},
267+
expectedSpec: &specs.Spec{
268+
Hooks: &specs.Hooks{
269+
Prestart: nil,
270+
},
271+
},
272+
},
273+
{
274+
description: "legacy mode keeps nvidia-container-runtime-hook",
275+
config: &config.Config{
276+
NVIDIAContainerRuntimeConfig: config.RuntimeConfig{
277+
Mode: "legacy",
278+
},
279+
},
280+
spec: &specs.Spec{
281+
Hooks: &specs.Hooks{
282+
Prestart: []specs.Hook{
283+
{
284+
Path: "/path/to/nvidia-container-runtime-hook",
285+
Args: []string{"/path/to/nvidia-container-runtime-hook", "prestart"},
286+
},
287+
},
288+
},
289+
},
290+
expectedSpec: &specs.Spec{
291+
Hooks: &specs.Hooks{
292+
Prestart: []specs.Hook{
293+
{
294+
Path: "/path/to/nvidia-container-runtime-hook",
295+
Args: []string{"/path/to/nvidia-container-runtime-hook", "prestart"},
296+
},
297+
},
298+
},
299+
},
300+
},
301+
{
302+
description: "legacy mode keeps nvidia-container-toolkit",
303+
config: &config.Config{
304+
NVIDIAContainerRuntimeConfig: config.RuntimeConfig{
305+
Mode: "legacy",
306+
},
307+
},
308+
spec: &specs.Spec{
309+
Hooks: &specs.Hooks{
310+
Prestart: []specs.Hook{
311+
{
312+
Path: "/path/to/nvidia-container-toolkit",
313+
Args: []string{"/path/to/nvidia-container-toolkit", "prestart"},
314+
},
315+
},
316+
},
317+
},
318+
expectedSpec: &specs.Spec{
319+
Hooks: &specs.Hooks{
320+
Prestart: []specs.Hook{
321+
{
322+
Path: "/path/to/nvidia-container-toolkit",
323+
Args: []string{"/path/to/nvidia-container-toolkit", "prestart"},
324+
},
325+
},
326+
},
327+
},
328+
},
329+
}
330+
331+
for _, tc := range testCases {
332+
t.Run(tc.description, func(t *testing.T) {
333+
spec := &oci.SpecMock{
334+
LoadFunc: func() (*specs.Spec, error) {
335+
return tc.spec, nil
336+
},
337+
}
338+
m, err := newSpecModifier(logger, tc.config, spec, driver)
339+
require.NoError(t, err)
340+
341+
err = m.Modify(tc.spec)
342+
require.NoError(t, err)
343+
require.EqualValues(t, tc.expectedSpec, tc.spec)
344+
})
345+
}
346+
}

tests/e2e/nvidia-container-toolkit_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ var _ = Describe("docker", Ordered, func() {
7474
Expect(containerOutput).To(Equal(hostOutput))
7575
})
7676

77+
It("should support automatic CDI spec generation with the --gpus flag", func(ctx context.Context) {
78+
containerOutput, _, err := r.Run("docker run --rm -i --gpus=all --runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=runtime.nvidia.com/gpu=all ubuntu nvidia-smi -L")
79+
Expect(err).ToNot(HaveOccurred())
80+
Expect(containerOutput).To(Equal(hostOutput))
81+
})
82+
7783
It("should support the --gpus flag using the nvidia-container-runtime", func(ctx context.Context) {
7884
containerOutput, _, err := r.Run("docker run --rm -i --runtime=nvidia --gpus all ubuntu nvidia-smi -L")
7985
Expect(err).ToNot(HaveOccurred())

0 commit comments

Comments
 (0)