Skip to content

Commit cc9c609

Browse files
committed
Default to jit-cdi mode in the nvidia runtime
Signed-off-by: Evan Lezar <[email protected]>
1 parent 4cfb7e1 commit cc9c609

File tree

6 files changed

+108
-40
lines changed

6 files changed

+108
-40
lines changed

cmd/nvidia-container-runtime/main_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ func TestGoodInput(t *testing.T) {
125125
// Check config.json for NVIDIA prestart hook
126126
spec, err = cfg.getRuntimeSpec()
127127
require.NoError(t, err, "should be no errors when reading and parsing spec from config.json")
128-
require.NotEmpty(t, spec.Hooks, "there should be hooks in config.json")
129-
require.Equal(t, 1, nvidiaHookCount(spec.Hooks), "exactly one nvidia prestart hook should be inserted correctly into config.json")
128+
require.Empty(t, spec.Hooks, "there should be hooks in config.json")
129+
require.Equal(t, 0, nvidiaHookCount(spec.Hooks), "exactly one nvidia prestart hook should be inserted correctly into config.json")
130130
}
131131

132132
// NVIDIA prestart hook already present in config file
@@ -171,8 +171,8 @@ func TestDuplicateHook(t *testing.T) {
171171
// Check config.json for NVIDIA prestart hook
172172
spec, err = cfg.getRuntimeSpec()
173173
require.NoError(t, err, "should be no errors when reading and parsing spec from config.json")
174-
require.NotEmpty(t, spec.Hooks, "there should be hooks in config.json")
175-
require.Equal(t, 1, nvidiaHookCount(spec.Hooks), "exactly one nvidia prestart hook should be inserted correctly into config.json")
174+
require.Empty(t, spec.Hooks, "there should be no hooks in config.json")
175+
require.Equal(t, 0, nvidiaHookCount(spec.Hooks), "exactly one nvidia prestart hook should be inserted correctly into config.json")
176176
}
177177

178178
// addNVIDIAHook is a basic wrapper for an addHookModifier that is used for

internal/info/auto.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ const (
3030
RuntimeModeLegacy = RuntimeMode("legacy")
3131
RuntimeModeCSV = RuntimeMode("csv")
3232
RuntimeModeCDI = RuntimeMode("cdi")
33+
RuntimeModeJitCDI = RuntimeMode("jit-cdi")
3334
)
3435

3536
// ResolveAutoMode determines the correct mode for the platform if set to "auto"
@@ -57,9 +58,9 @@ func resolveMode(logger logger.Interface, mode string, image image.CUDA, propert
5758

5859
switch nvinfo.ResolvePlatform() {
5960
case info.PlatformNVML, info.PlatformWSL:
60-
return RuntimeModeLegacy
61+
return RuntimeModeJitCDI
6162
case info.PlatformTegra:
6263
return RuntimeModeCSV
6364
}
64-
return RuntimeModeLegacy
65+
return RuntimeModeJitCDI
6566
}

internal/info/auto_test.go

+17-12
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,15 @@ func TestResolveAutoMode(t *testing.T) {
4444
expectedMode: "not-auto",
4545
},
4646
{
47-
description: "no info defaults to legacy",
47+
description: "legacy resolves to legacy",
48+
mode: "legacy",
49+
expectedMode: "legacy",
50+
},
51+
{
52+
description: "no info defaults to jit-cdi",
4853
mode: "auto",
4954
info: map[string]bool{},
50-
expectedMode: "legacy",
55+
expectedMode: "jit-cdi",
5156
},
5257
{
5358
description: "non-nvml, non-tegra, nvgpu resolves to csv",
@@ -80,14 +85,14 @@ func TestResolveAutoMode(t *testing.T) {
8085
expectedMode: "csv",
8186
},
8287
{
83-
description: "nvml, non-tegra, non-nvgpu resolves to legacy",
88+
description: "nvml, non-tegra, non-nvgpu resolves to jit-cdi",
8489
mode: "auto",
8590
info: map[string]bool{
8691
"nvml": true,
8792
"tegra": false,
8893
"nvgpu": false,
8994
},
90-
expectedMode: "legacy",
95+
expectedMode: "jit-cdi",
9196
},
9297
{
9398
description: "nvml, non-tegra, nvgpu resolves to csv",
@@ -100,14 +105,14 @@ func TestResolveAutoMode(t *testing.T) {
100105
expectedMode: "csv",
101106
},
102107
{
103-
description: "nvml, tegra, non-nvgpu resolves to legacy",
108+
description: "nvml, tegra, non-nvgpu resolves to jit-cdi",
104109
mode: "auto",
105110
info: map[string]bool{
106111
"nvml": true,
107112
"tegra": true,
108113
"nvgpu": false,
109114
},
110-
expectedMode: "legacy",
115+
expectedMode: "jit-cdi",
111116
},
112117
{
113118
description: "nvml, tegra, nvgpu resolves to csv",
@@ -136,7 +141,7 @@ func TestResolveAutoMode(t *testing.T) {
136141
},
137142
},
138143
{
139-
description: "at least one non-cdi device resolves to legacy",
144+
description: "at least one non-cdi device resolves to jit-cdi",
140145
mode: "auto",
141146
envmap: map[string]string{
142147
"NVIDIA_VISIBLE_DEVICES": "nvidia.com/gpu=0,0",
@@ -146,7 +151,7 @@ func TestResolveAutoMode(t *testing.T) {
146151
"tegra": false,
147152
"nvgpu": false,
148153
},
149-
expectedMode: "legacy",
154+
expectedMode: "jit-cdi",
150155
},
151156
{
152157
description: "at least one non-cdi device resolves to csv",
@@ -170,7 +175,7 @@ func TestResolveAutoMode(t *testing.T) {
170175
expectedMode: "cdi",
171176
},
172177
{
173-
description: "cdi mount and non-CDI devices resolves to legacy",
178+
description: "cdi mount and non-CDI devices resolves to jit-cdi",
174179
mode: "auto",
175180
mounts: []string{
176181
"/var/run/nvidia-container-devices/cdi/nvidia.com/gpu/0",
@@ -181,10 +186,10 @@ func TestResolveAutoMode(t *testing.T) {
181186
"tegra": false,
182187
"nvgpu": false,
183188
},
184-
expectedMode: "legacy",
189+
expectedMode: "jit-cdi",
185190
},
186191
{
187-
description: "cdi mount and non-CDI envvar resolves to legacy",
192+
description: "cdi mount and non-CDI envvar resolves to jit-cdi",
188193
mode: "auto",
189194
envmap: map[string]string{
190195
"NVIDIA_VISIBLE_DEVICES": "0",
@@ -197,7 +202,7 @@ func TestResolveAutoMode(t *testing.T) {
197202
"tegra": false,
198203
"nvgpu": false,
199204
},
200-
expectedMode: "legacy",
205+
expectedMode: "jit-cdi",
201206
},
202207
}
203208

internal/modifier/cdi.go

+41-19
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,22 @@ import (
3131
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
3232
)
3333

34+
const (
35+
automaticDeviceVendor = "runtime.nvidia.com"
36+
automaticDeviceClass = "gpu"
37+
automaticDeviceKind = automaticDeviceVendor + "/" + automaticDeviceClass
38+
)
39+
3440
// NewCDIModifier creates an OCI spec modifier that determines the modifications to make based on the
3541
// CDI specifications available on the system. The NVIDIA_VISIBLE_DEVICES environment variable is
3642
// used to select the devices to include.
37-
func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) {
38-
devices, err := getDevicesFromSpec(logger, ociSpec, cfg)
43+
func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec, isJitCDI bool) (oci.SpecModifier, error) {
44+
defaultKind := cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.DefaultKind
45+
if isJitCDI {
46+
defaultKind = automaticDeviceKind
47+
}
48+
49+
devices, err := getDevicesFromSpec(logger, ociSpec, cfg, defaultKind)
3950
if err != nil {
4051
return nil, fmt.Errorf("failed to get required devices from OCI specification: %v", err)
4152
}
@@ -65,7 +76,7 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spe
6576
)
6677
}
6778

68-
func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.Config) ([]string, error) {
79+
func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.Config, defaultKind string) ([]string, error) {
6980
rawSpec, err := ociSpec.Load()
7081
if err != nil {
7182
return nil, fmt.Errorf("failed to load OCI spec: %v", err)
@@ -83,26 +94,16 @@ func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.C
8394
if err != nil {
8495
return nil, err
8596
}
86-
if cfg.AcceptDeviceListAsVolumeMounts {
87-
mountDevices := container.CDIDevicesFromMounts()
88-
if len(mountDevices) > 0 {
89-
return mountDevices, nil
90-
}
91-
}
9297

9398
var devices []string
94-
seen := make(map[string]bool)
95-
for _, name := range container.VisibleDevicesFromEnvVar() {
96-
if !parser.IsQualifiedName(name) {
97-
name = fmt.Sprintf("%s=%s", cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.DefaultKind, name)
98-
}
99-
if seen[name] {
100-
logger.Debugf("Ignoring duplicate device %q", name)
101-
continue
99+
if cfg.AcceptDeviceListAsVolumeMounts {
100+
devices = normalizeDeviceList(logger, defaultKind, append(container.DevicesFromMounts(), container.CDIDevicesFromMounts()...)...)
101+
if len(devices) > 0 {
102+
return devices, nil
102103
}
103-
devices = append(devices, name)
104104
}
105105

106+
devices = normalizeDeviceList(logger, defaultKind, container.VisibleDevicesFromEnvVar()...)
106107
if len(devices) == 0 {
107108
return nil, nil
108109
}
@@ -116,6 +117,25 @@ func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.C
116117
return nil, nil
117118
}
118119

120+
func normalizeDeviceList(logger logger.Interface, defaultKind string, devices ...string) []string {
121+
fmt.Printf("devices = %v\n", devices)
122+
seen := make(map[string]bool)
123+
var normalized []string
124+
for _, name := range devices {
125+
if !parser.IsQualifiedName(name) {
126+
name = fmt.Sprintf("%s=%s", defaultKind, name)
127+
}
128+
if seen[name] {
129+
logger.Debugf("Ignoring duplicate device %q", name)
130+
continue
131+
}
132+
normalized = append(normalized, name)
133+
seen[name] = true
134+
}
135+
136+
return normalized
137+
}
138+
119139
// getAnnotationDevices returns a list of devices specified in the annotations.
120140
// Keys starting with the specified prefixes are considered and expected to contain a comma-separated list of
121141
// fully-qualified CDI devices names. If any device name is not fully-quality an error is returned.
@@ -156,7 +176,7 @@ func filterAutomaticDevices(devices []string) []string {
156176
var automatic []string
157177
for _, device := range devices {
158178
vendor, class, _ := parser.ParseDevice(device)
159-
if vendor == "runtime.nvidia.com" && class == "gpu" {
179+
if vendor == automaticDeviceVendor && class == automaticDeviceClass {
160180
automatic = append(automatic, device)
161181
}
162182
}
@@ -165,6 +185,8 @@ func filterAutomaticDevices(devices []string) []string {
165185

166186
func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, devices []string) (oci.SpecModifier, error) {
167187
logger.Debugf("Generating in-memory CDI specs for devices %v", devices)
188+
// TODO: We should try to load the kernel modules and create the device nodes here.
189+
// Failures should raise a warning and not error out.
168190
spec, err := generateAutomaticCDISpec(logger, cfg, devices)
169191
if err != nil {
170192
return nil, fmt.Errorf("failed to generate CDI spec: %w", err)

internal/modifier/cdi_test.go

+40
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,49 @@ import (
2020
"fmt"
2121
"testing"
2222

23+
"github.com/opencontainers/runtime-spec/specs-go"
24+
testlog "github.com/sirupsen/logrus/hooks/test"
2325
"github.com/stretchr/testify/require"
26+
27+
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
28+
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
2429
)
2530

31+
func TestGetDevicesFromSpec(t *testing.T) {
32+
logger, _ := testlog.NewNullLogger()
33+
testCases := []struct {
34+
description string
35+
spec *specs.Spec
36+
config *config.Config
37+
defaultKind string
38+
expectedDevices []string
39+
}{
40+
{
41+
description: "NVIDIA_VISIBLE_DEVICES=all",
42+
spec: &specs.Spec{
43+
Process: &specs.Process{
44+
Env: []string{"NVIDIA_VISIBLE_DEVICES=all"},
45+
},
46+
},
47+
config: func() *config.Config {
48+
c, _ := config.GetDefault()
49+
return c
50+
}(),
51+
defaultKind: "runtime.nvidia.com/gpu",
52+
expectedDevices: []string{"runtime.nvidia.com/gpu=all"},
53+
},
54+
}
55+
56+
for _, tc := range testCases {
57+
t.Run(tc.description, func(t *testing.T) {
58+
devices, err := getDevicesFromSpec(logger, oci.NewMemorySpec(tc.spec), tc.config, tc.defaultKind)
59+
require.NoError(t, err)
60+
61+
require.EqualValues(t, tc.expectedDevices, devices)
62+
})
63+
}
64+
}
65+
2666
func TestGetAnnotationDevices(t *testing.T) {
2767
testCases := []struct {
2868
description string

internal/runtime/runtime_factory.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ func newModeModifier(logger logger.Interface, mode info.RuntimeMode, cfg *config
111111
return modifier.NewStableRuntimeModifier(logger, cfg.NVIDIAContainerRuntimeHookConfig.Path), nil
112112
case info.RuntimeModeCSV:
113113
return modifier.NewCSVModifier(logger, cfg, image)
114-
case info.RuntimeModeCDI:
115-
return modifier.NewCDIModifier(logger, cfg, ociSpec)
114+
case info.RuntimeModeCDI, info.RuntimeModeJitCDI:
115+
return modifier.NewCDIModifier(logger, cfg, ociSpec, mode == info.RuntimeModeJitCDI)
116116
}
117117

118118
return nil, fmt.Errorf("invalid runtime mode: %v", cfg.NVIDIAContainerRuntimeConfig.Mode)
@@ -121,7 +121,7 @@ func newModeModifier(logger logger.Interface, mode info.RuntimeMode, cfg *config
121121
// supportedModifierTypes returns the modifiers supported for a specific runtime mode.
122122
func supportedModifierTypes(mode info.RuntimeMode) []string {
123123
switch mode {
124-
case info.RuntimeModeCDI:
124+
case info.RuntimeModeCDI, info.RuntimeModeJitCDI:
125125
// For CDI mode we make no additional modifications.
126126
return []string{"nvidia-hook-remover", "mode"}
127127
case info.RuntimeModeCSV:

0 commit comments

Comments
 (0)