Skip to content

Commit a988339

Browse files
committed
Add cuda-compat hook to allow compat libs to be discovered
This change adds an nvidia-cdi-hook cuda-compat hook that checks the container for cuda compat libs and updates /etc/ld.so.conf.d to include their parent folder if their driver major version is sufficient. This allows CUDA Forward Compatibility to be used when this is not available through the libnvidia-container. Signed-off-by: Evan Lezar <[email protected]>
1 parent 92472bd commit a988339

File tree

3 files changed

+350
-0
lines changed

3 files changed

+350
-0
lines changed

cmd/nvidia-cdi-hook/commands/commands.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121

2222
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/chmod"
2323
symlinks "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/create-symlinks"
24+
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/cudacompat"
2425
ldcache "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/update-ldcache"
2526
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
2627
)
@@ -32,5 +33,6 @@ func New(logger logger.Interface) []*cli.Command {
3233
ldcache.NewCommand(logger),
3334
symlinks.NewCommand(logger),
3435
chmod.NewCommand(logger),
36+
cudacompat.NewCommand(logger),
3537
}
3638
}
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
/**
2+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package cudacompat
18+
19+
import (
20+
"fmt"
21+
"os"
22+
"path/filepath"
23+
"strings"
24+
25+
"github.com/moby/sys/symlink"
26+
"github.com/urfave/cli/v2"
27+
28+
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
29+
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
30+
)
31+
32+
const (
33+
cudaCompatPath = "/usr/local/cuda/compat"
34+
)
35+
36+
type command struct {
37+
logger logger.Interface
38+
}
39+
40+
type options struct {
41+
driverVersion string
42+
containerSpec string
43+
}
44+
45+
// NewCommand constructs an cuda-compat command with the specified logger
46+
func NewCommand(logger logger.Interface) *cli.Command {
47+
c := command{
48+
logger: logger,
49+
}
50+
return c.build()
51+
}
52+
53+
// build the cuda-compat command
54+
func (m command) build() *cli.Command {
55+
cfg := options{}
56+
57+
// Create the 'cuda-compat' command
58+
c := cli.Command{
59+
Name: "cuda-compat",
60+
Usage: "This hook ensures that the folder containing the CUDA compat libraries is added to the ldconfig search path if required.",
61+
Before: func(c *cli.Context) error {
62+
return m.validateFlags(c, &cfg)
63+
},
64+
Action: func(c *cli.Context) error {
65+
return m.run(c, &cfg)
66+
},
67+
}
68+
69+
c.Flags = []cli.Flag{
70+
&cli.StringFlag{
71+
Name: "driver-version",
72+
Usage: "Specify the host driver version. If the CUDA compat libraries detected in the container do not have a higher MAJOR version, the hook is a no-op.",
73+
Destination: &cfg.driverVersion,
74+
},
75+
&cli.StringFlag{
76+
Name: "container-spec",
77+
Hidden: true,
78+
Category: "testing-only",
79+
Usage: "Specify the path to the OCI container spec. If empty or '-' the spec will be read from STDIN",
80+
Destination: &cfg.containerSpec,
81+
},
82+
}
83+
84+
return &c
85+
}
86+
87+
func (m command) validateFlags(c *cli.Context, cfg *options) error {
88+
return nil
89+
}
90+
91+
func (m command) run(c *cli.Context, cfg *options) error {
92+
s, err := oci.LoadContainerState(cfg.containerSpec)
93+
if err != nil {
94+
return fmt.Errorf("failed to load container state: %w", err)
95+
}
96+
97+
containerRoot, err := s.GetContainerRoot()
98+
if err != nil {
99+
return fmt.Errorf("failed to determined container root: %w", err)
100+
}
101+
102+
return m.updateCUDACompatLibs(root(containerRoot), cfg)
103+
}
104+
105+
func (m command) updateCUDACompatLibs(containerRoot root, cfg *options) error {
106+
if !containerRoot.hasPath(cudaCompatPath) {
107+
return nil
108+
}
109+
110+
if !containerRoot.hasPath("/etc/ld.so.cache") {
111+
// If there is no ldcache in the container, the hook is a no-op.
112+
return nil
113+
}
114+
if !containerRoot.hasPath("/etc/ld.so.conf.d") {
115+
// If the /etc/ld.so.conf.d folder does not exist in the container, the hook is a no-op.
116+
return nil
117+
}
118+
119+
libs, err := containerRoot.glob(filepath.Join(cudaCompatPath, "libcuda.so.*.*"))
120+
if err != nil {
121+
m.logger.Warningf("Failed to find CUDA compat library: %w", err)
122+
return nil
123+
}
124+
125+
if len(libs) == 0 {
126+
return nil
127+
}
128+
129+
if len(libs) != 1 {
130+
m.logger.Warningf("Unexpected number of CUDA compat libraries: %v", libs)
131+
return nil
132+
}
133+
134+
compatVersion := strings.TrimPrefix(filepath.Base(libs[0]), "libcuda.so.")
135+
compatMajor := strings.SplitN(compatVersion, ".", 2)[0]
136+
137+
driverVersion := cfg.driverVersion
138+
driverMajor := strings.SplitN(driverVersion, ".", 2)[0]
139+
140+
if driverMajor < compatMajor {
141+
return m.createConfig(string(containerRoot), []string{cudaCompatPath})
142+
}
143+
return nil
144+
}
145+
146+
// A root is used to add basic path functionality to a string.
147+
type root string
148+
149+
// hasPath checks whether the specified path exists in the root.
150+
func (r root) hasPath(path string) bool {
151+
resolved, err := r.resolve(path)
152+
if err != nil {
153+
return false
154+
}
155+
if _, err := os.Stat(resolved); err != nil && os.IsNotExist(err) {
156+
return false
157+
}
158+
return true
159+
}
160+
161+
// glob matches the specified pattern in the root.
162+
func (r root) glob(pattern string) ([]string, error) {
163+
patternPath, err := r.resolve(pattern)
164+
if err != nil {
165+
return nil, err
166+
}
167+
return filepath.Glob(patternPath)
168+
}
169+
170+
// resolve returns the absolute path including root path.
171+
// Symlinks are resolved, but are guaranteed to resolve in the root.
172+
func (r root) resolve(path string) (string, error) {
173+
absolute := filepath.Clean(filepath.Join(string(r), path))
174+
return symlink.FollowSymlinkInScope(absolute, string(r))
175+
}
176+
177+
// createConfig creates (or updates) /etc/ld.so.conf.d/00-compat-<RANDOM_STRING>.conf in the container
178+
// to include the required paths.
179+
// Note that the 00-compat prefix is chosen to ensure that these libraries have
180+
// a higher precedence than other libraries on the system.
181+
func (m command) createConfig(root string, folders []string) error {
182+
if len(folders) == 0 {
183+
m.logger.Debugf("No folders to add to /etc/ld.so.conf")
184+
return nil
185+
}
186+
187+
if err := os.MkdirAll(filepath.Join(root, "/etc/ld.so.conf.d"), 0755); err != nil {
188+
return fmt.Errorf("failed to create ld.so.conf.d: %w", err)
189+
}
190+
191+
configFile, err := os.CreateTemp(filepath.Join(root, "/etc/ld.so.conf.d"), "00-compat-*.conf")
192+
if err != nil {
193+
return fmt.Errorf("failed to create config file: %w", err)
194+
}
195+
defer configFile.Close()
196+
197+
m.logger.Debugf("Adding folders %v to %v", folders, configFile.Name())
198+
199+
configured := make(map[string]bool)
200+
for _, folder := range folders {
201+
if configured[folder] {
202+
continue
203+
}
204+
_, err = configFile.WriteString(fmt.Sprintf("%s\n", folder))
205+
if err != nil {
206+
return fmt.Errorf("failed to update ld.so.conf.d: %w", err)
207+
}
208+
configured[folder] = true
209+
}
210+
211+
// The created file needs to be world readable for the cases where the container is run as a non-root user.
212+
if err := os.Chmod(configFile.Name(), 0644); err != nil {
213+
return fmt.Errorf("failed to chmod config file: %w", err)
214+
}
215+
216+
return nil
217+
}
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
*/
16+
17+
package cudacompat
18+
19+
import (
20+
"os"
21+
"path/filepath"
22+
"testing"
23+
24+
testlog "github.com/sirupsen/logrus/hooks/test"
25+
"github.com/stretchr/testify/require"
26+
)
27+
28+
func TestCompatLibs(t *testing.T) {
29+
logger, _ := testlog.NewNullLogger()
30+
31+
testCases := []struct {
32+
description string
33+
contents map[string]string
34+
options options
35+
expectedCompatContents string
36+
}{
37+
{
38+
description: "empty root",
39+
options: options{
40+
driverVersion: "222.55.66",
41+
},
42+
},
43+
{
44+
description: "compat lib is newer; no ldcache",
45+
contents: map[string]string{
46+
"/usr/local/cuda/compat/libcuda.so.333.88.99": "",
47+
},
48+
options: options{
49+
driverVersion: "222.55.66",
50+
},
51+
},
52+
{
53+
description: "compat lib is newer; ldcache",
54+
contents: map[string]string{
55+
"/etc/ld.so.cache": "",
56+
"/etc/ld.so.conf.d/.hidden": "",
57+
"/usr/local/cuda/compat/libcuda.so.333.88.99": "",
58+
},
59+
options: options{
60+
driverVersion: "222.55.66",
61+
},
62+
expectedCompatContents: "/usr/local/cuda/compat\n",
63+
},
64+
{
65+
description: "compat lib is older; ldcache",
66+
contents: map[string]string{
67+
"/etc/ld.so.cache": "",
68+
"/etc/ld.so.conf.d/.hidden": "",
69+
"/usr/local/cuda/compat/libcuda.so.111.88.99": "",
70+
},
71+
options: options{
72+
driverVersion: "222.55.66",
73+
},
74+
expectedCompatContents: "",
75+
},
76+
{
77+
description: "compat lib has same major version; ldcache",
78+
contents: map[string]string{
79+
"/etc/ld.so.cache": "",
80+
"/etc/ld.so.conf.d/.hidden": "",
81+
"/usr/local/cuda/compat/libcuda.so.222.88.99": "",
82+
},
83+
options: options{
84+
driverVersion: "222.55.66",
85+
},
86+
expectedCompatContents: "",
87+
},
88+
{
89+
description: "driver version empty; ldcache",
90+
contents: map[string]string{
91+
"/etc/ld.so.cache": "",
92+
"/etc/ld.so.conf.d/.hidden": "",
93+
"/usr/local/cuda/compat/libcuda.so.222.88.99": "",
94+
},
95+
options: options{
96+
driverVersion: "",
97+
},
98+
expectedCompatContents: "/usr/local/cuda/compat\n",
99+
},
100+
}
101+
102+
for _, tc := range testCases {
103+
t.Run(tc.description, func(t *testing.T) {
104+
containerRoot := t.TempDir()
105+
for name, contents := range tc.contents {
106+
target := filepath.Join(containerRoot, name)
107+
require.NoError(t, os.MkdirAll(filepath.Dir(target), 0755))
108+
require.NoError(t, os.WriteFile(target, []byte(contents), 0600))
109+
}
110+
111+
c := command{
112+
logger: logger,
113+
}
114+
err := c.updateCUDACompatLibs(root(containerRoot), &tc.options)
115+
require.NoError(t, err)
116+
117+
matches, err := filepath.Glob(filepath.Join(containerRoot, "/etc/ld.so.conf.d/00-compat-*.conf"))
118+
require.NoError(t, err)
119+
120+
if tc.expectedCompatContents == "" {
121+
require.Empty(t, matches)
122+
} else {
123+
require.Len(t, matches, 1)
124+
contents, err := os.ReadFile(matches[0])
125+
require.NoError(t, err)
126+
127+
require.EqualValues(t, tc.expectedCompatContents, string(contents))
128+
}
129+
})
130+
}
131+
}

0 commit comments

Comments
 (0)