Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit a0e02c0

Browse files
committed
addressed review comments
Signed-off-by: Maroon Ayoub <[email protected]>
1 parent b852c92 commit a0e02c0

File tree

3 files changed

+37
-17
lines changed

3 files changed

+37
-17
lines changed

pkg/epp/scheduling/local_config.go

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import (
2020
"context"
2121

2222
"sigs.k8s.io/controller-runtime/pkg/log"
23-
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/filter"
2423
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker"
2524
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer"
2625
envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env"
@@ -31,7 +30,7 @@ const (
3130
kvCacheScorerEnablementEnvVar = "ENABLE_KVCACHE_AWARE_SCORER"
3231
loadAwareScorerEnablementEnvVar = "ENABLE_LOAD_AWARE_SCORER"
3332
prefixScorerEnablementEnvVar = "ENABLE_PREFIX_AWARE_SCORER"
34-
pdFilterEnablementEnvVar = "ENABLE_PD_FILTER"
33+
pdFilterEnablementEnvVar = "ENABLE_PD_FILTER"
3534

3635
kvCacheScorerWeightEnvVar = "KVCACHE_AWARE_SCORER_WEIGHT"
3736
loadAwareScorerWeightEnvVar = "LOAD_AWARE_SCORER_WEIGHT"
@@ -86,19 +85,6 @@ func setKVCacheAwareScorer() {
8685
loggerDebug.Info("Initialized KVCacheAwareScorer", "weight", kvCacheScorerWeight)
8786
}
8887

89-
func setPDFilter() {
90-
ctx := context.Background()
91-
loggerDebug := log.FromContext(ctx).WithName("scheduler_config").V(logutil.DEBUG)
92-
93-
if envutil.GetEnvString(pdFilterEnablementEnvVar, "false", loggerDebug) != "true" {
94-
loggerDebug.Info("Skipping PDFilter creation as it is not enabled")
95-
return
96-
}
97-
98-
defaultConfig.filters = append(defaultConfig.filters, filter.PDFilter)
99-
loggerDebug.Info("Initialized PDFilter")
100-
}
101-
10288
func setPrefixScorer() {
10389
ctx := context.Background()
10490
loggerDebug := log.FromContext(ctx).WithName("scheduler_config").V(logutil.DEBUG)

pkg/epp/scheduling/plugins/scorer/prefix_aware_scorer.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,33 @@ func (s *PrefixAwareScorer) PostResponse(ctx *types.SchedulingContext, pod types
9898
func (s *PrefixAwareScorer) GetPrefixStore() *PrefixStore {
9999
return s.prefixStore
100100
}
101+
102+
// podToKey is a function type that converts a Pod to a string key.
103+
// It returns the key and a boolean indicating success.
104+
type podToKeyFunc func(pod types.Pod) (string, bool)
105+
106+
func indexedScoresToNormalizedScoredPods(pods []types.Pod, podToKey podToKeyFunc,
107+
scores map[string]int) map[types.Pod]float64 {
108+
scoredPods := make(map[types.Pod]float64)
109+
minScore, maxScore := getMinMax(scores)
110+
111+
for _, pod := range pods {
112+
key, ok := podToKey(pod)
113+
if !ok {
114+
continue
115+
}
116+
117+
if score, ok := scores[key]; ok {
118+
if minScore == maxScore {
119+
scoredPods[pod] = 1.0
120+
continue
121+
}
122+
123+
scoredPods[pod] = float64(score-minScore) / float64(maxScore-minScore)
124+
} else {
125+
scoredPods[pod] = 0.0
126+
}
127+
}
128+
129+
return scoredPods
130+
}

pkg/epp/scheduling/plugins/scorer/prefix_store.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ func NewPrefixStore(config *PrefixStoreConfig) *PrefixStore {
8989

9090
// AddEntry adds a new entry to the prefix store.
9191
func (s *PrefixStore) AddEntry(modelName string, prompt string, pod *types.NamespacedName) error {
92-
if prompt == "" || pod == nil {
92+
if prompt == "" || pod == nil || len(prompt) < s.blockSize /* skip if prompt is too short */ {
9393
return nil
9494
}
9595

@@ -111,7 +111,7 @@ func (s *PrefixStore) AddEntry(modelName string, prompt string, pod *types.Names
111111
for start := 0; start < len(prompt); start += s.blockSize {
112112
end := start + s.blockSize
113113
if end > len(prompt) {
114-
end = len(prompt)
114+
break // skip partial blocks
115115
}
116116

117117
// Compute the hash for the current block
@@ -142,6 +142,10 @@ func (s *PrefixStore) AddEntry(modelName string, prompt string, pod *types.Names
142142
// FindMatchingPods finds all pods that match the given prompt and model name.
143143
// It returns a map of pods and the number of blocks they match.
144144
func (s *PrefixStore) FindMatchingPods(prompt, modelName string) map[string]int {
145+
if prompt == "" || modelName == "" || len(prompt) < s.blockSize /* skip if prompt is too short */ {
146+
return nil
147+
}
148+
145149
s.RLock()
146150
cache, ok := s.store[modelName] // cache is thread-safe
147151
s.RUnlock()

0 commit comments

Comments
 (0)