Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 36 additions & 33 deletions internal/controller/tensorfusionworkload_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/tools/record"
"k8s.io/kubernetes/pkg/controller"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"
Expand All @@ -38,7 +39,6 @@ import (
"github.com/NexusGPU/tensor-fusion/internal/portallocator"
"github.com/NexusGPU/tensor-fusion/internal/utils"
"github.com/NexusGPU/tensor-fusion/internal/worker"
"github.com/samber/lo"
)

// TensorFusionWorkloadReconciler reconciles a TensorFusionWorkload object
Expand Down Expand Up @@ -78,9 +78,7 @@ func (r *TensorFusionWorkloadReconciler) Reconcile(ctx context.Context, req ctrl
return ctrl.Result{}, fmt.Errorf("list pods: %w", err)
}
// only calculate state based on not deleted pods, otherwise will cause wrong total replica count
podList.Items = lo.Filter(podList.Items, func(pod corev1.Pod, _ int) bool {
return pod.DeletionTimestamp.IsZero()
})
activePods := filterActivePods(podList)

// handle finalizer
shouldReturn, err := utils.HandleFinalizer(ctx, workload, r.Client, func(ctx context.Context, workload *tfv1.TensorFusionWorkload) (bool, error) {
Expand All @@ -91,10 +89,10 @@ func (r *TensorFusionWorkloadReconciler) Reconcile(ctx context.Context, req ctrl

// fixed replica mode which created by user, should trigger pod deletion and stop scale up
// when all pods are deleted, finalizer will be removed
if len(podList.Items) == 0 {
if len(activePods) == 0 {
return true, nil
}
if err := r.scaleDownWorkers(ctx, workload, podList.Items); err != nil {
if err := r.scaleDownWorkers(ctx, workload, activePods); err != nil {
return false, err
}
return false, nil
Expand Down Expand Up @@ -140,13 +138,13 @@ func (r *TensorFusionWorkloadReconciler) Reconcile(ctx context.Context, req ctrl
// In this mode, allow any Pod select connection to connect to any worker,
// to achieve a sub-pool for lower costs when CPU side scaling frequency is high
if !workload.Spec.IsDynamicReplica() {
err := r.reconcileScaling(ctx, workload, podList, workerGenerator, podTemplateHash)
err := r.reconcileScaling(ctx, workload, activePods, workerGenerator, podTemplateHash)
if err != nil {
return ctrl.Result{}, err
}
}

if err := r.updateStatus(ctx, workload, podList.Items); err != nil {
if err := r.updateStatus(ctx, workload, activePods); err != nil {
return ctrl.Result{}, err
}

Expand All @@ -157,23 +155,22 @@ func (r *TensorFusionWorkloadReconciler) Reconcile(ctx context.Context, req ctrl
func (r *TensorFusionWorkloadReconciler) reconcileScaling(
ctx context.Context,
workload *tfv1.TensorFusionWorkload,
podList *corev1.PodList,
activePods []*corev1.Pod,
workerGenerator *worker.WorkerGenerator,
podTemplateHash string,
) error {
log := log.FromContext(ctx)
// Check if there are any Pods using the old podTemplateHash and delete them if any
if len(podList.Items) > 0 {
if len(activePods) > 0 {
// make oldest pod first, to delete from oldest to latest outdated pod
sort.Slice(podList.Items, func(i, j int) bool {
return podList.Items[i].CreationTimestamp.Before(&podList.Items[j].CreationTimestamp)
sort.Slice(activePods, func(i, j int) bool {
return activePods[i].CreationTimestamp.Before(&activePods[j].CreationTimestamp)
})

var outdatedPods []corev1.Pod
for i := range podList.Items {
pod := &podList.Items[i]
var outdatedPods []*corev1.Pod
for _, pod := range activePods {
if pod.Labels[constants.LabelKeyPodTemplateHash] != podTemplateHash {
outdatedPods = append(outdatedPods, *pod)
outdatedPods = append(outdatedPods, pod)
}
}

Expand All @@ -194,7 +191,7 @@ func (r *TensorFusionWorkloadReconciler) reconcileScaling(
}

// Count current replicas
currentReplicas := int32(len(podList.Items))
currentReplicas := int32(len(activePods))
log.Info("Current replicas", "count", currentReplicas, "desired", desiredReplicas)

// Update workload status
Expand All @@ -205,26 +202,23 @@ func (r *TensorFusionWorkloadReconciler) reconcileScaling(
}
}

diff := currentReplicas - desiredReplicas
// Scale up if needed
if currentReplicas < desiredReplicas {
if diff < 0 {
log.Info("Scaling up workers", "from", currentReplicas, "to", desiredReplicas)

// Calculate how many pods need to be added
podsToAdd := int(desiredReplicas - currentReplicas)
if err := r.scaleUpWorkers(ctx, workerGenerator, workload, podsToAdd, podTemplateHash); err != nil {
if err := r.scaleUpWorkers(ctx, workerGenerator, workload, int(-diff), podTemplateHash); err != nil {
return fmt.Errorf("scale up workers: %w", err)
}
} else if currentReplicas > desiredReplicas {
} else if diff > 0 {
log.Info("Scaling down workers", "from", currentReplicas, "to", desiredReplicas)

// Sort pods by creation time (oldest first)
sort.Slice(podList.Items, func(i, j int) bool {
return podList.Items[i].CreationTimestamp.Before(&podList.Items[j].CreationTimestamp)
})
// No need to sort if we are about to delete all pods
if diff < int32(len(activePods)) {
sort.Sort(controller.ActivePods(activePods))
}

// Calculate how many pods need to be removed
podsToRemove := int(currentReplicas - desiredReplicas)
if err := r.scaleDownWorkers(ctx, workload, podList.Items[:podsToRemove]); err != nil {
if err := r.scaleDownWorkers(ctx, workload, activePods[:diff]); err != nil {
return err
}
}
Expand Down Expand Up @@ -259,10 +253,9 @@ func (r *TensorFusionWorkloadReconciler) tryStartWorker(
}

// scaleDownWorkers handles the scaling down of worker pods
func (r *TensorFusionWorkloadReconciler) scaleDownWorkers(ctx context.Context, workload *tfv1.TensorFusionWorkload, pods []corev1.Pod) error {
func (r *TensorFusionWorkloadReconciler) scaleDownWorkers(ctx context.Context, workload *tfv1.TensorFusionWorkload, pods []*corev1.Pod) error {
log := log.FromContext(ctx)
for i := range pods {
podToDelete := &pods[i]
for _, podToDelete := range pods {
log.Info("Scaling down worker pod", "name", podToDelete.Name, "workload", workload.Name)

// If it's already being deleting, should avoid call delete multiple times
Expand Down Expand Up @@ -316,7 +309,7 @@ func (r *TensorFusionWorkloadReconciler) scaleUpWorkers(ctx context.Context, wor
func (r *TensorFusionWorkloadReconciler) updateStatus(
ctx context.Context,
workload *tfv1.TensorFusionWorkload,
pods []corev1.Pod,
pods []*corev1.Pod,
) error {
log := log.FromContext(ctx)
readyReplicas := int32(0)
Expand Down Expand Up @@ -396,6 +389,16 @@ func (r *TensorFusionWorkloadReconciler) updateStatus(
return nil
}

func filterActivePods(podList *corev1.PodList) []*corev1.Pod {
var activePods []*corev1.Pod
for _, pod := range podList.Items {
if pod.DeletionTimestamp.IsZero() {
activePods = append(activePods, &pod)
}
}
return activePods
}

// SetupWithManager sets up the controller with the Manager.
func (r *TensorFusionWorkloadReconciler) SetupWithManager(mgr ctrl.Manager) error {
return ctrl.NewControllerManagedBy(mgr).
Expand Down