Skip to content

Commit

Permalink
cache: re-enable Redis cluster support (#924)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz authored Jan 12, 2025
1 parent 5a123ca commit 310a28f
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 16 deletions.
19 changes: 19 additions & 0 deletions .github/workflows/build_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,20 @@ jobs:
MYSQL_ROOT_PASSWORD: password

steps:
- name: Install pre-requisites
uses: awalsh128/cache-apt-pkgs-action@latest
with:
packages: redis-tools

- name: Setup Redis cluster
run: |
git clone https://github.com/gorse-cloud/redis-stack.git
docker compose -f redis-stack/docker-compose.yml --project-directory redis-stack up -d
for i in {1..5}; do
redis-cli -p 7005 ping | grep PONG && break
sleep 10
done
- name: Set up Go 1.23.x
uses: actions/setup-go@v4
with:
Expand All @@ -259,6 +273,11 @@ jobs:
env:
MYSQL_URI: mysql://root:password@tcp(localhost:${{ job.services.mariadb.ports[3306] }})/

- name: Test Redis cluster
run: go test ./storage/cache -run ^TestRedis
env:
REDIS_URI: redis+cluster://localhost:7005

golangci:
name: lint
runs-on: ubuntu-latest
Expand Down
1 change: 1 addition & 0 deletions config/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# The database for caching, support Redis, MySQL, Postgres and MongoDB:
# redis://<user>:<password>@<host>:<port>/<db_number>
# rediss://<user>:<password>@<host>:<port>/<db_number>
# redis+cluster://<user>:<password>@<host1>:<port1>,<host2>:<port2>,...,<hostN>:<portN>
# mysql://[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...&paramN=valueN]
# postgres://bob:[email protected]:5432/mydb?sslmode=verify-full
# postgresql://bob:[email protected]:5432/mydb?sslmode=verify-full
Expand Down
14 changes: 14 additions & 0 deletions storage/cache/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,20 @@ func Open(path, tablePrefix string, opts ...storage.Option) (Database, error) {
return nil, errors.Trace(err)
}
return database, nil
} else if strings.HasPrefix(path, storage.RedisClusterPrefix) {
opt, err := ParseRedisClusterURL(path)
if err != nil {
return nil, err
}
opt.Protocol = 2
database := new(Redis)
database.client = redis.NewClusterClient(opt)
database.TablePrefix = storage.TablePrefix(tablePrefix)
if err = redisotel.InstrumentTracing(database.client, redisotel.WithAttributes(semconv.DBSystemRedis)); err != nil {
log.Logger().Error("failed to add tracing for redis", zap.Error(err))
return nil, errors.Trace(err)
}
return database, nil
} else if strings.HasPrefix(path, storage.MongoPrefix) || strings.HasPrefix(path, storage.MongoSrvPrefix) {
// connect to database
database := new(MongoDB)
Expand Down
191 changes: 186 additions & 5 deletions storage/cache/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"encoding/base64"
"fmt"
"io"
"net/url"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -88,7 +89,13 @@ func (r *Redis) Init() error {

func (r *Redis) Scan(work func(string) error) error {
ctx := context.Background()
return r.scan(ctx, r.client, work)
if clusterClient, isCluster := r.client.(*redis.ClusterClient); isCluster {
return clusterClient.ForEachMaster(ctx, func(ctx context.Context, client *redis.Client) error {
return r.scan(ctx, client, work)
})
} else {
return r.scan(ctx, r.client, work)
}
}

func (r *Redis) scan(ctx context.Context, client redis.UniversalClient, work func(string) error) error {
Expand All @@ -115,10 +122,16 @@ func (r *Redis) scan(ctx context.Context, client redis.UniversalClient, work fun

func (r *Redis) Purge() error {
ctx := context.Background()
return r.purge(ctx, r.client)
if clusterClient, isCluster := r.client.(*redis.ClusterClient); isCluster {
return clusterClient.ForEachMaster(ctx, func(ctx context.Context, client *redis.Client) error {
return r.purge(ctx, client, isCluster)
})
} else {
return r.purge(ctx, r.client, isCluster)
}
}

func (r *Redis) purge(ctx context.Context, client redis.UniversalClient) error {
func (r *Redis) purge(ctx context.Context, client redis.UniversalClient, isCluster bool) error {
var (
result []string
cursor uint64
Expand All @@ -130,8 +143,20 @@ func (r *Redis) purge(ctx context.Context, client redis.UniversalClient) error {
return errors.Trace(err)
}
if len(result) > 0 {
if err = client.Del(ctx, result...).Err(); err != nil {
return errors.Trace(err)
if isCluster {
p := client.Pipeline()
for _, key := range result {
if err = p.Del(ctx, key).Err(); err != nil {
return errors.Trace(err)
}
}
if _, err = p.Exec(ctx); err != nil {
return errors.Trace(err)
}
} else {
if err = client.Del(ctx, result...).Err(); err != nil {
return errors.Trace(err)
}
}
}
if cursor == 0 {
Expand Down Expand Up @@ -488,3 +513,159 @@ func escape(s string) string {
)
return r.Replace(s)
}

func ParseRedisClusterURL(redisURL string) (*redis.ClusterOptions, error) {
options := &redis.ClusterOptions{}
uri := redisURL

var err error
if strings.HasPrefix(redisURL, storage.RedisClusterPrefix) {
uri = uri[len(storage.RedisClusterPrefix):]
} else {
return nil, fmt.Errorf("scheme must be \"redis+cluster\"")
}

if idx := strings.Index(uri, "@"); idx != -1 {
userInfo := uri[:idx]
uri = uri[idx+1:]

username := userInfo
var password string

if idx := strings.Index(userInfo, ":"); idx != -1 {
username = userInfo[:idx]
password = userInfo[idx+1:]
}

// Validate and process the username.
if strings.Contains(username, "/") {
return nil, fmt.Errorf("unescaped slash in username")
}
options.Username, err = url.PathUnescape(username)
if err != nil {
return nil, errors.Wrap(err, fmt.Errorf("invalid username"))
}

// Validate and process the password.
if strings.Contains(password, ":") {
return nil, fmt.Errorf("unescaped colon in password")
}
if strings.Contains(password, "/") {
return nil, fmt.Errorf("unescaped slash in password")
}
options.Password, err = url.PathUnescape(password)
if err != nil {
return nil, errors.Wrap(err, fmt.Errorf("invalid password"))
}
}

// fetch the hosts field
hosts := uri
if idx := strings.IndexAny(uri, "/?@"); idx != -1 {
if uri[idx] == '@' {
return nil, fmt.Errorf("unescaped @ sign in user info")
}
hosts = uri[:idx]
}

options.Addrs = strings.Split(hosts, ",")
uri = uri[len(hosts):]
if len(uri) > 0 && uri[0] == '/' {
uri = uri[1:]
}

// grab connection arguments from URI
connectionArgsFromQueryString, err := extractQueryArgsFromURI(uri)
if err != nil {
return nil, err
}
for _, pair := range connectionArgsFromQueryString {
err = addOption(options, pair)
if err != nil {
return nil, err
}
}

return options, nil
}

func extractQueryArgsFromURI(uri string) ([]string, error) {
if len(uri) == 0 {
return nil, nil
}

if uri[0] != '?' {
return nil, errors.New("must have a ? separator between path and query")
}

uri = uri[1:]
if len(uri) == 0 {
return nil, nil
}
return strings.FieldsFunc(uri, func(r rune) bool { return r == ';' || r == '&' }), nil
}

type optionHandler struct {
int *int
bool *bool
duration *time.Duration
}

func addOption(options *redis.ClusterOptions, pair string) error {
kv := strings.SplitN(pair, "=", 2)
if len(kv) != 2 || kv[0] == "" {
return fmt.Errorf("invalid option")
}

key, err := url.QueryUnescape(kv[0])
if err != nil {
return errors.Wrap(err, errors.Errorf("invalid option key %q", kv[0]))
}

value, err := url.QueryUnescape(kv[1])
if err != nil {
return errors.Wrap(err, errors.Errorf("invalid option value %q", kv[1]))
}

handlers := map[string]optionHandler{
"max_retries": {int: &options.MaxRetries},
"min_retry_backoff": {duration: &options.MinRetryBackoff},
"max_retry_backoff": {duration: &options.MaxRetryBackoff},
"dial_timeout": {duration: &options.DialTimeout},
"read_timeout": {duration: &options.ReadTimeout},
"write_timeout": {duration: &options.WriteTimeout},
"pool_fifo": {bool: &options.PoolFIFO},
"pool_size": {int: &options.PoolSize},
"pool_timeout": {duration: &options.PoolTimeout},
"min_idle_conns": {int: &options.MinIdleConns},
"max_idle_conns": {int: &options.MaxIdleConns},
"conn_max_idle_time": {duration: &options.ConnMaxIdleTime},
"conn_max_lifetime": {duration: &options.ConnMaxLifetime},
}

lowerKey := strings.ToLower(key)
if handler, ok := handlers[lowerKey]; ok {
if handler.int != nil {
*handler.int, err = strconv.Atoi(value)
if err != nil {
return errors.Wrap(err, fmt.Errorf("invalid '%s' value: %q", key, value))
}
} else if handler.duration != nil {
*handler.duration, err = time.ParseDuration(value)
if err != nil {
return errors.Wrap(err, fmt.Errorf("invalid '%s' value: %q", key, value))
}
} else if handler.bool != nil {
*handler.bool, err = strconv.ParseBool(value)
if err != nil {
return errors.Wrap(err, fmt.Errorf("invalid '%s' value: %q", key, value))
}
} else {
return fmt.Errorf("redis: unexpected option: %s", key)
}
} else {
return fmt.Errorf("redis: unexpected option: %s", key)
}

return nil
}
24 changes: 24 additions & 0 deletions storage/cache/redis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,27 @@ func BenchmarkRedis(b *testing.B) {
// benchmark
benchmark(b, database)
}

func TestParseRedisClusterURL(t *testing.T) {
options, err := ParseRedisClusterURL("redis+cluster://username:[email protected]:6379,127.0.0.1:6380,127.0.0.1:6381/?" +
"max_retries=1000&dial_timeout=1h&pool_fifo=true")
if assert.NoError(t, err) {
assert.Equal(t, "username", options.Username)
assert.Equal(t, "password", options.Password)
assert.Equal(t, []string{"127.0.0.1:6379", "127.0.0.1:6380", "127.0.0.1:6381"}, options.Addrs)
assert.Equal(t, 1000, options.MaxRetries)
assert.Equal(t, time.Hour, options.DialTimeout)
assert.True(t, options.PoolFIFO)
}

_, err = ParseRedisClusterURL("redis://")
assert.Error(t, err)
_, err = ParseRedisClusterURL("redis+cluster://username:[email protected]:6379/?max_retries=a")
assert.Error(t, err)
_, err = ParseRedisClusterURL("redis+cluster://username:[email protected]:6379/?dial_timeout=a")
assert.Error(t, err)
_, err = ParseRedisClusterURL("redis+cluster://username:[email protected]:6379/?pool_fifo=a")
assert.Error(t, err)
_, err = ParseRedisClusterURL("redis+cluster://username:[email protected]:6379/?a=1")
assert.Error(t, err)
}
23 changes: 12 additions & 11 deletions storage/scheme.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,18 @@ import (
)

const (
MySQLPrefix = "mysql://"
MongoPrefix = "mongodb://"
MongoSrvPrefix = "mongodb+srv://"
PostgresPrefix = "postgres://"
PostgreSQLPrefix = "postgresql://"
ClickhousePrefix = "clickhouse://"
CHHTTPPrefix = "chhttp://"
CHHTTPSPrefix = "chhttps://"
SQLitePrefix = "sqlite://"
RedisPrefix = "redis://"
RedissPrefix = "rediss://"
MySQLPrefix = "mysql://"
MongoPrefix = "mongodb://"
MongoSrvPrefix = "mongodb+srv://"
PostgresPrefix = "postgres://"
PostgreSQLPrefix = "postgresql://"
ClickhousePrefix = "clickhouse://"
CHHTTPPrefix = "chhttp://"
CHHTTPSPrefix = "chhttps://"
SQLitePrefix = "sqlite://"
RedisPrefix = "redis://"
RedissPrefix = "rediss://"
RedisClusterPrefix = "redis+cluster://"
)

func AppendURLParams(rawURL string, params []lo.Tuple2[string, string]) (string, error) {
Expand Down

0 comments on commit 310a28f

Please sign in to comment.