Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions pkg/epp/scheduling/local_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ import (
const (
kvCacheScorerEnablementEnvVar = "ENABLE_KVCACHE_AWARE_SCORER"
loadAwareScorerEnablementEnvVar = "ENABLE_LOAD_AWARE_SCORER"
prefixScorerEnablementEnvVar = "ENABLE_PREFIX_AWARE_SCORER"
pdFilterEnablementEnvVar = "ENABLE_PD_FILTER"

kvCacheScorerWeightEnvVar = "KVCACHE_AWARE_SCORER_WEIGHT"
loadAwareScorerWeightEnvVar = "LOAD_AWARE_SCORER_WEIGHT"
prefixScorerWeightEnvVar = "PREFIX_AWARE_SCORER_WEIGHT"
)

func init() {
Expand All @@ -44,6 +46,7 @@ func setDefaultConfig() {
// this configuration is a temporary state, it should be better streamlined.
setLoadAwareScorer()
setKVCacheAwareScorer()
setPrefixScorer()

defaultConfig.picker = picker.NewMaxScorePicker()
}
Expand Down Expand Up @@ -81,3 +84,20 @@ func setKVCacheAwareScorer() {
defaultConfig.scorers[kvCacheScorer] = kvCacheScorerWeight
loggerDebug.Info("Initialized KVCacheAwareScorer", "weight", kvCacheScorerWeight)
}

func setPrefixScorer() {
ctx := context.Background()
loggerDebug := log.FromContext(ctx).WithName("scheduler_config").V(logutil.DEBUG)
Comment on lines +89 to +90
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems off?
should you be getting the logger fields/configuration from an existing context? If you create a new context, it won't have any existing fields inherited from context

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, but this entire configuration is off. Propagating context here roots it deeper into the codebase, I'd prefer living with the current state until fully refactored. Does that sound ok?


if envutil.GetEnvString(prefixScorerEnablementEnvVar, "false", loggerDebug) != "true" {
loggerDebug.Info("Skipping PrefixScorer creation as it is not enabled")
return
}

prefixScorerWeight := envutil.GetEnvInt(prefixScorerWeightEnvVar, 1, loggerDebug)
prefixScorer := scorer.NewPrefixAwareScorer(nil)
defaultConfig.scorers[prefixScorer] = prefixScorerWeight // TODO: make configurable
defaultConfig.postResponsePlugins = append(defaultConfig.postResponsePlugins, prefixScorer)

loggerDebug.Info("Initialized PrefixAwareScorer", "weight", prefixScorerWeight)
}
130 changes: 130 additions & 0 deletions pkg/epp/scheduling/plugins/scorer/prefix_aware_scorer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
Copyright 2025 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package scorer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since the package name is scorer and is used when referring to the objects outside the package, consider dropping the Scorer from function and other variables:
for example: scorer.PrefixAware instead of scorer.PrefixAwareScorer, scorer.NewPrefixAware instead of scorer.NewPrefixAwareScorer etc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll defer this to a followup PR since this is relevant to other scorers too.


import (
"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)

const prefixAwareScorerName = "prefix-aware-scorer"

// PrefixAwareScorer is a routing scorer that scores pods based on the longest prefix match
// between the request's prompt and stored prefixes. The score is normalized between 0 and 1,
// where 1 represents the longest matching prefix.
type PrefixAwareScorer struct {
prefixStore *PrefixStore
}

var _ plugins.Scorer = &PrefixAwareScorer{}

// NewPrefixAwareScorer creates a new PrefixAwareScorer with the given
// PrefixStoreConfig. If the config is nil, default is used.
func NewPrefixAwareScorer(config *PrefixStoreConfig) *PrefixAwareScorer {
return &PrefixAwareScorer{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other functions assume it is not nil (e.g., L57 below), suggest checking it here (change func signature to also return an error)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter is PrefixStoreConfig and not PrefixStore, which can be nil in its use. I think you missed the Config part.

prefixStore: NewPrefixStore(config),
}
}

func (s *PrefixAwareScorer) Name() string {
return "prefix-aware-scorer"
}

// Score scores the target pods based on the longest prefix match.
func (s *PrefixAwareScorer) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 {
loggerDebug := log.FromContext(ctx).WithName(prefixAwareScorerName).V(logutil.DEBUG)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see here in contrast to the earlier comment on log.FromContext()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explained in the said comment - but reaffirming agreement.

if ctx.Req == nil {
loggerDebug.Info("Request is nil, skipping scoring")
return nil
}

scores := s.prefixStore.FindMatchingPods(ctx.Req.Prompt, ctx.Req.Model)
loggerDebug.Info("Got pod scores", "scores", scores)

if len(scores) == 0 {
loggerDebug.Info("No scores found for pods")
return nil
}

podToKey := func(pod types.Pod) (string, bool) {
if pod.GetPod() == nil {
return "", false
}

return pod.GetPod().NamespacedName.String(), true
}

return indexedScoresToNormalizedScoredPods(pods, podToKey, scores)
}

// PostResponse implements the PostResponsePlugin interface.
// It adds the prefix to the PrefixStore for the given pod.
func (s *PrefixAwareScorer) PostResponse(ctx *types.SchedulingContext, pod types.Pod) {
debugLogger := log.FromContext(ctx).WithName(prefixAwareScorerName).V(logutil.DEBUG)

if ctx.Req == nil {
debugLogger.Info("Request is nil, skipping PostResponse")
return
}

if pod.GetPod() == nil {
debugLogger.Info("Pod is nil, skipping PostResponse", "req", ctx.Req, "pod", pod)
return
}

if err := s.prefixStore.AddEntry(ctx.Req.Model, ctx.Req.Prompt, &pod.GetPod().NamespacedName); err != nil {
debugLogger.Error(err, "Failed to add entry to prefix store", "req", ctx.Req, "pod", pod)
return
}
}

// GetPrefixStore returns the scorer's PrefixStore.
func (s *PrefixAwareScorer) GetPrefixStore() *PrefixStore {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unclear why this is an exported method.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For tests, since they reside in a different package.

return s.prefixStore
}

// podToKey is a function type that converts a Pod to a string key.
// It returns the key and a boolean indicating success.
type podToKeyFunc func(pod types.Pod) (string, bool)

func indexedScoresToNormalizedScoredPods(pods []types.Pod, podToKey podToKeyFunc,
scores map[string]int) map[types.Pod]float64 {
scoredPods := make(map[types.Pod]float64)
minScore, maxScore := getMinMax(scores)

for _, pod := range pods {
key, ok := podToKey(pod)
if !ok {
continue
}

if score, ok := scores[key]; ok {
if minScore == maxScore {
scoredPods[pod] = 1.0
continue
}

scoredPods[pod] = float64(score-minScore) / float64(maxScore-minScore)
} else {
scoredPods[pod] = 0.0
}
}

return scoredPods
}
156 changes: 156 additions & 0 deletions pkg/epp/scheduling/plugins/scorer/prefix_aware_scorer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/*
Copyright 2025 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package scorer_test

import (
"context"
k8stypes "k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/log"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
"testing"
)

func TestPrefixAwareScorer(t *testing.T) {
ctx := context.Background()
logger := log.FromContext(ctx)
ctx = log.IntoContext(ctx, logger)

// Create test pods
pod1 := &types.PodMetrics{
Pod: &backendmetrics.Pod{
NamespacedName: k8stypes.NamespacedName{
Name: "pod1",
Namespace: "default",
},
},
Metrics: &backendmetrics.Metrics{},
}
pod2 := &types.PodMetrics{
Pod: &backendmetrics.Pod{
NamespacedName: k8stypes.NamespacedName{
Name: "pod2",
Namespace: "default",
},
},
Metrics: &backendmetrics.Metrics{},
}

tests := []struct {
name string
weight float64
prompt string
modelName string
prefixToAdd string
podToAdd k8stypes.NamespacedName
prefixModel string // Model name to use when adding the prefix
expectedScores map[types.Pod]float64
}{
{
name: "no prompt",
weight: 1.0,
prompt: "",
modelName: "model1",
prefixToAdd: "hello",
podToAdd: pod1.Pod.NamespacedName,
prefixModel: "model1",
expectedScores: map[types.Pod]float64{}, // No prompt means zero scores
},
{
name: "exact prefix match",
weight: 1.0,
prompt: "hello world",
modelName: "model1",
prefixToAdd: "hello",
podToAdd: pod1.Pod.NamespacedName,
prefixModel: "model1",
expectedScores: map[types.Pod]float64{
pod1: 1.0,
pod2: 0.0,
}, // pod1 matches, pod2 doesn't
},
{
name: "no prefix match",
weight: 1.0,
prompt: "goodbye",
modelName: "model1",
prefixToAdd: "hello",
podToAdd: pod1.Pod.NamespacedName,
prefixModel: "model1",
expectedScores: map[types.Pod]float64{}, // No matching prefix
},
{
name: "different model name",
weight: 1.0,
prompt: "hello world",
modelName: "model2", // Try to find with model2
prefixToAdd: "hello",
podToAdd: pod1.Pod.NamespacedName,
prefixModel: "model1", // But prefix was added with model1
expectedScores: map[types.Pod]float64{}, // Model name mismatch should result in no match
},
{
name: "custom weight",
weight: 0.5,
prompt: "hello world",
modelName: "model1",
prefixToAdd: "hello",
podToAdd: pod1.Pod.NamespacedName,
prefixModel: "model1",
expectedScores: map[types.Pod]float64{
pod1: 0.5, // Pod1 matches with weight
pod2: 0.0, // Pod2 doesn't match
}, // Weight affects score
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset prefix store for each test
config := scorer.DefaultPrefixStoreConfig()
config.BlockSize = 5 // set small chunking for testing

s := scorer.NewPrefixAwareScorer(config)

// Add prefix if specified
if tt.prefixToAdd != "" {
err := s.GetPrefixStore().AddEntry(tt.prefixModel,
tt.prefixToAdd, &tt.podToAdd)
if err != nil {
t.Fatalf("Failed to add prefix: %v", err)
}
}

// Create test context
sCtx := types.NewSchedulingContext(ctx, &types.LLMRequest{
Prompt: tt.prompt,
ResolvedTargetModel: tt.modelName,
}, []types.Pod{}, 0)

// Score pods
pods := []types.Pod{pod1, pod2}
scores := s.Score(sCtx, pods)

for p, score := range scores {
if score != tt.expectedScores[p] {
t.Errorf("Pod %v: expected score %v, got %v", p, tt.expectedScores[p], score)
}
}
})
}
}
Loading