Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit c5fd826

Browse files
committed
implemented PrefixAwareScorer based on Ricardo's work
1 parent 49f92e9 commit c5fd826

File tree

10 files changed

+423
-656
lines changed

10 files changed

+423
-656
lines changed

pkg/epp/scheduling/local_config.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,13 @@ import (
3030
const (
3131
kvCacheScorerEnablementEnvVar = "ENABLE_KVCACHE_AWARE_SCORER"
3232
loadAwareScorerEnablementEnvVar = "ENABLE_LOAD_AWARE_SCORER"
33-
pdFilterEnablementEnvVar = "ENABLE_PD_FILTER"
33+
prefixScorerEnablementEnvVar = "ENABLE_PREFIX_AWARE_SCORER"
34+
35+
pdFilterEnablementEnvVar = "ENABLE_PD_FILTER"
3436

3537
kvCacheScorerWeightEnvVar = "KVCACHE_AWARE_SCORER_WEIGHT"
3638
loadAwareScorerWeightEnvVar = "LOAD_AWARE_SCORER_WEIGHT"
39+
prefixScorerWeightEnvVar = "PREFIX_AWARE_SCORER_WEIGHT"
3740
)
3841

3942
func init() {
@@ -46,6 +49,7 @@ func setDefaultConfig() {
4649
setLoadAwareScorer()
4750
setKVCacheAwareScorer()
4851
setPDFilter()
52+
setPrefixScorer()
4953

5054
defaultConfig.picker = picker.NewMaxScorePicker()
5155
}
@@ -96,3 +100,20 @@ func setPDFilter() {
96100
defaultConfig.filters = append(defaultConfig.filters, filter.PDFilter)
97101
loggerDebug.Info("Initialized PDFilter")
98102
}
103+
104+
func setPrefixScorer() {
105+
ctx := context.Background()
106+
loggerDebug := log.FromContext(ctx).WithName("scheduler_config").V(logutil.DEBUG)
107+
108+
if envutil.GetEnvString(prefixScorerEnablementEnvVar, "false", loggerDebug) != "true" {
109+
loggerDebug.Info("Skipping PrefixScorer creation as it is not enabled")
110+
return
111+
}
112+
113+
prefixScorerWeight := envutil.GetEnvInt(prefixScorerWeightEnvVar, 1, loggerDebug)
114+
prefixScorer := scorer.NewPrefixAwareScorer(nil)
115+
defaultConfig.scorers[prefixScorer] = prefixScorerWeight // TODO: make configurable
116+
defaultConfig.postResponsePlugins = append(defaultConfig.postResponsePlugins, prefixScorer)
117+
118+
loggerDebug.Info("Initialized PrefixAwareScorer", "weight", prefixScorerWeight)
119+
}

pkg/epp/scheduling/plugins/scorer/kvcache-aware-scorer.go

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,20 @@ func (s *KVCacheAwareScorer) Score(ctx *types.SchedulingContext, pods []types.Po
9797
}
9898
loggerDebug.Info("Got pod scores", "scores", scores)
9999

100-
return indexerScoresToNormalizedScoredPods(pods, scores)
100+
if len(scores) == 0 {
101+
loggerDebug.Info("No scores found for pods")
102+
return nil
103+
}
104+
105+
podToKey := func(pod types.Pod) (string, bool) {
106+
metricsPod := pod.GetPod()
107+
if metricsPod == nil {
108+
return "", false
109+
}
110+
return metricsPod.Address, true
111+
}
112+
113+
return indexedScoresToNormalizedScoredPods(pods, podToKey, scores)
101114
}
102115

103116
func getMinMax(scores map[string]int) (int, int) {
@@ -116,17 +129,21 @@ func getMinMax(scores map[string]int) (int, int) {
116129
return minScore, maxScore
117130
}
118131

119-
func indexerScoresToNormalizedScoredPods(pods []types.Pod, scores map[string]int) map[types.Pod]float64 {
132+
// podToKey is a function type that converts a Pod to a string key.
133+
// It returns the key and a boolean indicating success.
134+
type podToKeyFunc func(pod types.Pod) (string, bool)
135+
136+
func indexedScoresToNormalizedScoredPods(pods []types.Pod, podToKey podToKeyFunc, scores map[string]int) map[types.Pod]float64 {
120137
scoredPods := make(map[types.Pod]float64)
121138
minScore, maxScore := getMinMax(scores)
122139

123140
for _, pod := range pods {
124-
metricsPod := pod.GetPod()
125-
if metricsPod == nil {
141+
key, ok := podToKey(pod)
142+
if !ok {
126143
continue
127144
}
128145

129-
if score, ok := scores[metricsPod.Address]; ok {
146+
if score, ok := scores[key]; ok {
130147
if minScore == maxScore {
131148
scoredPods[pod] = 1.0
132149
continue
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package scorer
18+
19+
import (
20+
"sigs.k8s.io/controller-runtime/pkg/log"
21+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
22+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
23+
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
24+
)
25+
26+
const prefixAwareScorerName = "prefix-aware-scorer"
27+
28+
// PrefixAwareScorer is a routing scorer that scores pods based on the longest prefix match
29+
// between the request's prompt and stored prefixes. The score is normalized between 0 and 1,
30+
// where 1 represents the longest matching prefix.
31+
type PrefixAwareScorer struct {
32+
prefixStore *PrefixStore
33+
}
34+
35+
var _ plugins.Scorer = &PrefixAwareScorer{}
36+
37+
// NewPrefixAwareScorer creates a new PrefixAwareScorer with the given
38+
// PrefixStoreConfig. If the config is nil, default is used.
39+
func NewPrefixAwareScorer(config *PrefixStoreConfig) *PrefixAwareScorer {
40+
return &PrefixAwareScorer{
41+
prefixStore: NewPrefixStore(config),
42+
}
43+
}
44+
45+
func (s *PrefixAwareScorer) Name() string {
46+
return "prefix-aware-scorer"
47+
}
48+
49+
// Score scores the target pods based on the longest prefix match.
50+
func (s *PrefixAwareScorer) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 {
51+
loggerDebug := log.FromContext(ctx).WithName(prefixAwareScorerName).V(logutil.DEBUG)
52+
if ctx.Req == nil {
53+
loggerDebug.Info("Request is nil, skipping scoring")
54+
return nil
55+
}
56+
57+
scores := s.prefixStore.FindMatchingPods(ctx.Req.Prompt, ctx.Req.Model)
58+
loggerDebug.Info("Got pod scores", "scores", scores)
59+
60+
if len(scores) == 0 {
61+
loggerDebug.Info("No scores found for pods")
62+
return nil
63+
}
64+
65+
podToKey := func(pod types.Pod) (string, bool) {
66+
if pod.GetPod() == nil {
67+
return "", false
68+
}
69+
70+
return pod.GetPod().NamespacedName.String(), true
71+
}
72+
73+
return indexedScoresToNormalizedScoredPods(pods, podToKey, scores)
74+
}
75+
76+
// PostResponse implements the PostResponsePlugin interface.
77+
// It adds the prefix to the PrefixStore for the given pod.
78+
func (s *PrefixAwareScorer) PostResponse(ctx *types.SchedulingContext, pod types.Pod) {
79+
debugLogger := log.FromContext(ctx).WithName(prefixAwareScorerName).V(logutil.DEBUG)
80+
81+
if ctx.Req == nil {
82+
debugLogger.Info("Request is nil, skipping PostResponse")
83+
return
84+
}
85+
86+
if pod.GetPod() == nil {
87+
debugLogger.Info("Pod is nil, skipping PostResponse", "req", ctx.Req, "pod", pod)
88+
return
89+
}
90+
91+
if err := s.prefixStore.AddEntry(ctx.Req.Model, ctx.Req.Prompt, &pod.GetPod().NamespacedName); err != nil {
92+
debugLogger.Error(err, "Failed to add entry to prefix store", "req", ctx.Req, "pod", pod)
93+
return
94+
}
95+
}
96+
97+
// GetPrefixStore returns the scorer's PrefixStore.
98+
func (s *PrefixAwareScorer) GetPrefixStore() *PrefixStore {
99+
return s.prefixStore
100+
}

pkg/epp/scheduling/prefix_aware_scorer_test.go renamed to pkg/epp/scheduling/plugins/scorer/prefix_aware_scorer_test.go

Lines changed: 43 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -14,32 +14,23 @@ See the License for the specific language governing permissions and
1414
limitations under the License.
1515
*/
1616

17-
package scheduling
17+
package scorer_test
1818

1919
import (
2020
"context"
21-
"testing"
22-
"time"
23-
2421
k8stypes "k8s.io/apimachinery/pkg/types"
2522
"sigs.k8s.io/controller-runtime/pkg/log"
2623
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
24+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer"
2725
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
26+
"testing"
2827
)
2928

3029
func TestPrefixAwareScorer(t *testing.T) {
3130
ctx := context.Background()
3231
logger := log.FromContext(ctx)
3332
ctx = log.IntoContext(ctx, logger)
3433

35-
// Create a prefix store with test configuration
36-
prefixStore := NewPrefixStore(PrefixStoreConfig{
37-
MaxEntries: 100,
38-
MinPrefixLen: 3,
39-
MaxPrefixLen: 10,
40-
EntryTTL: 1 * time.Hour,
41-
})
42-
4334
// Create test pods
4435
pod1 := &types.PodMetrics{
4536
Pod: &backendmetrics.Pod{
@@ -68,7 +59,7 @@ func TestPrefixAwareScorer(t *testing.T) {
6859
prefixToAdd string
6960
podToAdd k8stypes.NamespacedName
7061
prefixModel string // Model name to use when adding the prefix
71-
expectedScores []float64
62+
expectedScores map[types.Pod]float64
7263
}{
7364
{
7465
name: "no prompt",
@@ -78,17 +69,20 @@ func TestPrefixAwareScorer(t *testing.T) {
7869
prefixToAdd: "hello",
7970
podToAdd: pod1.Pod.NamespacedName,
8071
prefixModel: "model1",
81-
expectedScores: []float64{0, 0}, // No prompt means zero scores
72+
expectedScores: map[types.Pod]float64{}, // No prompt means zero scores
8273
},
8374
{
84-
name: "exact prefix match",
85-
weight: 1.0,
86-
prompt: "hello world",
87-
modelName: "model1",
88-
prefixToAdd: "hello",
89-
podToAdd: pod1.Pod.NamespacedName,
90-
prefixModel: "model1",
91-
expectedScores: []float64{1.0, 0}, // pod1 matches, pod2 doesn't
75+
name: "exact prefix match",
76+
weight: 1.0,
77+
prompt: "hello world",
78+
modelName: "model1",
79+
prefixToAdd: "hello",
80+
podToAdd: pod1.Pod.NamespacedName,
81+
prefixModel: "model1",
82+
expectedScores: map[types.Pod]float64{
83+
pod1: 1.0,
84+
pod2: 0.0,
85+
}, // pod1 matches, pod2 doesn't
9286
},
9387
{
9488
name: "no prefix match",
@@ -98,7 +92,7 @@ func TestPrefixAwareScorer(t *testing.T) {
9892
prefixToAdd: "hello",
9993
podToAdd: pod1.Pod.NamespacedName,
10094
prefixModel: "model1",
101-
expectedScores: []float64{0, 0}, // No matching prefix
95+
expectedScores: map[types.Pod]float64{}, // No matching prefix
10296
},
10397
{
10498
name: "different model name",
@@ -107,63 +101,54 @@ func TestPrefixAwareScorer(t *testing.T) {
107101
modelName: "model2", // Try to find with model2
108102
prefixToAdd: "hello",
109103
podToAdd: pod1.Pod.NamespacedName,
110-
prefixModel: "model1", // But prefix was added with model1
111-
expectedScores: []float64{0, 0}, // Model name mismatch should result in no match
104+
prefixModel: "model1", // But prefix was added with model1
105+
expectedScores: map[types.Pod]float64{}, // Model name mismatch should result in no match
112106
},
113107
{
114-
name: "custom weight",
115-
weight: 0.5,
116-
prompt: "hello world",
117-
modelName: "model1",
118-
prefixToAdd: "hello",
119-
podToAdd: pod1.Pod.NamespacedName,
120-
prefixModel: "model1",
121-
expectedScores: []float64{0.5, 0}, // Weight affects score
108+
name: "custom weight",
109+
weight: 0.5,
110+
prompt: "hello world",
111+
modelName: "model1",
112+
prefixToAdd: "hello",
113+
podToAdd: pod1.Pod.NamespacedName,
114+
prefixModel: "model1",
115+
expectedScores: map[types.Pod]float64{
116+
pod1: 0.5, // Pod1 matches with weight
117+
pod2: 0.0, // Pod2 doesn't match
118+
}, // Weight affects score
122119
},
123120
}
124121

125122
for _, tt := range tests {
126123
t.Run(tt.name, func(t *testing.T) {
127124
// Reset prefix store for each test
128-
prefixStore = NewPrefixStore(PrefixStoreConfig{
129-
MaxEntries: 100,
130-
MinPrefixLen: 3,
131-
MaxPrefixLen: 10,
132-
EntryTTL: 1 * time.Hour,
133-
})
125+
config := scorer.DefaultPrefixStoreConfig()
126+
config.BlockSize = 5 // set small chunking for testing
127+
128+
s := scorer.NewPrefixAwareScorer(config)
134129

135130
// Add prefix if specified
136131
if tt.prefixToAdd != "" {
137-
err := prefixStore.AddPrefix(ctx, tt.prefixToAdd, tt.podToAdd, tt.prefixModel)
132+
err := s.GetPrefixStore().AddEntry(tt.prefixModel,
133+
tt.prefixToAdd, &tt.podToAdd)
138134
if err != nil {
139135
t.Fatalf("Failed to add prefix: %v", err)
140136
}
141137
}
142138

143-
// Create scorer with test weight
144-
scorer := NewPrefixAwareScorer(tt.weight, prefixStore)
145-
146139
// Create test context
147-
sCtx := types.NewContext(ctx, &types.LLMRequest{
140+
sCtx := types.NewSchedulingContext(ctx, &types.LLMRequest{
148141
Prompt: tt.prompt,
149142
ResolvedTargetModel: tt.modelName,
150-
}, []*types.PodMetrics{})
143+
}, []types.Pod{}, 0)
151144

152145
// Score pods
153-
pods := []*types.PodMetrics{pod1, pod2}
154-
scores, err := scorer.ScoreTargets(sCtx, pods)
155-
if err != nil {
156-
t.Fatalf("Unexpected error: %v", err)
157-
}
158-
159-
// Verify scores
160-
if len(scores) != len(tt.expectedScores) {
161-
t.Fatalf("Expected %d scores, got %d", len(tt.expectedScores), len(scores))
162-
}
146+
pods := []types.Pod{pod1, pod2}
147+
scores := s.Score(sCtx, pods)
163148

164-
for i, score := range scores {
165-
if score.Score != tt.expectedScores[i] {
166-
t.Errorf("Pod %d: expected score %v, got %v", i, tt.expectedScores[i], score.Score)
149+
for p, score := range scores {
150+
if score != tt.expectedScores[p] {
151+
t.Errorf("Pod %v: expected score %v, got %v", p, tt.expectedScores[p], score)
167152
}
168153
}
169154
})

0 commit comments

Comments
 (0)