-
Notifications
You must be signed in to change notification settings - Fork 3
Implemented PrefixAwareScorer Based On Ricardo's Work #118
Changes from all commits
e45e31c
073069a
9e30e07
a481c85
53c550d
d7f20fe
b852c92
a0e02c0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
/* | ||
Copyright 2025 The Kubernetes Authors. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package scorer | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since the package name is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll defer this to a followup PR since this is relevant to other scorers too. |
||
|
||
import ( | ||
"sigs.k8s.io/controller-runtime/pkg/log" | ||
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" | ||
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" | ||
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" | ||
) | ||
|
||
const prefixAwareScorerName = "prefix-aware-scorer" | ||
|
||
// PrefixAwareScorer is a routing scorer that scores pods based on the longest prefix match | ||
// between the request's prompt and stored prefixes. The score is normalized between 0 and 1, | ||
// where 1 represents the longest matching prefix. | ||
type PrefixAwareScorer struct { | ||
prefixStore *PrefixStore | ||
} | ||
|
||
var _ plugins.Scorer = &PrefixAwareScorer{} | ||
|
||
// NewPrefixAwareScorer creates a new PrefixAwareScorer with the given | ||
// PrefixStoreConfig. If the config is nil, default is used. | ||
func NewPrefixAwareScorer(config *PrefixStoreConfig) *PrefixAwareScorer { | ||
return &PrefixAwareScorer{ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Other functions assume it is not nil (e.g., L57 below), suggest checking it here (change func signature to also return an error) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The parameter is |
||
prefixStore: NewPrefixStore(config), | ||
} | ||
} | ||
|
||
func (s *PrefixAwareScorer) Name() string { | ||
return "prefix-aware-scorer" | ||
} | ||
|
||
// Score scores the target pods based on the longest prefix match. | ||
func (s *PrefixAwareScorer) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 { | ||
loggerDebug := log.FromContext(ctx).WithName(prefixAwareScorerName).V(logutil.DEBUG) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see here in contrast to the earlier comment on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Explained in the said comment - but reaffirming agreement. |
||
if ctx.Req == nil { | ||
loggerDebug.Info("Request is nil, skipping scoring") | ||
return nil | ||
} | ||
|
||
scores := s.prefixStore.FindMatchingPods(ctx.Req.Prompt, ctx.Req.Model) | ||
loggerDebug.Info("Got pod scores", "scores", scores) | ||
vMaroon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if len(scores) == 0 { | ||
loggerDebug.Info("No scores found for pods") | ||
return nil | ||
} | ||
|
||
podToKey := func(pod types.Pod) (string, bool) { | ||
if pod.GetPod() == nil { | ||
return "", false | ||
} | ||
|
||
return pod.GetPod().NamespacedName.String(), true | ||
} | ||
|
||
return indexedScoresToNormalizedScoredPods(pods, podToKey, scores) | ||
} | ||
|
||
// PostResponse implements the PostResponsePlugin interface. | ||
// It adds the prefix to the PrefixStore for the given pod. | ||
func (s *PrefixAwareScorer) PostResponse(ctx *types.SchedulingContext, pod types.Pod) { | ||
debugLogger := log.FromContext(ctx).WithName(prefixAwareScorerName).V(logutil.DEBUG) | ||
|
||
if ctx.Req == nil { | ||
debugLogger.Info("Request is nil, skipping PostResponse") | ||
return | ||
} | ||
|
||
if pod.GetPod() == nil { | ||
debugLogger.Info("Pod is nil, skipping PostResponse", "req", ctx.Req, "pod", pod) | ||
return | ||
} | ||
|
||
if err := s.prefixStore.AddEntry(ctx.Req.Model, ctx.Req.Prompt, &pod.GetPod().NamespacedName); err != nil { | ||
debugLogger.Error(err, "Failed to add entry to prefix store", "req", ctx.Req, "pod", pod) | ||
return | ||
} | ||
} | ||
|
||
// GetPrefixStore returns the scorer's PrefixStore. | ||
func (s *PrefixAwareScorer) GetPrefixStore() *PrefixStore { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unclear why this is an exported method. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For tests, since they reside in a different package. |
||
return s.prefixStore | ||
} | ||
|
||
// podToKey is a function type that converts a Pod to a string key. | ||
// It returns the key and a boolean indicating success. | ||
type podToKeyFunc func(pod types.Pod) (string, bool) | ||
|
||
func indexedScoresToNormalizedScoredPods(pods []types.Pod, podToKey podToKeyFunc, | ||
scores map[string]int) map[types.Pod]float64 { | ||
scoredPods := make(map[types.Pod]float64) | ||
minScore, maxScore := getMinMax(scores) | ||
|
||
for _, pod := range pods { | ||
key, ok := podToKey(pod) | ||
if !ok { | ||
continue | ||
} | ||
|
||
if score, ok := scores[key]; ok { | ||
if minScore == maxScore { | ||
scoredPods[pod] = 1.0 | ||
continue | ||
} | ||
|
||
scoredPods[pod] = float64(score-minScore) / float64(maxScore-minScore) | ||
} else { | ||
scoredPods[pod] = 0.0 | ||
} | ||
} | ||
|
||
return scoredPods | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
/* | ||
Copyright 2025 The Kubernetes Authors. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package scorer_test | ||
|
||
import ( | ||
"context" | ||
k8stypes "k8s.io/apimachinery/pkg/types" | ||
"sigs.k8s.io/controller-runtime/pkg/log" | ||
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" | ||
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer" | ||
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" | ||
"testing" | ||
) | ||
|
||
func TestPrefixAwareScorer(t *testing.T) { | ||
ctx := context.Background() | ||
logger := log.FromContext(ctx) | ||
ctx = log.IntoContext(ctx, logger) | ||
|
||
// Create test pods | ||
pod1 := &types.PodMetrics{ | ||
Pod: &backendmetrics.Pod{ | ||
NamespacedName: k8stypes.NamespacedName{ | ||
Name: "pod1", | ||
Namespace: "default", | ||
}, | ||
}, | ||
Metrics: &backendmetrics.Metrics{}, | ||
} | ||
pod2 := &types.PodMetrics{ | ||
Pod: &backendmetrics.Pod{ | ||
NamespacedName: k8stypes.NamespacedName{ | ||
Name: "pod2", | ||
Namespace: "default", | ||
}, | ||
}, | ||
Metrics: &backendmetrics.Metrics{}, | ||
} | ||
|
||
tests := []struct { | ||
name string | ||
weight float64 | ||
prompt string | ||
modelName string | ||
prefixToAdd string | ||
podToAdd k8stypes.NamespacedName | ||
prefixModel string // Model name to use when adding the prefix | ||
expectedScores map[types.Pod]float64 | ||
}{ | ||
{ | ||
name: "no prompt", | ||
weight: 1.0, | ||
prompt: "", | ||
modelName: "model1", | ||
prefixToAdd: "hello", | ||
podToAdd: pod1.Pod.NamespacedName, | ||
prefixModel: "model1", | ||
expectedScores: map[types.Pod]float64{}, // No prompt means zero scores | ||
}, | ||
{ | ||
name: "exact prefix match", | ||
weight: 1.0, | ||
prompt: "hello world", | ||
modelName: "model1", | ||
prefixToAdd: "hello", | ||
podToAdd: pod1.Pod.NamespacedName, | ||
prefixModel: "model1", | ||
expectedScores: map[types.Pod]float64{ | ||
pod1: 1.0, | ||
pod2: 0.0, | ||
}, // pod1 matches, pod2 doesn't | ||
}, | ||
{ | ||
name: "no prefix match", | ||
weight: 1.0, | ||
prompt: "goodbye", | ||
modelName: "model1", | ||
prefixToAdd: "hello", | ||
podToAdd: pod1.Pod.NamespacedName, | ||
prefixModel: "model1", | ||
expectedScores: map[types.Pod]float64{}, // No matching prefix | ||
}, | ||
{ | ||
name: "different model name", | ||
weight: 1.0, | ||
prompt: "hello world", | ||
modelName: "model2", // Try to find with model2 | ||
prefixToAdd: "hello", | ||
podToAdd: pod1.Pod.NamespacedName, | ||
prefixModel: "model1", // But prefix was added with model1 | ||
expectedScores: map[types.Pod]float64{}, // Model name mismatch should result in no match | ||
}, | ||
{ | ||
name: "custom weight", | ||
weight: 0.5, | ||
prompt: "hello world", | ||
modelName: "model1", | ||
prefixToAdd: "hello", | ||
podToAdd: pod1.Pod.NamespacedName, | ||
prefixModel: "model1", | ||
expectedScores: map[types.Pod]float64{ | ||
pod1: 0.5, // Pod1 matches with weight | ||
pod2: 0.0, // Pod2 doesn't match | ||
}, // Weight affects score | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
// Reset prefix store for each test | ||
config := scorer.DefaultPrefixStoreConfig() | ||
config.BlockSize = 5 // set small chunking for testing | ||
|
||
s := scorer.NewPrefixAwareScorer(config) | ||
|
||
// Add prefix if specified | ||
if tt.prefixToAdd != "" { | ||
err := s.GetPrefixStore().AddEntry(tt.prefixModel, | ||
tt.prefixToAdd, &tt.podToAdd) | ||
if err != nil { | ||
t.Fatalf("Failed to add prefix: %v", err) | ||
} | ||
} | ||
|
||
// Create test context | ||
sCtx := types.NewSchedulingContext(ctx, &types.LLMRequest{ | ||
Prompt: tt.prompt, | ||
ResolvedTargetModel: tt.modelName, | ||
}, []types.Pod{}, 0) | ||
|
||
// Score pods | ||
pods := []types.Pod{pod1, pod2} | ||
scores := s.Score(sCtx, pods) | ||
|
||
for p, score := range scores { | ||
if score != tt.expectedScores[p] { | ||
t.Errorf("Pod %v: expected score %v, got %v", p, tt.expectedScores[p], score) | ||
} | ||
} | ||
}) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems off?
should you be getting the logger fields/configuration from an existing context? If you create a new context, it won't have any existing fields inherited from context
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, but this entire configuration is off. Propagating context here roots it deeper into the codebase, I'd prefer living with the current state until fully refactored. Does that sound ok?