Skip to content

Commit f5680dd

Browse files
authored
Merge pull request #948 from elezar/add-compat-lib-hook
Add CUDA forward compatibility hook
2 parents 6b037a0 + c1bac28 commit f5680dd

File tree

9 files changed

+540
-4
lines changed

9 files changed

+540
-4
lines changed

cmd/nvidia-cdi-hook/commands/commands.go

+2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121

2222
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/chmod"
2323
symlinks "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/create-symlinks"
24+
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/cudacompat"
2425
ldcache "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/update-ldcache"
2526
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
2627
)
@@ -32,5 +33,6 @@ func New(logger logger.Interface) []*cli.Command {
3233
ldcache.NewCommand(logger),
3334
symlinks.NewCommand(logger),
3435
chmod.NewCommand(logger),
36+
cudacompat.NewCommand(logger),
3537
}
3638
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/**
2+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package cudacompat
18+
19+
import (
20+
"os"
21+
"path/filepath"
22+
23+
"github.com/moby/sys/symlink"
24+
)
25+
26+
// A containerRoot represents the root filesystem of a container.
27+
type containerRoot string
28+
29+
// hasPath checks whether the specified path exists in the root.
30+
func (r containerRoot) hasPath(path string) bool {
31+
resolved, err := r.resolve(path)
32+
if err != nil {
33+
return false
34+
}
35+
if _, err := os.Stat(resolved); err != nil && os.IsNotExist(err) {
36+
return false
37+
}
38+
return true
39+
}
40+
41+
// globFiles matches the specified pattern in the root.
42+
// The files that match must be regular files.
43+
func (r containerRoot) globFiles(pattern string) ([]string, error) {
44+
patternPath, err := r.resolve(pattern)
45+
if err != nil {
46+
return nil, err
47+
}
48+
matches, err := filepath.Glob(patternPath)
49+
if err != nil {
50+
return nil, err
51+
}
52+
var files []string
53+
for _, match := range matches {
54+
info, err := os.Lstat(match)
55+
if err != nil {
56+
return nil, err
57+
}
58+
// Ignore symlinks.
59+
if info.Mode()&os.ModeSymlink != 0 {
60+
continue
61+
}
62+
// Ignore directories.
63+
if info.IsDir() {
64+
continue
65+
}
66+
files = append(files, match)
67+
}
68+
return files, nil
69+
}
70+
71+
// resolve returns the absolute path including root path.
72+
// Symlinks are resolved, but are guaranteed to resolve in the root.
73+
func (r containerRoot) resolve(path string) (string, error) {
74+
absolute := filepath.Clean(filepath.Join(string(r), path))
75+
return symlink.FollowSymlinkInScope(absolute, string(r))
76+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
/**
2+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package cudacompat
18+
19+
import (
20+
"fmt"
21+
"os"
22+
"path/filepath"
23+
"strconv"
24+
"strings"
25+
26+
"github.com/urfave/cli/v2"
27+
28+
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
29+
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
30+
)
31+
32+
const (
33+
cudaCompatPath = "/usr/local/cuda/compat"
34+
// cudaCompatLdsoconfdFilenamePattern specifies the pattern for the filename
35+
// in ld.so.conf.d that includes a reference to the CUDA compat path.
36+
// The 00-compat prefix is chosen to ensure that these libraries have a
37+
// higher precedence than other libraries on the system.
38+
cudaCompatLdsoconfdFilenamePattern = "00-compat-*.conf"
39+
)
40+
41+
type command struct {
42+
logger logger.Interface
43+
}
44+
45+
type options struct {
46+
hostDriverVersion string
47+
containerSpec string
48+
}
49+
50+
// NewCommand constructs a cuda-compat command with the specified logger
51+
func NewCommand(logger logger.Interface) *cli.Command {
52+
c := command{
53+
logger: logger,
54+
}
55+
return c.build()
56+
}
57+
58+
// build the enable-cuda-compat command
59+
func (m command) build() *cli.Command {
60+
cfg := options{}
61+
62+
// Create the 'enable-cuda-compat' command
63+
c := cli.Command{
64+
Name: "enable-cuda-compat",
65+
Usage: "This hook ensures that the folder containing the CUDA compat libraries is added to the ldconfig search path if required.",
66+
Before: func(c *cli.Context) error {
67+
return m.validateFlags(c, &cfg)
68+
},
69+
Action: func(c *cli.Context) error {
70+
return m.run(c, &cfg)
71+
},
72+
}
73+
74+
c.Flags = []cli.Flag{
75+
&cli.StringFlag{
76+
Name: "host-driver-version",
77+
Usage: "Specify the host driver version. If the CUDA compat libraries detected in the container do not have a higher MAJOR version, the hook is a no-op.",
78+
Destination: &cfg.hostDriverVersion,
79+
},
80+
&cli.StringFlag{
81+
Name: "container-spec",
82+
Hidden: true,
83+
Category: "testing-only",
84+
Usage: "Specify the path to the OCI container spec. If empty or '-' the spec will be read from STDIN",
85+
Destination: &cfg.containerSpec,
86+
},
87+
}
88+
89+
return &c
90+
}
91+
92+
func (m command) validateFlags(_ *cli.Context, cfg *options) error {
93+
return nil
94+
}
95+
96+
func (m command) run(_ *cli.Context, cfg *options) error {
97+
if cfg.hostDriverVersion == "" {
98+
return nil
99+
}
100+
101+
s, err := oci.LoadContainerState(cfg.containerSpec)
102+
if err != nil {
103+
return fmt.Errorf("failed to load container state: %w", err)
104+
}
105+
106+
containerRootDir, err := s.GetContainerRoot()
107+
if err != nil {
108+
return fmt.Errorf("failed to determined container root: %w", err)
109+
}
110+
111+
containerForwardCompatDir, err := m.getContainerForwardCompatDir(containerRoot(containerRootDir), cfg.hostDriverVersion)
112+
if err != nil {
113+
return fmt.Errorf("failed to get container forward compat directory: %w", err)
114+
}
115+
if containerForwardCompatDir == "" {
116+
return nil
117+
}
118+
119+
return m.createLdsoconfdFile(containerRoot(containerRootDir), cudaCompatLdsoconfdFilenamePattern, containerForwardCompatDir)
120+
}
121+
122+
func (m command) getContainerForwardCompatDir(containerRoot containerRoot, hostDriverVersion string) (string, error) {
123+
if hostDriverVersion == "" {
124+
m.logger.Debugf("Host driver version not specified")
125+
return "", nil
126+
}
127+
if !containerRoot.hasPath(cudaCompatPath) {
128+
m.logger.Debugf("No CUDA forward compatibility libraries directory in container")
129+
return "", nil
130+
}
131+
if !containerRoot.hasPath("/etc/ld.so.cache") {
132+
m.logger.Debugf("The container does not have an LDCache")
133+
return "", nil
134+
}
135+
136+
libs, err := containerRoot.globFiles(filepath.Join(cudaCompatPath, "libcuda.so.*.*"))
137+
if err != nil {
138+
m.logger.Warningf("Failed to find CUDA compat library: %w", err)
139+
return "", nil
140+
}
141+
142+
if len(libs) == 0 {
143+
m.logger.Debugf("No CUDA forward compatibility libraries container")
144+
return "", nil
145+
}
146+
147+
if len(libs) != 1 {
148+
m.logger.Warningf("Unexpected number of CUDA compat libraries in container: %v", libs)
149+
return "", nil
150+
}
151+
152+
compatDriverVersion := strings.TrimPrefix(filepath.Base(libs[0]), "libcuda.so.")
153+
compatMajor, err := extractMajorVersion(compatDriverVersion)
154+
if err != nil {
155+
return "", fmt.Errorf("failed to extract major version from %q: %v", compatDriverVersion, err)
156+
}
157+
158+
driverMajor, err := extractMajorVersion(hostDriverVersion)
159+
if err != nil {
160+
return "", fmt.Errorf("failed to extract major version from %q: %v", hostDriverVersion, err)
161+
}
162+
163+
if driverMajor >= compatMajor {
164+
m.logger.Debugf("Compat major version is not greater than the host driver major version (%v >= %v)", hostDriverVersion, compatDriverVersion)
165+
return "", nil
166+
}
167+
168+
resolvedCompatDir := strings.TrimPrefix(filepath.Dir(libs[0]), string(containerRoot))
169+
return resolvedCompatDir, nil
170+
}
171+
172+
// createLdsoconfdFile creates a file at /etc/ld.so.conf.d/ in the specified root.
173+
// The file is created at /etc/ld.so.conf.d/{{ .pattern }} using `CreateTemp` and
174+
// contains the specified directories on each line.
175+
func (m command) createLdsoconfdFile(in containerRoot, pattern string, dirs ...string) error {
176+
if len(dirs) == 0 {
177+
m.logger.Debugf("No directories to add to /etc/ld.so.conf")
178+
return nil
179+
}
180+
181+
ldsoconfdDir, err := in.resolve("/etc/ld.so.conf.d")
182+
if err != nil {
183+
return err
184+
}
185+
if err := os.MkdirAll(ldsoconfdDir, 0755); err != nil {
186+
return fmt.Errorf("failed to create ld.so.conf.d: %w", err)
187+
}
188+
189+
configFile, err := os.CreateTemp(ldsoconfdDir, pattern)
190+
if err != nil {
191+
return fmt.Errorf("failed to create config file: %w", err)
192+
}
193+
defer configFile.Close()
194+
195+
m.logger.Debugf("Adding directories %v to %v", dirs, configFile.Name())
196+
197+
added := make(map[string]bool)
198+
for _, dir := range dirs {
199+
if added[dir] {
200+
continue
201+
}
202+
_, err = configFile.WriteString(fmt.Sprintf("%s\n", dir))
203+
if err != nil {
204+
return fmt.Errorf("failed to update config file: %w", err)
205+
}
206+
added[dir] = true
207+
}
208+
209+
// The created file needs to be world readable for the cases where the container is run as a non-root user.
210+
if err := configFile.Chmod(0644); err != nil {
211+
return fmt.Errorf("failed to chmod config file: %w", err)
212+
}
213+
214+
return nil
215+
}
216+
217+
// extractMajorVersion parses a version string and returns the major version as an int.
218+
func extractMajorVersion(version string) (int, error) {
219+
majorString := strings.SplitN(version, ".", 2)[0]
220+
return strconv.Atoi(majorString)
221+
}

0 commit comments

Comments
 (0)