Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 2 additions & 27 deletions cmd/nvidia-container-runtime-hook/container_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,30 +148,6 @@ func getMigDevices(image image.CUDA, envvar string) *string {
return &devices
}

func (hookConfig *hookConfig) getImexChannels(image image.CUDA, privileged bool) []string {
if hookConfig.Features.IgnoreImexChannelRequests.IsEnabled() {
return nil
}

// If enabled, try and get the device list from volume mounts first
if hookConfig.AcceptDeviceListAsVolumeMounts {
devices := image.ImexChannelsFromMounts()
if len(devices) > 0 {
return devices
}
}
devices := image.ImexChannelsFromEnvVar()
if len(devices) == 0 {
return nil
}

if privileged || hookConfig.AcceptEnvvarUnprivileged {
return devices
}

return nil
}

func (hookConfig *hookConfig) getDriverCapabilities(cudaImage image.CUDA, legacyImage bool) image.DriverCapabilities {
// We use the default driver capabilities by default. This is filtered to only include the
// supported capabilities
Expand Down Expand Up @@ -223,8 +199,6 @@ func (hookConfig *hookConfig) getNvidiaConfig(image image.CUDA, privileged bool)
log.Panicln("cannot set MIG_MONITOR_DEVICES in non privileged container")
}

imexChannels := hookConfig.getImexChannels(image, privileged)

driverCapabilities := hookConfig.getDriverCapabilities(image, legacyImage).String()

requirements, err := image.GetRequirements()
Expand All @@ -236,7 +210,7 @@ func (hookConfig *hookConfig) getNvidiaConfig(image image.CUDA, privileged bool)
Devices: devices,
MigConfigDevices: migConfigDevices,
MigMonitorDevices: migMonitorDevices,
ImexChannels: imexChannels,
ImexChannels: image.ImexChannelRequests(),
DriverCapabilities: driverCapabilities,
Requirements: requirements,
}
Expand Down Expand Up @@ -273,6 +247,7 @@ func (hookConfig *hookConfig) getContainerConfig() (config *containerConfig) {
image.WithAcceptDeviceListAsVolumeMounts(hookConfig.AcceptDeviceListAsVolumeMounts),
image.WithAcceptEnvvarUnprivileged(hookConfig.AcceptEnvvarUnprivileged),
image.WithPreferredVisibleDevicesEnvVars(hookConfig.getSwarmResourceEnvvars()...),
image.WithIgnoreImexChannelRequests(hookConfig.Features.IgnoreImexChannelRequests.IsEnabled()),
)
if err != nil {
log.Panicln(err)
Expand Down
9 changes: 9 additions & 0 deletions internal/config/image/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,15 @@ func WithEnvMap(env map[string]string) Option {
}
}

// WithIgnoreImexChannelRequests sets whether per-container IMEX channel
// requests are supported.
func WithIgnoreImexChannelRequests(ignoreImexChannelRequests bool) Option {
return func(b *builder) error {
b.ignoreImexChannelRequests = ignoreImexChannelRequests
return nil
}
}

// WithLogger sets the logger to use when creating the CUDA image.
func WithLogger(logger logger.Interface) Option {
return func(b *builder) error {
Expand Down
43 changes: 39 additions & 4 deletions internal/config/image/cuda_image.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ type CUDA struct {
annotationsPrefixes []string
acceptDeviceListAsVolumeMounts bool
acceptEnvvarUnprivileged bool
ignoreImexChannelRequests bool
preferredVisibleDeviceEnvVars []string
}

Expand Down Expand Up @@ -412,17 +413,51 @@ func (m cdiDeviceMountRequest) qualifiedName() (string, error) {
return fmt.Sprintf("%s/%s=%s", parts[0], parts[1], parts[2]), nil
}

// ImexChannelsFromEnvVar returns the list of IMEX channels requested for the image.
func (i CUDA) ImexChannelsFromEnvVar() []string {
func (i CUDA) ImexChannelRequests() []string {
if i.ignoreImexChannelRequests {
return nil
}

// If enabled, try and get the device list from volume mounts first
if i.acceptDeviceListAsVolumeMounts {
volumeMountDeviceRequests := i.imexChannelsFromMounts()
if len(volumeMountDeviceRequests) > 0 {
return volumeMountDeviceRequests
}
}

// Get the Fallback to reading from the environment variable if privileges are correct
envVarDeviceRequests := i.imexChannelsFromEnvVar()
if len(envVarDeviceRequests) == 0 {
return nil
}

// If the container is privileged, or environment variable requests are
// allowed for unprivileged containers, these devices are returned.
if i.isPrivileged || i.acceptEnvvarUnprivileged {
return envVarDeviceRequests
}

// We log a warning if we are ignoring the environment variable requests.
envVars := []string{EnvVarNvidiaImexChannels}
if len(envVars) > 0 {
i.logger.Warningf("Ignoring request by environment variable(s) in unprivileged container: %v", envVars)
}

return nil
}

// imexChannelsFromEnvVar returns the list of IMEX channels requested for the image.
func (i CUDA) imexChannelsFromEnvVar() []string {
imexChannels := i.devicesFromEnvvars(EnvVarNvidiaImexChannels)
if len(imexChannels) == 1 && imexChannels[0] == "all" {
return nil
}
return imexChannels
}

// ImexChannelsFromMounts returns the list of IMEX channels requested for the image.
func (i CUDA) ImexChannelsFromMounts() []string {
// imexChannelsFromMounts returns the list of IMEX channels requested for the image.
func (i CUDA) imexChannelsFromMounts() []string {
var channels []string
for _, mountDevice := range i.requestsFromMounts() {
if !strings.HasPrefix(mountDevice, volumeMountDevicePrefixImex) {
Expand Down
2 changes: 1 addition & 1 deletion internal/config/image/cuda_image_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ func TestImexChannelsFromEnvVar(t *testing.T) {
i, err := newCUDAImageFromEnv(append(baseEnvvars, tc.env...))
require.NoError(t, err)

channels := i.ImexChannelsFromEnvVar()
channels := i.imexChannelsFromEnvVar()
require.EqualValues(t, tc.expected, channels)
})
}
Expand Down
43 changes: 36 additions & 7 deletions internal/modifier/cdi.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, image image.CUD
return nil, fmt.Errorf("requesting a CDI device with vendor 'runtime.nvidia.com' is not supported when requesting other CDI devices")
}
if len(automaticDevices) > 0 {
automaticDevices = append(automaticDevices, gatedDevices(image).DeviceRequests()...)
automaticDevices = append(automaticDevices, withUniqueDevices(gatedDevices(image)).DeviceRequests()...)
automaticDevices = append(automaticDevices, withUniqueDevices(imexDevices(image)).DeviceRequests()...)

automaticModifier, err := newAutomaticCDISpecModifier(logger, cfg, automaticDevices)
if err == nil {
return automaticModifier, nil
Expand Down Expand Up @@ -135,6 +137,17 @@ func (g gatedDevices) DeviceRequests() []string {
return devices
}

type imexDevices image.CUDA

func (d imexDevices) DeviceRequests() []string {
var devices []string
i := (image.CUDA)(d)
for _, channelID := range i.ImexChannelRequests() {
devices = append(devices, "mode=imex,id="+channelID)
}
return devices
}

// filterAutomaticDevices searches for "automatic" device names in the input slice.
// "Automatic" devices are a well-defined list of CDI device names which, when requested,
// trigger the generation of a CDI spec at runtime. This removes the need to generate a
Expand All @@ -155,17 +168,21 @@ func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, de

perModeIdentifiers := make(map[string][]string)
perModeDeviceClass := map[string]string{"auto": automaticDeviceClass}
modes := []string{"auto"}
uniqueModes := []string{"auto"}
seen := make(map[string]bool)
for _, device := range devices {
if strings.HasPrefix(device, "mode=") {
modes = append(modes, strings.TrimPrefix(device, "mode="))
continue
mode, id := getModeIdentifier(device)
if !seen[mode] {
uniqueModes = append(uniqueModes, mode)
seen[mode] = true
}
if id != "" {
perModeIdentifiers[id] = append(perModeIdentifiers[id], id)
}
perModeIdentifiers["auto"] = append(perModeIdentifiers["auto"], strings.TrimPrefix(device, automaticDevicePrefix))
}

var modifiers oci.SpecModifiers
for _, mode := range modes {
for _, mode := range uniqueModes {
cdilib, err := nvcdi.New(
nvcdi.WithLogger(logger),
nvcdi.WithNVIDIACDIHookPath(cfg.NVIDIACTKConfig.Path),
Expand Down Expand Up @@ -197,6 +214,18 @@ func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, de
return modifiers, nil
}

func getModeIdentifier(device string) (string, string) {
if !strings.HasPrefix(device, "mode=") {
return "auto", strings.TrimPrefix(device, automaticDevicePrefix)
}
parts := strings.SplitN(device, ",", 2)
mode := strings.TrimPrefix(parts[0], "mode=")
if len(parts) == 2 {
return mode, strings.TrimPrefix(parts[1], "id=")
}
return mode, ""
}

type deduplicatedDeviceRequestor struct {
deviceRequestor
}
Expand Down
1 change: 1 addition & 0 deletions internal/runtime/runtime_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ func initRuntimeModeAndImage(logger logger.Interface, cfg *config.Config, ociSpe
image.WithAcceptDeviceListAsVolumeMounts(cfg.AcceptDeviceListAsVolumeMounts),
image.WithAcceptEnvvarUnprivileged(cfg.AcceptEnvvarUnprivileged),
image.WithAnnotationsPrefixes(cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes),
image.WithIgnoreImexChannelRequests(cfg.Features.IgnoreImexChannelRequests.IsEnabled()),
)
if err != nil {
return "", nil, err
Expand Down