Skip to content

Commit 16bd33d

Browse files
authored
Merge pull request #1166 from elezar/refactor-cdi-api
Refactor cdi api
2 parents d116a0b + b42048d commit 16bd33d

File tree

14 files changed

+536
-442
lines changed

14 files changed

+536
-442
lines changed

cmd/nvidia-ctk/cdi/generate/generate.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
312312
return nil, fmt.Errorf("failed to create CDI library: %v", err)
313313
}
314314

315-
deviceSpecs, err := cdilib.GetAllDeviceSpecs()
315+
deviceSpecs, err := cdilib.GetDeviceSpecsByID("all")
316316
if err != nil {
317317
return nil, fmt.Errorf("failed to create device CDI specs: %v", err)
318318
}

pkg/nvcdi/api.go

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
package nvcdi
1818

1919
import (
20-
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
2120
"tags.cncf.io/container-device-interface/pkg/cdi"
2221
"tags.cncf.io/container-device-interface/specs-go"
2322

@@ -27,14 +26,22 @@ import (
2726

2827
// Interface defines the API for the nvcdi package
2928
type Interface interface {
30-
GetSpec(...string) (spec.Interface, error)
29+
SpecGenerator
3130
GetCommonEdits() (*cdi.ContainerEdits, error)
32-
GetAllDeviceSpecs() ([]specs.Device, error)
33-
GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error)
34-
GetGPUDeviceSpecs(int, device.Device) ([]specs.Device, error)
35-
GetMIGDeviceEdits(device.Device, device.MigDevice) (*cdi.ContainerEdits, error)
36-
GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) ([]specs.Device, error)
3731
GetDeviceSpecsByID(...string) ([]specs.Device, error)
32+
// Deprecated: GetAllDeviceSpecs is deprecated. Use GetDeviceSpecsByID("all") instead.
33+
GetAllDeviceSpecs() ([]specs.Device, error)
34+
}
35+
36+
// A SpecGenerator is used to generate a complete CDI spec for a collected set
37+
// of devices.
38+
type SpecGenerator interface {
39+
GetSpec(...string) (spec.Interface, error)
40+
}
41+
42+
// A DeviceSpecGenerator is used to generate the specs for one or more devices.
43+
type DeviceSpecGenerator interface {
44+
GetDeviceSpecs() ([]specs.Device, error)
3845
}
3946

4047
// A HookName represents one of the predefined NVIDIA CDI hooks.

pkg/nvcdi/full-gpu-nvml.go

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,41 +19,74 @@ package nvcdi
1919
import (
2020
"fmt"
2121

22-
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
2322
"tags.cncf.io/container-device-interface/pkg/cdi"
2423
"tags.cncf.io/container-device-interface/specs-go"
2524

25+
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
26+
"github.com/NVIDIA/go-nvml/pkg/nvml"
27+
2628
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
2729
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
2830
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/dgpu"
2931
)
3032

31-
// GetGPUDeviceSpecs returns the CDI device specs for the full GPU represented by 'device'.
32-
func (l *nvmllib) GetGPUDeviceSpecs(i int, d device.Device) ([]specs.Device, error) {
33-
edits, err := l.GetGPUDeviceEdits(d)
33+
// A fullGPUDeviceSpecGenerator generates the CDI device specifications for a
34+
// single full GPU.
35+
type fullGPUDeviceSpecGenerator struct {
36+
*nvmllib
37+
id string
38+
index int
39+
device device.Device
40+
}
41+
42+
var _ DeviceSpecGenerator = (*fullGPUDeviceSpecGenerator)(nil)
43+
44+
func (l *nvmllib) newFullGPUDeviceSpecGeneratorFromNVMLDevice(id string, nvmlDevice nvml.Device) (DeviceSpecGenerator, error) {
45+
device, err := l.devicelib.NewDevice(nvmlDevice)
3446
if err != nil {
35-
return nil, fmt.Errorf("failed to get edits for device: %v", err)
47+
return nil, err
3648
}
3749

38-
var deviceSpecs []specs.Device
39-
names, err := l.deviceNamers.GetDeviceNames(i, convert{d})
50+
index, ret := nvmlDevice.GetIndex()
51+
if ret != nvml.SUCCESS {
52+
return nil, fmt.Errorf("failed to get device index: %v", ret)
53+
}
54+
55+
e := &fullGPUDeviceSpecGenerator{
56+
nvmllib: l,
57+
id: id,
58+
index: index,
59+
device: device,
60+
}
61+
return e, nil
62+
}
63+
64+
func (l *fullGPUDeviceSpecGenerator) GetDeviceSpecs() ([]specs.Device, error) {
65+
deviceEdits, err := l.getDeviceEdits()
66+
if err != nil {
67+
return nil, fmt.Errorf("failed to get CDI device edits for identifier %q: %w", l.id, err)
68+
}
69+
70+
names, err := l.getNames()
4071
if err != nil {
41-
return nil, fmt.Errorf("failed to get device name: %v", err)
72+
return nil, fmt.Errorf("failed to get device names: %w", err)
4273
}
74+
75+
var deviceSpecs []specs.Device
4376
for _, name := range names {
44-
spec := specs.Device{
77+
deviceSpec := specs.Device{
4578
Name: name,
46-
ContainerEdits: *edits.ContainerEdits,
79+
ContainerEdits: *deviceEdits.ContainerEdits,
4780
}
48-
deviceSpecs = append(deviceSpecs, spec)
81+
deviceSpecs = append(deviceSpecs, deviceSpec)
4982
}
5083

5184
return deviceSpecs, nil
5285
}
5386

5487
// GetGPUDeviceEdits returns the CDI edits for the full GPU represented by 'device'.
55-
func (l *nvmllib) GetGPUDeviceEdits(d device.Device) (*cdi.ContainerEdits, error) {
56-
device, err := l.newFullGPUDiscoverer(d)
88+
func (l *fullGPUDeviceSpecGenerator) getDeviceEdits() (*cdi.ContainerEdits, error) {
89+
device, err := l.newFullGPUDiscoverer(l.device)
5790
if err != nil {
5891
return nil, fmt.Errorf("failed to create device discoverer: %v", err)
5992
}
@@ -66,8 +99,12 @@ func (l *nvmllib) GetGPUDeviceEdits(d device.Device) (*cdi.ContainerEdits, error
6699
return editsForDevice, nil
67100
}
68101

102+
func (l *fullGPUDeviceSpecGenerator) getNames() ([]string, error) {
103+
return l.deviceNamers.GetDeviceNames(l.index, convert{l.device})
104+
}
105+
69106
// newFullGPUDiscoverer creates a discoverer for the full GPU defined by the specified device.
70-
func (l *nvmllib) newFullGPUDiscoverer(d device.Device) (discover.Discover, error) {
107+
func (l *fullGPUDeviceSpecGenerator) newFullGPUDiscoverer(d device.Device) (discover.Discover, error) {
71108
deviceNodes, err := dgpu.NewForDevice(d,
72109
dgpu.WithDevRoot(l.devRoot),
73110
dgpu.WithLogger(l.logger),

pkg/nvcdi/gds.go

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,23 @@ package nvcdi
1919
import (
2020
"fmt"
2121

22-
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
2322
"tags.cncf.io/container-device-interface/pkg/cdi"
2423
"tags.cncf.io/container-device-interface/specs-go"
2524

2625
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
2726
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
28-
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
2927
)
3028

3129
type gdslib nvcdilib
3230

33-
var _ Interface = (*gdslib)(nil)
31+
var _ deviceSpecGeneratorFactory = (*gdslib)(nil)
3432

35-
// GetAllDeviceSpecs returns the device specs for all available devices.
36-
func (l *gdslib) GetAllDeviceSpecs() ([]specs.Device, error) {
33+
func (l *gdslib) DeviceSpecGenerators(...string) (DeviceSpecGenerator, error) {
34+
return l, nil
35+
}
36+
37+
// GetDeviceSpecs returns the CDI device specs for a single all device.
38+
func (l *gdslib) GetDeviceSpecs() ([]specs.Device, error) {
3739
discoverer, err := discover.NewGDSDiscoverer(l.logger, l.driverRoot, l.devRoot)
3840
if err != nil {
3941
return nil, fmt.Errorf("failed to create GPUDirect Storage discoverer: %v", err)
@@ -55,36 +57,3 @@ func (l *gdslib) GetAllDeviceSpecs() ([]specs.Device, error) {
5557
func (l *gdslib) GetCommonEdits() (*cdi.ContainerEdits, error) {
5658
return edits.FromDiscoverer(discover.None{})
5759
}
58-
59-
// GetSpec is unsppported for the gdslib specs.
60-
// gdslib is typically wrapped by a spec that implements GetSpec.
61-
func (l *gdslib) GetSpec(...string) (spec.Interface, error) {
62-
return nil, fmt.Errorf("GetSpec is not supported")
63-
}
64-
65-
// GetGPUDeviceEdits is unsupported for the gdslib specs
66-
func (l *gdslib) GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error) {
67-
return nil, fmt.Errorf("GetGPUDeviceEdits is not supported")
68-
}
69-
70-
// GetGPUDeviceSpecs is unsupported for the gdslib specs
71-
func (l *gdslib) GetGPUDeviceSpecs(int, device.Device) ([]specs.Device, error) {
72-
return nil, fmt.Errorf("GetGPUDeviceSpecs is not supported")
73-
}
74-
75-
// GetMIGDeviceEdits is unsupported for the gdslib specs
76-
func (l *gdslib) GetMIGDeviceEdits(device.Device, device.MigDevice) (*cdi.ContainerEdits, error) {
77-
return nil, fmt.Errorf("GetMIGDeviceEdits is not supported")
78-
}
79-
80-
// GetMIGDeviceSpecs is unsupported for the gdslib specs
81-
func (l *gdslib) GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) ([]specs.Device, error) {
82-
return nil, fmt.Errorf("GetMIGDeviceSpecs is not supported")
83-
}
84-
85-
// GetDeviceSpecsByID returns the CDI device specs for the GPU(s) represented by
86-
// the provided identifiers, where an identifier is an index or UUID of a valid
87-
// GPU device.
88-
func (l *gdslib) GetDeviceSpecsByID(...string) ([]specs.Device, error) {
89-
return nil, fmt.Errorf("GetDeviceSpecsByID is not supported")
90-
}

pkg/nvcdi/lib-csv.go

Lines changed: 15 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,33 @@ package nvcdi
1919
import (
2020
"fmt"
2121

22-
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
2322
"tags.cncf.io/container-device-interface/pkg/cdi"
2423
"tags.cncf.io/container-device-interface/specs-go"
2524

2625
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
2726
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
2827
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra"
29-
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
3028
)
3129

3230
type csvlib nvcdilib
3331

34-
var _ Interface = (*csvlib)(nil)
32+
var _ deviceSpecGeneratorFactory = (*csvlib)(nil)
3533

36-
// GetSpec should not be called for wsllib
37-
func (l *csvlib) GetSpec(...string) (spec.Interface, error) {
38-
return nil, fmt.Errorf("unexpected call to csvlib.GetSpec()")
34+
func (l *csvlib) DeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error) {
35+
for _, id := range ids {
36+
switch id {
37+
case "all":
38+
case "0":
39+
default:
40+
return nil, fmt.Errorf("unsupported device id: %v", id)
41+
}
42+
}
43+
44+
return l, nil
3945
}
4046

41-
// GetAllDeviceSpecs returns the device specs for all available devices.
42-
func (l *csvlib) GetAllDeviceSpecs() ([]specs.Device, error) {
47+
// GetDeviceSpecs returns the CDI device specs for a single device.
48+
func (l *csvlib) GetDeviceSpecs() ([]specs.Device, error) {
4349
d, err := tegra.New(
4450
tegra.WithLogger(l.logger),
4551
tegra.WithDriverRoot(l.driverRoot),
@@ -76,33 +82,5 @@ func (l *csvlib) GetAllDeviceSpecs() ([]specs.Device, error) {
7682

7783
// GetCommonEdits generates a CDI specification that can be used for ANY devices
7884
func (l *csvlib) GetCommonEdits() (*cdi.ContainerEdits, error) {
79-
d := discover.None{}
80-
return edits.FromDiscoverer(d)
81-
}
82-
83-
// GetGPUDeviceEdits generates a CDI specification that can be used for GPU devices
84-
func (l *csvlib) GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error) {
85-
return nil, fmt.Errorf("GetGPUDeviceEdits is not supported for CSV files")
86-
}
87-
88-
// GetGPUDeviceSpecs returns the CDI device specs for the full GPU represented by 'device'.
89-
func (l *csvlib) GetGPUDeviceSpecs(i int, d device.Device) ([]specs.Device, error) {
90-
return nil, fmt.Errorf("GetGPUDeviceSpecs is not supported for CSV files")
91-
}
92-
93-
// GetMIGDeviceEdits generates a CDI specification that can be used for MIG devices
94-
func (l *csvlib) GetMIGDeviceEdits(device.Device, device.MigDevice) (*cdi.ContainerEdits, error) {
95-
return nil, fmt.Errorf("GetMIGDeviceEdits is not supported for CSV files")
96-
}
97-
98-
// GetMIGDeviceSpecs returns the CDI device specs for the full MIG represented by 'device'.
99-
func (l *csvlib) GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) ([]specs.Device, error) {
100-
return nil, fmt.Errorf("GetMIGDeviceSpecs is not supported for CSV files")
101-
}
102-
103-
// GetDeviceSpecsByID returns the CDI device specs for the GPU(s) represented by
104-
// the provided identifiers, where an identifier is an index or UUID of a valid
105-
// GPU device.
106-
func (l *csvlib) GetDeviceSpecsByID(...string) ([]specs.Device, error) {
107-
return nil, fmt.Errorf("GetDeviceSpecsByID is not supported for CSV files")
85+
return edits.FromDiscoverer(discover.None{})
10886
}

0 commit comments

Comments
 (0)