@@ -105,42 +105,72 @@ func (wg *WorkerGenerator) GenerateWorkerPod(
105
105
}, nil
106
106
}
107
107
108
- func SelectWorker (ctx context.Context , k8sClient client.Client , workloadName string , workerStatuses []tfv1.WorkerStatus ) (* tfv1.WorkerStatus , error ) {
109
- if len (workerStatuses ) == 0 {
108
+ func SelectWorker (
109
+ ctx context.Context ,
110
+ k8sClient client.Client ,
111
+ workload * tfv1.TensorFusionWorkload ,
112
+ maxSkew int32 ,
113
+ ) (* tfv1.WorkerStatus , error ) {
114
+ if len (workload .Status .WorkerStatuses ) == 0 {
110
115
return nil , fmt .Errorf ("no available worker" )
111
116
}
112
- usageMapping := make (map [string ]int , len (workerStatuses ))
113
- for _ , workerStatus := range workerStatuses {
117
+ usageMapping := make (map [string ]int , len (workload . Status . WorkerStatuses ))
118
+ for _ , workerStatus := range workload . Status . WorkerStatuses {
114
119
usageMapping [workerStatus .WorkerName ] = 0
115
120
}
116
121
117
122
connectionList := tfv1.TensorFusionConnectionList {}
118
- if err := k8sClient .List (ctx , & connectionList , client.MatchingLabels {constants .WorkloadKey : workloadName }); err != nil {
123
+ if err := k8sClient .List (ctx , & connectionList , client.MatchingLabels {constants .WorkloadKey : workload . Name }); err != nil {
119
124
return nil , fmt .Errorf ("list TensorFusionConnection: %w" , err )
120
125
}
121
126
122
127
for _ , connection := range connectionList .Items {
123
128
if connection .Status .WorkerName != "" {
124
- continue
129
+ usageMapping [ connection . Status . WorkerName ] ++
125
130
}
126
- usageMapping [connection .Status .WorkerName ]++
127
131
}
128
132
129
- var minUsageWorker * tfv1.WorkerStatus
130
- // Initialize with max int value
133
+ // First find the minimum usage
131
134
minUsage := int (^ uint (0 ) >> 1 )
132
- for _ , workerStatus := range workerStatuses {
135
+ // Initialize with max int value
136
+ for _ , workerStatus := range workload .Status .WorkerStatuses {
133
137
if workerStatus .WorkerPhase == tfv1 .WorkerFailed {
134
138
continue
135
139
}
136
140
usage := usageMapping [workerStatus .WorkerName ]
137
141
if usage < minUsage {
138
142
minUsage = usage
139
- minUsageWorker = & workerStatus
140
143
}
141
144
}
142
- if minUsageWorker == nil {
145
+
146
+ // Collect all eligible workers that are within maxSkew of the minimum usage
147
+ var eligibleWorkers []* tfv1.WorkerStatus
148
+ for _ , workerStatus := range workload .Status .WorkerStatuses {
149
+ if workerStatus .WorkerPhase == tfv1 .WorkerFailed {
150
+ continue
151
+ }
152
+ usage := usageMapping [workerStatus .WorkerName ]
153
+ // Worker is eligible if its usage is within maxSkew of the minimum usage
154
+ if usage <= minUsage + int (maxSkew ) {
155
+ eligibleWorkers = append (eligibleWorkers , & workerStatus )
156
+ }
157
+ }
158
+
159
+ if len (eligibleWorkers ) == 0 {
143
160
return nil , fmt .Errorf ("no available worker" )
144
161
}
145
- return minUsageWorker , nil
162
+
163
+ // Choose the worker with the minimum usage among eligible workers
164
+ selectedWorker := eligibleWorkers [0 ]
165
+ selectedUsage := usageMapping [selectedWorker .WorkerName ]
166
+ for i := 1 ; i < len (eligibleWorkers ); i ++ {
167
+ worker := eligibleWorkers [i ]
168
+ usage := usageMapping [worker .WorkerName ]
169
+ if usage < selectedUsage {
170
+ selectedWorker = worker
171
+ selectedUsage = usage
172
+ }
173
+ }
174
+
175
+ return selectedWorker , nil
146
176
}
0 commit comments