@@ -31,11 +31,22 @@ import (
31
31
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
32
32
)
33
33
34
+ const (
35
+ automaticDeviceVendor = "runtime.nvidia.com"
36
+ automaticDeviceClass = "gpu"
37
+ automaticDeviceKind = automaticDeviceVendor + "/" + automaticDeviceClass
38
+ )
39
+
34
40
// NewCDIModifier creates an OCI spec modifier that determines the modifications to make based on the
35
41
// CDI specifications available on the system. The NVIDIA_VISIBLE_DEVICES environment variable is
36
42
// used to select the devices to include.
37
- func NewCDIModifier (logger logger.Interface , cfg * config.Config , ociSpec oci.Spec ) (oci.SpecModifier , error ) {
38
- devices , err := getDevicesFromSpec (logger , ociSpec , cfg )
43
+ func NewCDIModifier (logger logger.Interface , cfg * config.Config , ociSpec oci.Spec , isJitCDI bool ) (oci.SpecModifier , error ) {
44
+ defaultKind := cfg .NVIDIAContainerRuntimeConfig .Modes .CDI .DefaultKind
45
+ if isJitCDI {
46
+ defaultKind = automaticDeviceKind
47
+ }
48
+
49
+ devices , err := getDevicesFromSpec (logger , ociSpec , cfg , defaultKind )
39
50
if err != nil {
40
51
return nil , fmt .Errorf ("failed to get required devices from OCI specification: %v" , err )
41
52
}
@@ -65,7 +76,7 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spe
65
76
)
66
77
}
67
78
68
- func getDevicesFromSpec (logger logger.Interface , ociSpec oci.Spec , cfg * config.Config ) ([]string , error ) {
79
+ func getDevicesFromSpec (logger logger.Interface , ociSpec oci.Spec , cfg * config.Config , defaultKind string ) ([]string , error ) {
69
80
rawSpec , err := ociSpec .Load ()
70
81
if err != nil {
71
82
return nil , fmt .Errorf ("failed to load OCI spec: %v" , err )
@@ -83,26 +94,16 @@ func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.C
83
94
if err != nil {
84
95
return nil , err
85
96
}
86
- if cfg .AcceptDeviceListAsVolumeMounts {
87
- mountDevices := container .CDIDevicesFromMounts ()
88
- if len (mountDevices ) > 0 {
89
- return mountDevices , nil
90
- }
91
- }
92
97
93
98
var devices []string
94
- seen := make (map [string ]bool )
95
- for _ , name := range container .VisibleDevicesFromEnvVar () {
96
- if ! parser .IsQualifiedName (name ) {
97
- name = fmt .Sprintf ("%s=%s" , cfg .NVIDIAContainerRuntimeConfig .Modes .CDI .DefaultKind , name )
98
- }
99
- if seen [name ] {
100
- logger .Debugf ("Ignoring duplicate device %q" , name )
101
- continue
99
+ if cfg .AcceptDeviceListAsVolumeMounts {
100
+ devices = normalizeDeviceList (logger , defaultKind , append (container .DevicesFromMounts (), container .CDIDevicesFromMounts ()... )... )
101
+ if len (devices ) > 0 {
102
+ return devices , nil
102
103
}
103
- devices = append (devices , name )
104
104
}
105
105
106
+ devices = normalizeDeviceList (logger , defaultKind , container .VisibleDevicesFromEnvVar ()... )
106
107
if len (devices ) == 0 {
107
108
return nil , nil
108
109
}
@@ -116,6 +117,25 @@ func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.C
116
117
return nil , nil
117
118
}
118
119
120
+ func normalizeDeviceList (logger logger.Interface , defaultKind string , devices ... string ) []string {
121
+ fmt .Printf ("devices = %v\n " , devices )
122
+ seen := make (map [string ]bool )
123
+ var normalized []string
124
+ for _ , name := range devices {
125
+ if ! parser .IsQualifiedName (name ) {
126
+ name = fmt .Sprintf ("%s=%s" , defaultKind , name )
127
+ }
128
+ if seen [name ] {
129
+ logger .Debugf ("Ignoring duplicate device %q" , name )
130
+ continue
131
+ }
132
+ normalized = append (normalized , name )
133
+ seen [name ] = true
134
+ }
135
+
136
+ return normalized
137
+ }
138
+
119
139
// getAnnotationDevices returns a list of devices specified in the annotations.
120
140
// Keys starting with the specified prefixes are considered and expected to contain a comma-separated list of
121
141
// fully-qualified CDI devices names. If any device name is not fully-quality an error is returned.
@@ -156,7 +176,7 @@ func filterAutomaticDevices(devices []string) []string {
156
176
var automatic []string
157
177
for _ , device := range devices {
158
178
vendor , class , _ := parser .ParseDevice (device )
159
- if vendor == "runtime.nvidia.com" && class == "gpu" {
179
+ if vendor == automaticDeviceVendor && class == automaticDeviceClass {
160
180
automatic = append (automatic , device )
161
181
}
162
182
}
@@ -165,6 +185,8 @@ func filterAutomaticDevices(devices []string) []string {
165
185
166
186
func newAutomaticCDISpecModifier (logger logger.Interface , cfg * config.Config , devices []string ) (oci.SpecModifier , error ) {
167
187
logger .Debugf ("Generating in-memory CDI specs for devices %v" , devices )
188
+ // TODO: We should try to load the kernel modules and create the device nodes here.
189
+ // Failures should raise a warning and not error out.
168
190
spec , err := generateAutomaticCDISpec (logger , cfg , devices )
169
191
if err != nil {
170
192
return nil , fmt .Errorf ("failed to generate CDI spec: %w" , err )
0 commit comments