@@ -25,6 +25,7 @@ import (
25
25
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
26
26
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
27
27
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
28
+ "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
28
29
"github.com/NVIDIA/nvidia-container-toolkit/internal/modifier/cdi"
29
30
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
30
31
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi"
@@ -34,7 +35,7 @@ import (
34
35
// NewCDIModifier creates an OCI spec modifier that determines the modifications to make based on the
35
36
// CDI specifications available on the system. The NVIDIA_VISIBLE_DEVICES environment variable is
36
37
// used to select the devices to include.
37
- func NewCDIModifier (logger logger.Interface , cfg * config.Config , ociSpec oci.Spec ) (oci.SpecModifier , error ) {
38
+ func NewCDIModifier (logger logger.Interface , cfg * config.Config , driver * root. Driver , ociSpec oci.Spec ) (oci.SpecModifier , error ) {
38
39
devices , err := getDevicesFromSpec (logger , ociSpec , cfg )
39
40
if err != nil {
40
41
return nil , fmt .Errorf ("failed to get required devices from OCI specification: %v" , err )
@@ -50,7 +51,7 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spe
50
51
return nil , fmt .Errorf ("requesting a CDI device with vendor 'runtime.nvidia.com' is not supported when requesting other CDI devices" )
51
52
}
52
53
if len (automaticDevices ) > 0 {
53
- automaticModifier , err := newAutomaticCDISpecModifier (logger , cfg , automaticDevices )
54
+ automaticModifier , err := newAutomaticCDISpecModifier (logger , cfg , driver , automaticDevices )
54
55
if err == nil {
55
56
return automaticModifier , nil
56
57
}
@@ -163,9 +164,9 @@ func filterAutomaticDevices(devices []string) []string {
163
164
return automatic
164
165
}
165
166
166
- func newAutomaticCDISpecModifier (logger logger.Interface , cfg * config.Config , devices []string ) (oci.SpecModifier , error ) {
167
+ func newAutomaticCDISpecModifier (logger logger.Interface , cfg * config.Config , driver * root. Driver , devices []string ) (oci.SpecModifier , error ) {
167
168
logger .Debugf ("Generating in-memory CDI specs for devices %v" , devices )
168
- spec , err := generateAutomaticCDISpec (logger , cfg , devices )
169
+ spec , err := generateAutomaticCDISpec (logger , cfg , driver , devices )
169
170
if err != nil {
170
171
return nil , fmt .Errorf ("failed to generate CDI spec: %w" , err )
171
172
}
@@ -180,7 +181,7 @@ func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, de
180
181
return cdiModifier , nil
181
182
}
182
183
183
- func generateAutomaticCDISpec (logger logger.Interface , cfg * config.Config , devices []string ) (spec.Interface , error ) {
184
+ func generateAutomaticCDISpec (logger logger.Interface , cfg * config.Config , driver * root. Driver , devices []string ) (spec.Interface , error ) {
184
185
cdilib , err := nvcdi .New (
185
186
nvcdi .WithLogger (logger ),
186
187
nvcdi .WithNVIDIACDIHookPath (cfg .NVIDIACTKConfig .Path ),
@@ -192,6 +193,11 @@ func generateAutomaticCDISpec(logger logger.Interface, cfg *config.Config, devic
192
193
return nil , fmt .Errorf ("failed to construct CDI library: %w" , err )
193
194
}
194
195
196
+ // TODO: Consider moving this into the nvcdi API.
197
+ if err := driver .LoadKernelModules (cfg .NVIDIAContainerRuntimeConfig .Modes .JitCDI .LoadKernelModules ... ); err != nil {
198
+ logger .Warningf ("Ignoring error(s) loading kernel modules: %v" , err )
199
+ }
200
+
195
201
identifiers := []string {}
196
202
for _ , device := range devices {
197
203
_ , _ , id := parser .ParseDevice (device )
0 commit comments