From 18c22b14a3d9af93a60421922b2d8140b81de0f8 Mon Sep 17 00:00:00 2001 From: bdular Date: Mon, 27 May 2024 10:27:55 +0200 Subject: [PATCH] Added context to the cache stores --- cache.go | 4 ++-- persist/cache.go | 7 ++++--- persist/memory.go | 7 ++++--- persist/memory_test.go | 10 ++++++---- persist/redis.go | 9 +++------ 5 files changed, 19 insertions(+), 18 deletions(-) diff --git a/cache.go b/cache.go index 1417b26..eca24d1 100644 --- a/cache.go +++ b/cache.go @@ -79,7 +79,7 @@ func cache( // read cache first { respCache := &ResponseCache{} - err := cacheStore.Get(cacheKey, &respCache) + err := cacheStore.Get(c.Request.Context(), cacheKey, &respCache) if err == nil { replyWithCache(c, cfg, respCache) cfg.hitCacheCallback(c) @@ -118,7 +118,7 @@ func cache( // only cache 2xx response if !c.IsAborted() && cacheWriter.Status() < 300 && cacheWriter.Status() >= 200 { - if err := cacheStore.Set(cacheKey, respCache, cacheDuration); err != nil { + if err := cacheStore.Set(c.Request.Context(), cacheKey, respCache, cacheDuration); err != nil { cfg.logger.Errorf("set cache key error: %s, cache key: %s", err, cacheKey) } } diff --git a/persist/cache.go b/persist/cache.go index 93fcfaa..e884bcf 100644 --- a/persist/cache.go +++ b/persist/cache.go @@ -1,6 +1,7 @@ package persist import ( + "context" "errors" "time" ) @@ -11,11 +12,11 @@ var ErrCacheMiss = errors.New("persist cache miss error") // CacheStore is the interface of a Cache backend type CacheStore interface { // Get retrieves an item from the Cache. if key does not exist in the store, return ErrCacheMiss - Get(key string, value interface{}) error + Get(ctx context.Context, key string, value interface{}) error // Set sets an item to the Cache, replacing any existing item. - Set(key string, value interface{}, expire time.Duration) error + Set(ctx context.Context, key string, value interface{}, expire time.Duration) error // Delete removes an item from the Cache. Does nothing if the key is not in the Cache. - Delete(key string) error + Delete(ctx context.Context, key string) error } diff --git a/persist/memory.go b/persist/memory.go index ca54952..5829110 100644 --- a/persist/memory.go +++ b/persist/memory.go @@ -1,6 +1,7 @@ package persist import ( + "context" "errors" "reflect" "time" @@ -27,17 +28,17 @@ func NewMemoryStore(defaultExpiration time.Duration) *MemoryStore { } // Set put key value pair to memory store, and expire after expireDuration -func (c *MemoryStore) Set(key string, value interface{}, expireDuration time.Duration) error { +func (c *MemoryStore) Set(ctx context.Context, key string, value interface{}, expireDuration time.Duration) error { return c.Cache.SetWithTTL(key, value, expireDuration) } // Delete remove key in memory store, do nothing if key doesn't exist -func (c *MemoryStore) Delete(key string) error { +func (c *MemoryStore) Delete(ctx context.Context, key string) error { return c.Cache.Remove(key) } // Get key in memory store, if key doesn't exist, return ErrCacheMiss -func (c *MemoryStore) Get(key string, value interface{}) error { +func (c *MemoryStore) Get(ctx context.Context, key string, value interface{}) error { val, err := c.Cache.Get(key) if errors.Is(err, ttlcache.ErrNotFound) { return ErrCacheMiss diff --git a/persist/memory_test.go b/persist/memory_test.go index 342097f..b7c50ff 100644 --- a/persist/memory_test.go +++ b/persist/memory_test.go @@ -1,10 +1,12 @@ package persist import ( - "github.com/stretchr/testify/require" + "context" "testing" "time" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/assert" ) @@ -12,12 +14,12 @@ func TestMemoryStore(t *testing.T) { memoryStore := NewMemoryStore(1 * time.Minute) expectVal := "123" - require.Nil(t, memoryStore.Set("test", expectVal, 1*time.Second)) + require.Nil(t, memoryStore.Set(context.TODO(), "test", expectVal, 1*time.Second)) value := "" - assert.Nil(t, memoryStore.Get("test", &value)) + assert.Nil(t, memoryStore.Get(context.TODO(), "test", &value)) assert.Equal(t, expectVal, value) time.Sleep(1 * time.Second) - assert.Equal(t, ErrCacheMiss, memoryStore.Get("test", &value)) + assert.Equal(t, ErrCacheMiss, memoryStore.Get(context.TODO(), "test", &value)) } diff --git a/persist/redis.go b/persist/redis.go index f54dd2b..0d36bb7 100644 --- a/persist/redis.go +++ b/persist/redis.go @@ -21,25 +21,22 @@ func NewRedisStore(redisClient *redis.Client) *RedisStore { } // Set put key value pair to redis, and expire after expireDuration -func (store *RedisStore) Set(key string, value interface{}, expire time.Duration) error { +func (store *RedisStore) Set(ctx context.Context, key string, value interface{}, expire time.Duration) error { payload, err := Serialize(value) if err != nil { return err } - ctx := context.TODO() return store.RedisClient.Set(ctx, key, payload, expire).Err() } // Delete remove key in redis, do nothing if key doesn't exist -func (store *RedisStore) Delete(key string) error { - ctx := context.TODO() +func (store *RedisStore) Delete(ctx context.Context, key string) error { return store.RedisClient.Del(ctx, key).Err() } // Get retrieves an item from redis, if key doesn't exist, return ErrCacheMiss -func (store *RedisStore) Get(key string, value interface{}) error { - ctx := context.TODO() +func (store *RedisStore) Get(ctx context.Context, key string, value interface{}) error { payload, err := store.RedisClient.Get(ctx, key).Bytes() if errors.Is(err, redis.Nil) {