Skip to content

Commit e436533

Browse files
authored
Merge pull request #968 from elezar/allow-hooks-disable
Allow enable-cuda-compat hook to be disabled in CDI spec generation
2 parents ef0b16b + 0f299c3 commit e436533

File tree

9 files changed

+95
-38
lines changed

9 files changed

+95
-38
lines changed

cmd/nvidia-ctk-installer/container/toolkit/toolkit_test.go

-6
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,6 @@ containerEdits:
8080
- libcuda.so.1::/lib/x86_64-linux-gnu/libcuda.so
8181
hookName: createContainer
8282
path: {{ .toolkitRoot }}/nvidia-cdi-hook
83-
- args:
84-
- nvidia-cdi-hook
85-
- enable-cuda-compat
86-
- --host-driver-version=999.88.77
87-
hookName: createContainer
88-
path: {{ .toolkitRoot }}/nvidia-cdi-hook
8983
- args:
9084
- nvidia-cdi-hook
9185
- update-ldcache

pkg/nvcdi/api.go

+10
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,13 @@ type Interface interface {
3535
GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) ([]specs.Device, error)
3636
GetDeviceSpecsByID(...string) ([]specs.Device, error)
3737
}
38+
39+
// A HookName refers to one of the predefined set of CDI hooks that may be
40+
// included in the generated CDI specification.
41+
type HookName string
42+
43+
const (
44+
// HookEnableCudaCompat refers to the hook used to enable CUDA Forward Compatibility.
45+
// This was added with v1.17.5 of the NVIDIA Container Toolkit.
46+
HookEnableCudaCompat = HookName("enable-cuda-compat")
47+
)

pkg/nvcdi/common-nvml.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func (l *nvmllib) newCommonNVMLDiscoverer() (discover.Discover, error) {
4141
l.logger.Warningf("failed to create discoverer for graphics mounts: %v", err)
4242
}
4343

44-
driverFiles, err := NewDriverDiscoverer(l.logger, l.driver, l.nvidiaCDIHookPath, l.ldconfigPath, l.nvmllib)
44+
driverFiles, err := l.NewDriverDiscoverer()
4545
if err != nil {
4646
return nil, fmt.Errorf("failed to create discoverer for driver files: %v", err)
4747
}

pkg/nvcdi/driver-nvml.go

+34-28
Original file line numberDiff line numberDiff line change
@@ -34,41 +34,41 @@ import (
3434

3535
// NewDriverDiscoverer creates a discoverer for the libraries and binaries associated with a driver installation.
3636
// The supplied NVML Library is used to query the expected driver version.
37-
func NewDriverDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath string, ldconfigPath string, nvmllib nvml.Interface) (discover.Discover, error) {
38-
if r := nvmllib.Init(); r != nvml.SUCCESS {
37+
func (l *nvmllib) NewDriverDiscoverer() (discover.Discover, error) {
38+
if r := l.nvmllib.Init(); r != nvml.SUCCESS {
3939
return nil, fmt.Errorf("failed to initialize NVML: %v", r)
4040
}
4141
defer func() {
42-
if r := nvmllib.Shutdown(); r != nvml.SUCCESS {
43-
logger.Warningf("failed to shutdown NVML: %v", r)
42+
if r := l.nvmllib.Shutdown(); r != nvml.SUCCESS {
43+
l.logger.Warningf("failed to shutdown NVML: %v", r)
4444
}
4545
}()
4646

47-
version, r := nvmllib.SystemGetDriverVersion()
47+
version, r := l.nvmllib.SystemGetDriverVersion()
4848
if r != nvml.SUCCESS {
4949
return nil, fmt.Errorf("failed to determine driver version: %v", r)
5050
}
5151

52-
return newDriverVersionDiscoverer(logger, driver, nvidiaCDIHookPath, ldconfigPath, version)
52+
return (*nvcdilib)(l).newDriverVersionDiscoverer(version)
5353
}
5454

55-
func newDriverVersionDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath, ldconfigPath, version string) (discover.Discover, error) {
56-
libraries, err := NewDriverLibraryDiscoverer(logger, driver, nvidiaCDIHookPath, ldconfigPath, version)
55+
func (l *nvcdilib) newDriverVersionDiscoverer(version string) (discover.Discover, error) {
56+
libraries, err := l.NewDriverLibraryDiscoverer(version)
5757
if err != nil {
5858
return nil, fmt.Errorf("failed to create discoverer for driver libraries: %v", err)
5959
}
6060

61-
ipcs, err := discover.NewIPCDiscoverer(logger, driver.Root)
61+
ipcs, err := discover.NewIPCDiscoverer(l.logger, l.driver.Root)
6262
if err != nil {
6363
return nil, fmt.Errorf("failed to create discoverer for IPC sockets: %v", err)
6464
}
6565

66-
firmwares, err := NewDriverFirmwareDiscoverer(logger, driver.Root, version)
66+
firmwares, err := NewDriverFirmwareDiscoverer(l.logger, l.driver.Root, version)
6767
if err != nil {
6868
return nil, fmt.Errorf("failed to create discoverer for GSP firmware: %v", err)
6969
}
7070

71-
binaries := NewDriverBinariesDiscoverer(logger, driver.Root)
71+
binaries := NewDriverBinariesDiscoverer(l.logger, l.driver.Root)
7272

7373
d := discover.Merge(
7474
libraries,
@@ -81,35 +81,41 @@ func newDriverVersionDiscoverer(logger logger.Interface, driver *root.Driver, nv
8181
}
8282

8383
// NewDriverLibraryDiscoverer creates a discoverer for the libraries associated with the specified driver version.
84-
func NewDriverLibraryDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath, ldconfigPath, version string) (discover.Discover, error) {
85-
libraryPaths, err := getVersionLibs(logger, driver, version)
84+
func (l *nvcdilib) NewDriverLibraryDiscoverer(version string) (discover.Discover, error) {
85+
libraryPaths, err := getVersionLibs(l.logger, l.driver, version)
8686
if err != nil {
8787
return nil, fmt.Errorf("failed to get libraries for driver version: %v", err)
8888
}
8989

9090
libraries := discover.NewMounts(
91-
logger,
91+
l.logger,
9292
lookup.NewFileLocator(
93-
lookup.WithLogger(logger),
94-
lookup.WithRoot(driver.Root),
93+
lookup.WithLogger(l.logger),
94+
lookup.WithRoot(l.driver.Root),
9595
),
96-
driver.Root,
96+
l.driver.Root,
9797
libraryPaths,
9898
)
9999

100-
// TODO: The following should use the version directly.
101-
cudaCompatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(logger, nvidiaCDIHookPath, driver)
102-
updateLDCache, _ := discover.NewLDCacheUpdateHook(logger, libraries, nvidiaCDIHookPath, ldconfigPath)
100+
var discoverers []discover.Discover
103101

104-
d := discover.Merge(
105-
discover.WithDriverDotSoSymlinks(
106-
libraries,
107-
version,
108-
nvidiaCDIHookPath,
109-
),
110-
cudaCompatLibHookDiscoverer,
111-
updateLDCache,
102+
driverDotSoSymlinksDiscoverer := discover.WithDriverDotSoSymlinks(
103+
libraries,
104+
version,
105+
l.nvidiaCDIHookPath,
112106
)
107+
discoverers = append(discoverers, driverDotSoSymlinksDiscoverer)
108+
109+
if l.HookIsSupported(HookEnableCudaCompat) {
110+
// TODO: The following should use the version directly.
111+
cudaCompatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(l.logger, l.nvidiaCDIHookPath, l.driver)
112+
discoverers = append(discoverers, cudaCompatLibHookDiscoverer)
113+
}
114+
115+
updateLDCache, _ := discover.NewLDCacheUpdateHook(l.logger, libraries, l.nvidiaCDIHookPath, l.ldconfigPath)
116+
discoverers = append(discoverers, updateLDCache)
117+
118+
d := discover.Merge(discoverers...)
113119

114120
return d, nil
115121
}

pkg/nvcdi/hooks.go

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/**
2+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package nvcdi
18+
19+
// disabledHooks allows individual hooks to be disabled.
20+
type disabledHooks map[HookName]bool
21+
22+
// HookIsSupported checks whether a hook of the specified name is supported.
23+
// Hooks must be explicitly disabled, meaning that if no disabled hooks are
24+
// all hooks are supported.
25+
func (l *nvcdilib) HookIsSupported(h HookName) bool {
26+
if len(l.disabledHooks) == 0 {
27+
return true
28+
}
29+
return !l.disabledHooks[h]
30+
}

pkg/nvcdi/lib.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,15 @@ type nvcdilib struct {
5454
infolib info.Interface
5555

5656
mergedDeviceOptions []transform.MergedDeviceOption
57+
58+
disabledHooks disabledHooks
5759
}
5860

5961
// New creates a new nvcdi library
6062
func New(opts ...Option) (Interface, error) {
61-
l := &nvcdilib{}
63+
l := &nvcdilib{
64+
disabledHooks: make(disabledHooks),
65+
}
6266
for _, opt := range opts {
6367
opt(l)
6468
}
@@ -140,6 +144,8 @@ func New(opts ...Option) (Interface, error) {
140144
if l.vendor == "" {
141145
l.vendor = "management.nvidia.com"
142146
}
147+
// Management containers in general do not require CUDA Forward compatibility.
148+
l.disabledHooks[HookEnableCudaCompat] = true
143149
lib = (*managementlib)(l)
144150
case ModeNvml:
145151
lib = (*nvmllib)(l)

pkg/nvcdi/management.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ func (m *managementlib) GetCommonEdits() (*cdi.ContainerEdits, error) {
8080
return nil, fmt.Errorf("failed to get CUDA version: %v", err)
8181
}
8282

83-
driver, err := newDriverVersionDiscoverer(m.logger, m.driver, m.nvidiaCDIHookPath, m.ldconfigPath, version)
83+
driver, err := (*nvcdilib)(m).newDriverVersionDiscoverer(version)
8484
if err != nil {
8585
return nil, fmt.Errorf("failed to create driver library discoverer: %v", err)
8686
}

pkg/nvcdi/options.go

+11
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,14 @@ func WithLibrarySearchPaths(paths []string) Option {
155155
o.librarySearchPaths = paths
156156
}
157157
}
158+
159+
// WithDisabledHook allows specific hooks to the disabled.
160+
// This option can be specified multiple times for each hook.
161+
func WithDisabledHook(hook HookName) Option {
162+
return func(o *nvcdilib) {
163+
if o.disabledHooks == nil {
164+
o.disabledHooks = make(map[HookName]bool)
165+
}
166+
o.disabledHooks[hook] = true
167+
}
168+
}

0 commit comments

Comments
 (0)