Skip to content

Commit f709e6f

Browse files
committed
Support setting K8s service account in config file
for all Pods in RayCluster Signed-off-by: David Xia <[email protected]>
1 parent 0e64af1 commit f709e6f

File tree

6 files changed

+50
-32
lines changed

6 files changed

+50
-32
lines changed

kubectl-plugin/pkg/cmd/create/create_cluster.go

+4-6
Original file line numberDiff line numberDiff line change
@@ -206,13 +206,11 @@ func (options *CreateClusterOptions) Validate(cmd *cobra.Command) error {
206206
}
207207
}
208208
}
209-
// we must assign gke-tpu-accelerator and gke-tpu-topology in nodeSelector
210-
// if worker-tpu is not 0
211-
if options.workerTPU != "" && options.workerTPU != "0" {
212-
if err := util.ValidateTPUNodeSelector(options.numOfHosts, options.workerNodeSelectors); err != nil {
213-
return fmt.Errorf("%w", err)
214-
}
209+
210+
if err := util.ValidateTPU(&options.workerTPU, &options.numOfHosts, options.workerNodeSelectors); err != nil {
211+
return fmt.Errorf("%w", err)
215212
}
213+
216214
return nil
217215
}
218216

kubectl-plugin/pkg/cmd/create/create_workergroup.go

+2-6
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,8 @@ func (options *CreateWorkerGroupOptions) Complete(cmd *cobra.Command, args []str
137137
}
138138

139139
func (options *CreateWorkerGroupOptions) Validate() error {
140-
// we must assign gke-tpu-accelerator and gke-tpu-topology in nodeSelector
141-
// if worker-tpu is not 0
142-
if options.workerTPU != "0" {
143-
if err := util.ValidateTPUNodeSelector(options.numOfHosts, options.workerNodeSelectors); err != nil {
144-
return fmt.Errorf("%w", err)
145-
}
140+
if err := util.ValidateTPU(&options.workerTPU, &options.numOfHosts, options.workerNodeSelectors); err != nil {
141+
return fmt.Errorf("%w", err)
146142
}
147143
return nil
148144
}

kubectl-plugin/pkg/util/generation/generation.go

+16-3
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ type RayClusterConfig struct {
3333
Labels map[string]string `yaml:"labels,omitempty"`
3434
Annotations map[string]string `yaml:"annotations,omitempty"`
3535

36-
RayVersion *string `yaml:"ray-version,omitempty"`
37-
Image *string `yaml:"image,omitempty"`
36+
RayVersion *string `yaml:"ray-version,omitempty"`
37+
Image *string `yaml:"image,omitempty"`
38+
ServiceAccount *string `yaml:"service-account,omitempty"`
3839

3940
Head *Head `yaml:"head,omitempty"`
4041

@@ -242,6 +243,10 @@ func (rayClusterConfig *RayClusterConfig) generateRayClusterSpec() *rayv1ac.RayC
242243
if workerGroup.NumOfHosts != nil {
243244
workerGroupSpecs[i].WithNumOfHosts(*workerGroup.NumOfHosts)
244245
}
246+
247+
if rayClusterConfig.ServiceAccount != nil && *rayClusterConfig.ServiceAccount != "" {
248+
workerGroupSpecs[i].Template.Spec.ServiceAccountName = ptr.To(*rayClusterConfig.ServiceAccount)
249+
}
245250
}
246251

247252
rayClusterSpec := rayv1ac.RayClusterSpec().
@@ -262,6 +267,10 @@ func (rayClusterConfig *RayClusterConfig) generateRayClusterSpec() *rayv1ac.RayC
262267
corev1ac.ContainerPort().WithContainerPort(10001).WithName("client")))))).
263268
WithWorkerGroupSpecs(workerGroupSpecs...)
264269

270+
if rayClusterConfig.ServiceAccount != nil && *rayClusterConfig.ServiceAccount != "" {
271+
rayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName = ptr.To(*rayClusterConfig.ServiceAccount)
272+
}
273+
265274
if rayClusterConfig.GKE != nil {
266275
setGCSFuseOptions(rayClusterSpec, rayClusterConfig.GKE.GCSFuse)
267276
}
@@ -400,7 +409,6 @@ func newRayClusterConfigWithDefaults() *RayClusterConfig {
400409
}
401410

402411
// ParseConfigFile parses the YAML configuration file into a RayClusterConfig object
403-
404412
func ParseConfigFile(filePath string) (*RayClusterConfig, error) {
405413
if _, err := os.Stat(filePath); os.IsNotExist(err) {
406414
return nil, fmt.Errorf("config file %s does not exist", filePath)
@@ -443,6 +451,7 @@ func ValidateConfig(config *RayClusterConfig) error {
443451
workerResourceFields := map[string]*string{
444452
"cpu": workerGroup.CPU,
445453
"gpu": workerGroup.GPU,
454+
"tpu": workerGroup.TPU,
446455
"memory": workerGroup.Memory,
447456
"ephemeral-storage": workerGroup.EphemeralStorage,
448457
}
@@ -455,6 +464,10 @@ func ValidateConfig(config *RayClusterConfig) error {
455464
return fmt.Errorf("%w", err)
456465
}
457466
}
467+
468+
if err := util.ValidateTPU(workerGroup.TPU, workerGroup.NumOfHosts, workerGroup.NodeSelectors); err != nil {
469+
return fmt.Errorf("%w", err)
470+
}
458471
}
459472

460473
if config.GKE != nil && config.GKE.GCSFuse != nil {

kubectl-plugin/pkg/util/generation/generation_test.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,9 @@ func TestGenerateResources(t *testing.T) {
301301

302302
func TestGenerateRayClusterSpec(t *testing.T) {
303303
testRayClusterConfig := RayClusterConfig{
304-
RayVersion: ptr.To("1.2.3"),
305-
Image: ptr.To("rayproject/ray:1.2.3"),
304+
RayVersion: ptr.To("1.2.3"),
305+
Image: ptr.To("rayproject/ray:1.2.3"),
306+
ServiceAccount: ptr.To("my-service-account"),
306307
Head: &Head{
307308
CPU: ptr.To("1"),
308309
Memory: ptr.To("5Gi"),
@@ -345,6 +346,7 @@ func TestGenerateRayClusterSpec(t *testing.T) {
345346
RayStartParams: map[string]string{"dashboard-host": "0.0.0.0", "softmax": "GELU"},
346347
Template: &corev1ac.PodTemplateSpecApplyConfiguration{
347348
Spec: &corev1ac.PodSpecApplyConfiguration{
349+
ServiceAccountName: ptr.To("my-service-account"),
348350
Containers: []corev1ac.ContainerApplyConfiguration{
349351
{
350352
Name: ptr.To("ray-head"),
@@ -393,6 +395,7 @@ func TestGenerateRayClusterSpec(t *testing.T) {
393395
RayStartParams: map[string]string{"metrics-export-port": "8080"},
394396
Template: &corev1ac.PodTemplateSpecApplyConfiguration{
395397
Spec: &corev1ac.PodSpecApplyConfiguration{
398+
ServiceAccountName: ptr.To("my-service-account"),
396399
Containers: []corev1ac.ContainerApplyConfiguration{
397400
{
398401
Name: ptr.To("ray-worker"),
@@ -423,6 +426,7 @@ func TestGenerateRayClusterSpec(t *testing.T) {
423426
},
424427
Template: &corev1ac.PodTemplateSpecApplyConfiguration{
425428
Spec: &corev1ac.PodSpecApplyConfiguration{
429+
ServiceAccountName: ptr.To("my-service-account"),
426430
Containers: []corev1ac.ContainerApplyConfiguration{
427431
{
428432
Name: ptr.To("ray-worker"),

kubectl-plugin/pkg/util/validation.go

+10-6
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"k8s.io/apimachinery/pkg/api/resource"
77
)
88

9+
const tpuDocURL = "https://cloud.google.com/kubernetes-engine/docs/concepts/plan-tpus#availability"
10+
911
func ValidateResourceQuantity(value string, name string) error {
1012
if value == "" {
1113
return nil
@@ -21,22 +23,24 @@ func ValidateResourceQuantity(value string, name string) error {
2123
return nil
2224
}
2325

24-
func ValidateTPUNodeSelector(numOfHosts int32, nodeSelector map[string]string) error {
26+
func ValidateTPU(tpu *string, numOfHosts *int32, nodeSelector map[string]string) error {
27+
if tpu == nil || *tpu == "" || *tpu == "0" {
28+
return nil
29+
}
30+
2531
// @TODO:
2632
// In the future we could validate that the accelerator and topology nodeSelectors are supported values,
2733
// and also validate the value for numOfHosts since it is deterministic based on the previous two values.
2834
// https://github.com/ray-project/kuberay/pull/3258#discussion_r2027973436
2935

30-
const docURL = "https://cloud.google.com/kubernetes-engine/docs/concepts/plan-tpus#availability"
31-
32-
if numOfHosts == 0 {
36+
if numOfHosts != nil && *numOfHosts == 0 {
3337
return fmt.Errorf("numOfHosts cannot be 0 when using TPU")
3438
}
3539
if _, ok := nodeSelector[NodeSelectorGKETPUAccelerator]; !ok {
36-
return fmt.Errorf("%s is not set in --worker-node-selectors. See %s for supported values", NodeSelectorGKETPUAccelerator, docURL)
40+
return fmt.Errorf("%s is not set in --worker-node-selectors. See %s for supported values", NodeSelectorGKETPUAccelerator, tpuDocURL)
3741
}
3842
if _, ok := nodeSelector[NodeSelectorGKETPUTopology]; !ok {
39-
return fmt.Errorf("%s is not set in --worker-node-selectors. See %s for supported values", NodeSelectorGKETPUTopology, docURL)
43+
return fmt.Errorf("%s is not set in --worker-node-selectors. See %s for supported values", NodeSelectorGKETPUTopology, tpuDocURL)
4044
}
4145
return nil
4246
}

kubectl-plugin/pkg/util/validation_test.go

+12-9
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,25 @@ func TestValidateResourceQuantity(t *testing.T) {
3131

3232
func TestValidateTPUNodeSelector(t *testing.T) {
3333
tests := []struct {
34-
nodeSelector map[string]string
34+
tpu string
3535
numOfHosts int32
36+
nodeSelector map[string]string
3637
wantErr bool
3738
}{
38-
{map[string]string{}, 1, true},
39-
{map[string]string{NodeSelectorGKETPUAccelerator: "v2"}, 1, true},
40-
{map[string]string{NodeSelectorGKETPUTopology: "topology-1"}, 1, true},
41-
{map[string]string{NodeSelectorGKETPUAccelerator: "v2", NodeSelectorGKETPUTopology: "topology-1"}, 0, true},
42-
{map[string]string{NodeSelectorGKETPUAccelerator: "v2"}, 0, true},
43-
{map[string]string{NodeSelectorGKETPUTopology: "topology-1"}, 0, true},
44-
{map[string]string{NodeSelectorGKETPUAccelerator: "v2", NodeSelectorGKETPUTopology: "topology-1"}, 1, false},
39+
{"", 1, map[string]string{}, false},
40+
{"0", 1, map[string]string{}, false},
41+
{"1", 1, map[string]string{}, true},
42+
{"1", 1, map[string]string{NodeSelectorGKETPUAccelerator: "v2"}, true},
43+
{"1", 1, map[string]string{NodeSelectorGKETPUTopology: "topology-1"}, true},
44+
{"1", 0, map[string]string{NodeSelectorGKETPUAccelerator: "v2", NodeSelectorGKETPUTopology: "topology-1"}, true},
45+
{"1", 0, map[string]string{NodeSelectorGKETPUAccelerator: "v2"}, true},
46+
{"1", 0, map[string]string{NodeSelectorGKETPUTopology: "topology-1"}, true},
47+
{"1", 1, map[string]string{NodeSelectorGKETPUAccelerator: "v2", NodeSelectorGKETPUTopology: "topology-1"}, false},
4548
}
4649

4750
for _, tt := range tests {
4851
t.Run(fmt.Sprintf("%v", tt.nodeSelector), func(t *testing.T) {
49-
err := ValidateTPUNodeSelector(tt.numOfHosts, tt.nodeSelector)
52+
err := ValidateTPU(&tt.tpu, &tt.numOfHosts, tt.nodeSelector)
5053
if (err != nil) != tt.wantErr {
5154
t.Errorf("ValidateTPUNodeSelector() = %v, wantErr %v", err, tt.wantErr)
5255
}

0 commit comments

Comments
 (0)