Skip to content

feat :Add GetActivePods to handle/datastore and remove deleted pod from prefix-cache scorer #1376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions cmd/epp/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ func (r *Runner) Run(ctx context.Context) error {
}
}

err = r.parsePluginsConfiguration(ctx)
err = r.parsePluginsConfiguration(ctx, datastore)
if err != nil {
setupLog.Error(err, "Failed to parse plugins configuration")
return err
Expand Down Expand Up @@ -411,7 +411,7 @@ func (r *Runner) registerInTreePlugins() {
plugins.Register(testfilter.HeaderBasedTestingFilterType, testfilter.HeaderBasedTestingFilterFactory)
}

func (r *Runner) parsePluginsConfiguration(ctx context.Context) error {
func (r *Runner) parsePluginsConfiguration(ctx context.Context, ds datastore.Datastore) error {
if *configText == "" && *configFile == "" {
return nil // configuring through code, not through file
}
Expand All @@ -430,8 +430,9 @@ func (r *Runner) parsePluginsConfiguration(ctx context.Context) error {
}

r.registerInTreePlugins()
handle := plugins.NewEppHandle(ctx)
handle := plugins.NewEppHandle(ctx, ds.GetActivePods)
config, err := loader.LoadConfig(configBytes, handle, logger)

if err != nil {
return fmt.Errorf("failed to load the configuration - %w", err)
}
Expand Down
11 changes: 11 additions & 0 deletions pkg/epp/datastore/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ type Datastore interface {
PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics
PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool
PodDelete(namespacedName types.NamespacedName)
GetActivePods() []types.NamespacedName
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to add this, we can use PodList.


// Clears the store state, happens when the pool gets deleted.
Clear()
Expand Down Expand Up @@ -225,6 +226,16 @@ func (ds *datastore) PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool {
return ok
}

// GetActivePods returns a list of all active pods.
func (ds *datastore) GetActivePods() []types.NamespacedName {
var namespacedNames []types.NamespacedName
ds.pods.Range(func(k, _ any) bool {
namespacedNames = append(namespacedNames, k.(types.NamespacedName))
return true
})
return namespacedNames
}

func (ds *datastore) PodDelete(namespacedName types.NamespacedName) {
v, ok := ds.pods.LoadAndDelete(namespacedName)
if ok {
Expand Down
16 changes: 15 additions & 1 deletion pkg/epp/plugins/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package plugins
import (
"context"
"fmt"

"k8s.io/apimachinery/pkg/types"
)

// Handle provides plugins a set of standard data and tools to work with
Expand All @@ -27,6 +29,9 @@ type Handle interface {
Context() context.Context

HandlePlugins

// GetActivePods returns a list of all active pods
GetActivePods() []types.NamespacedName
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: In go, it's more idiomatic to use ActivePods

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider returning a map[types.NamespacedName]bool instead of a list so the plugin doesn't need to convert a list to a map/set

}

// HandlePlugins defines a set of APIs to work with instantiated plugins
Expand All @@ -44,10 +49,14 @@ type HandlePlugins interface {
GetAllPluginsWithNames() map[string]Plugin
}

// GetActivePodsFunc is a function that returns a list of all active pods.
type GetActivePodsFunc func() []types.NamespacedName

// eppHandle is an implementation of the interface plugins.Handle
type eppHandle struct {
ctx context.Context
HandlePlugins
getActivePods GetActivePodsFunc
}

// Context returns a context the plugins can use, if they need one
Expand Down Expand Up @@ -84,7 +93,12 @@ func (h *eppHandlePlugins) GetAllPluginsWithNames() map[string]Plugin {
return h.plugins
}

func NewEppHandle(ctx context.Context) Handle {
// GetActivePods returns a function that returns a list of all active pods
func (h *eppHandle) GetActivePods() []types.NamespacedName {
return h.getActivePods()
}

func NewEppHandle(ctx context.Context, getActivePods GetActivePodsFunc) Handle {
return &eppHandle{
ctx: ctx,
HandlePlugins: &eppHandlePlugins{
Expand Down
20 changes: 20 additions & 0 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,23 @@ func (i *indexer) ReportLRUSize(interval time.Duration) {
i.mu.RUnlock()
}
}

// RemovePod removes a pod and its associated entries from the indexer.
func (i *indexer) RemovePod(pod ServerID) {
i.mu.RLock()
lruCache, exists := i.podToLRU[pod]
i.mu.RUnlock()

if !exists {
return
}

// Remove all hashes associated with the pod from hashToPods (triggers eviction callbacks).
for _, hash := range lruCache.Keys() {
lruCache.Remove(hash)
}

i.mu.Lock()
delete(i.podToLRU, pod)
i.mu.Unlock()
}
60 changes: 60 additions & 0 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,63 @@ func TestIndexer_AddAndGet(t *testing.T) {
servers = i.Get(BlockHash(4))
assert.Empty(t, servers, "Cache should not contain non-existent hash")
}

func TestIndexer_RemovePodAndEviction(t *testing.T) {
const indexerSize = 10

i := newIndexer(indexerSize)

server1 := ServerID{Namespace: "default", Name: "server1"}
server2 := ServerID{Namespace: "default", Name: "server2"}

// Add indexerSize hashes to both servers
var hashes []BlockHash
for j := 0; j < indexerSize; j++ {
h := BlockHash(j)
hashes = append(hashes, h)
i.Add([]BlockHash{h}, server1)
i.Add([]BlockHash{h}, server2)
}

// Ensure all entries are added
assert.Equal(t, indexerSize, i.podToLRU[server1].Len(), "server1 should have 10 entries")
assert.Equal(t, indexerSize, i.podToLRU[server2].Len(), "server2 should have 10 entries")

// Ensure each hash in hashToPods maps to both server1 and server2
for _, h := range hashes {
pods := i.hashToPods[h]
assert.Len(t, pods, 2, "Each hash should be associated with exactly 2 pods")
assert.Contains(t, pods, server1, "hash should be associated with server1")
assert.Contains(t, pods, server2, "hash should be associated with server2")
}

// Add indexerSize hash to server1 → should evict BlockHash(0)
evictedHash := BlockHash(0)
newHash := BlockHash(indexerSize)
i.Add([]BlockHash{newHash}, server1)

// server1 LRU should still be at max capacity
assert.Equal(t, indexerSize, i.podToLRU[server1].Len(), "server1 LRU should maintain max size")

// BlockHash(0) should no longer have server1 in hashToPods
pods := i.Get(evictedHash)
assert.NotContains(t, pods, server1, "server1 should be evicted from hashToPods for hash 0")
assert.Contains(t, pods, server2, "server2 should still have hash 0")

// Remove server2
i.RemovePod(server2)

// hashToPods for hash 0 should now be empty
pods = i.Get(evictedHash)
assert.NotContains(t, pods, server2, "server2 should be removed from hash 0")
assert.Empty(t, pods, "hash 0 should have no pods after both eviction and removal")

// All remaining hashes should map only to server1
for hash, pods := range i.hashToPods {
assert.Len(t, pods, 1, "hash %v should have only 1 pod after server2 removal", hash)
assert.Contains(t, pods, server1, "hash %v should only contain server1", hash)
}

// Ensure hashToPods contains exactly indexerSize hashes (post-eviction and server2 removal)
assert.Len(t, i.hashToPods, indexerSize, "hashToPods should contain %d hashes after cleanup", indexerSize)
}
52 changes: 50 additions & 2 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"encoding/binary"
"encoding/json"
"fmt"
"time"

"github.com/cespare/xxhash/v2"
k8stypes "k8s.io/apimachinery/pkg/types"
Expand Down Expand Up @@ -55,6 +56,11 @@ const (
PrefixCachePluginType = "prefix-cache-scorer"
)

const (
PodActiveCheckInterval = 1 * time.Minute
PodInactivityTimeout = 5 * time.Minute
)

var DefaultConfig = Config{
HashBlockSize: DefaultHashBlockSize,
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
Expand Down Expand Up @@ -84,6 +90,7 @@ type podSet map[ServerID]struct{}
type Indexer interface {
Get(hash BlockHash) podSet
Add(hashes []BlockHash, server ServerID)
RemovePod(server ServerID)
}

// BlockHash is a hash of the block of request body.
Expand Down Expand Up @@ -125,7 +132,7 @@ var _ framework.Scorer = &Plugin{}
var _ framework.PostCycle = &Plugin{}

// PrefixCachePluginFactory defines the factory function for Prefix plugin.
func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
parameters := Config{
HashBlockSize: DefaultHashBlockSize,
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
Expand All @@ -138,7 +145,9 @@ func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, _ plug
}
}

return New(parameters).WithName(name), nil
p := New(parameters).WithName(name)
go p.StartPodActiveWatcher(handle.Context(), handle)
return p, nil
}

// New initializes a new prefix Plugin and returns its pointer.
Expand Down Expand Up @@ -239,6 +248,45 @@ func (m *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
return res
}

// StartPodActiveWatcher starts a goroutine that watches for active pods.
func (m *Plugin) StartPodActiveWatcher(ctx context.Context, handle plugins.Handle) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we call this something like CleanUpInactivePods to be more descriptive?

logger := log.FromContext(ctx).V(logutil.VERBOSE)

ticker := time.NewTicker(PodActiveCheckInterval)
defer ticker.Stop()

podLastSeen := make(map[ServerID]time.Time)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking of a much simpler implementation: iterate over current pods in the podToLRU in the indexer, and if that pod doesn't exist in ActivePods, remove. And this approach does not need a PodInactivityTimeout.


for {
select {
case <-ctx.Done():
return
case <-ticker.C:
now := time.Now()
activePods := handle.GetActivePods()

// Track active pods
activeSet := make(map[ServerID]struct{}, len(activePods))
for _, np := range activePods {
id := ServerID(np)
activeSet[id] = struct{}{}
podLastSeen[id] = now
}

// Remove stale pods
for pod, lastSeen := range podLastSeen {
if _, stillActive := activeSet[pod]; !stillActive {
if now.Sub(lastSeen) > PodInactivityTimeout {
m.indexer.RemovePod(pod)
delete(podLastSeen, pod)
logger.Info("Removed inactive pod from prefix cache", "pod", pod)
}
}
}
}
}
}

// hashPrompt divides the prompt into blocks and calculate the prefix cache for each block.
// hash(0) is the hash of the model name, since different models generally don't share prefix cache.
// For block i, hash(i) = hash(block i content, hash(i-1)).
Expand Down
6 changes: 6 additions & 0 deletions test/utils/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package utils
import (
"context"

k8stypes "k8s.io/apimachinery/pkg/types"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
)

Expand All @@ -33,6 +35,10 @@ func (h *testHandle) Context() context.Context {
return h.ctx
}

func (h *testHandle) GetActivePods() []k8stypes.NamespacedName {
return []k8stypes.NamespacedName{}
}

type testHandlePlugins struct {
plugins map[string]plugins.Plugin
}
Expand Down