From 12c27c68cc23a3b5fa70b41502becfe31d0eba24 Mon Sep 17 00:00:00 2001 From: Hugefiver Date: Wed, 10 Jul 2024 16:02:08 +0800 Subject: [PATCH] add unit test for rate limit config --- conf/conf.go | 77 +++------------------------ rate_test.go | 19 ++++--- utils/rateconf.go | 83 +++++++++++++++++++++++++++++ utils/rateconf_test.go | 116 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 216 insertions(+), 79 deletions(-) create mode 100644 utils/rateconf.go create mode 100644 utils/rateconf_test.go diff --git a/conf/conf.go b/conf/conf.go index 2dbaeb8..e97829c 100644 --- a/conf/conf.go +++ b/conf/conf.go @@ -4,32 +4,13 @@ import ( "fmt" "io" "os" - "strconv" - "strings" - "time" "github.com/hugefiver/fakessh/modules/gitserver" + "github.com/hugefiver/fakessh/utils" "github.com/pelletier/go-toml/v2" ) -type Duration time.Duration - -func (d *Duration) UnmarshalText(text []byte) error { - if n, err := strconv.ParseFloat(string(text), 64); err == nil { - *d = Duration(time.Duration(n*1000) * time.Millisecond) - return nil - } - - if n, err := time.ParseDuration(string(text)); err == nil { - *d = Duration(n) - return nil - } - return fmt.Errorf("cannot unmarshal %q into a Duration", text) -} - -func (d Duration) Duration() time.Duration { - return time.Duration(d) -} +type RateLimitConfig = utils.RateLimitConfig type AppConfig struct { BaseConfig @@ -67,12 +48,6 @@ type BaseConfig struct { } `toml:"key"` } -type RateLimitConfig struct { - Interval Duration `toml:"interval"` - Limit int `toml:"limit"` - PerIP bool `toml:"per_ip,omitempty"` -} - type ModulesConfig struct { GitServer gitserver.Config `toml:"gitserver"` } @@ -196,51 +171,11 @@ func MergeConfig(c *AppConfig, f *FlagArgsStruct, set StringSet) error { } if len(f.RateLimits) > 0 { - for _, s := range f.RateLimits { - // format "interval:limit" - // or "interval:limit:perip"/"interval:limit:p" or "interval:limit:global"/"interval:limit:g", default global - // or "interval:limit;interval:limit" for multiple rate limits in one string - rs := strings.Split(s, ";") - for _, r := range rs { - r = strings.TrimSpace(r) - if r == "" { - continue - } - - r, err := parseRateLimit(r) - if err != nil { - return err - } - - c.Server.RateLimits = append(c.Server.RateLimits, r) - } + rs, err := utils.ParseCmdlineRateLimits(f.RateLimits) + if err != nil { + return err } + c.Server.RateLimits = append(c.Server.RateLimits, rs...) } return nil } - -func parseRateLimit(s string) (*RateLimitConfig, error) { - parts := strings.Split(s, ":") - if len(parts) != 2 && len(parts) != 3 { - return nil, fmt.Errorf("invalid rate limit string: '%s', expected format: `interval:limit[:tag]`", s) - } - - perip := false - if len(parts) == 3 { - switch strings.ToLower(parts[2]) { - case "p", "perip": - perip = true - default: - } - } - - var interval Duration - if err := interval.UnmarshalText([]byte(parts[0])); err != nil { - return nil, err - } - limit, err := strconv.Atoi(parts[1]) - if err != nil { - return nil, err - } - return &RateLimitConfig{interval, limit, perip}, nil -} diff --git a/rate_test.go b/rate_test.go index 9755e1f..11df550 100644 --- a/rate_test.go +++ b/rate_test.go @@ -5,14 +5,15 @@ import ( "time" "github.com/hugefiver/fakessh/conf" + "github.com/hugefiver/fakessh/utils" ) func TestRateLimiter(t *testing.T) { t.Parallel() rl := NewRateLimiter([]*conf.RateLimitConfig{ - {Limit: 3, Interval: conf.Duration(time.Second)}, - {Limit: 5, Interval: conf.Duration(time.Second * 10)}, + {Limit: 3, Interval: utils.Duration(time.Second)}, + {Limit: 5, Interval: utils.Duration(time.Second * 10)}, }) r1 := rl.Allow() @@ -44,12 +45,14 @@ func TestRateLimiter(t *testing.T) { func TestSSHRateLimiter(t *testing.T) { t.Parallel() - rl := NewSSHRateLimiter([]*conf.RateLimitConfig{ - {Limit: 3, Interval: conf.Duration(time.Second)}, - {Limit: 5, Interval: conf.Duration(time.Second * 10)}, - }, []*conf.RateLimitConfig{ - {Limit: 1, Interval: conf.Duration(time.Second), PerIP: true}, - }) + rl := NewSSHRateLimiter( + []*conf.RateLimitConfig{ + {Limit: 3, Interval: utils.Duration(time.Second)}, + {Limit: 5, Interval: utils.Duration(time.Second * 10)}, + }, + []*conf.RateLimitConfig{ + {Limit: 1, Interval: utils.Duration(time.Second), PerIP: true}, + }) x1 := rl.Allow("1") x2 := rl.Allow("1") diff --git a/utils/rateconf.go b/utils/rateconf.go new file mode 100644 index 0000000..7c7b374 --- /dev/null +++ b/utils/rateconf.go @@ -0,0 +1,83 @@ +package utils + +import ( + "fmt" + "strconv" + "strings" + "time" +) + +type RateLimitConfig struct { + Interval Duration `toml:"interval"` + Limit int `toml:"limit"` + PerIP bool `toml:"per_ip,omitempty"` +} + +type Duration time.Duration + +func (d *Duration) UnmarshalText(text []byte) error { + if n, err := strconv.ParseFloat(string(text), 64); err == nil { + *d = Duration(time.Duration(n*1000) * time.Millisecond) + return nil + } + + if n, err := time.ParseDuration(string(text)); err == nil { + *d = Duration(n) + return nil + } + return fmt.Errorf("cannot unmarshal %q into a Duration", text) +} + +func (d Duration) Duration() time.Duration { + return time.Duration(d) +} + +func ParseRateLimit(s string) (*RateLimitConfig, error) { + parts := strings.Split(s, ":") + if len(parts) != 2 && len(parts) != 3 { + return nil, fmt.Errorf("invalid rate limit string: '%s', expected format: `interval:limit[:tag]`", s) + } + + perip := false + if len(parts) == 3 { + switch strings.ToLower(parts[2]) { + case "p", "perip": + perip = true + default: + } + } + + var interval Duration + if err := interval.UnmarshalText([]byte(parts[0])); err != nil { + return nil, err + } + limit, err := strconv.Atoi(parts[1]) + if err != nil { + return nil, err + } + return &RateLimitConfig{interval, limit, perip}, nil +} + +func ParseCmdlineRateLimits(ss []string) ([]*RateLimitConfig, error) { + var ret []*RateLimitConfig + for _, s := range ss { + // format "interval:limit" + // or "interval:limit:perip"/"interval:limit:p" or "interval:limit:global"/"interval:limit:g", default global + // or "interval:limit;interval:limit" for multiple rate limits in one string + rs := strings.Split(s, ";") + for _, r := range rs { + r = strings.TrimSpace(r) + if r == "" { + continue + } + + r, err := ParseRateLimit(r) + if err != nil { + return nil, err + } + + ret = append(ret, r) + } + } + return ret, nil +} diff --git a/utils/rateconf_test.go b/utils/rateconf_test.go new file mode 100644 index 0000000..115c2a9 --- /dev/null +++ b/utils/rateconf_test.go @@ -0,0 +1,116 @@ +package utils + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestParseRateLimit(t *testing.T) { + tests := []struct { + input string + expected *RateLimitConfig + err bool + }{ + { + input: "10:10", + expected: &RateLimitConfig{ + Interval: Duration(time.Second * 10), + Limit: 10, + PerIP: false, + }, + }, + { + input: "1m:10", + expected: &RateLimitConfig{ + Interval: Duration(time.Second * 60), + Limit: 10, + PerIP: false, + }, + }, + { + input: "10s:10", + expected: &RateLimitConfig{ + Interval: Duration(time.Second * 10), + Limit: 10, + PerIP: false, + }, + }, + { + input: "10s:10:p", + expected: &RateLimitConfig{ + Interval: Duration(time.Second * 10), + Limit: 10, + PerIP: true, + }, + }, + { + input: "10s:10:perip", + expected: &RateLimitConfig{ + Interval: Duration(time.Second * 10), + Limit: 10, + PerIP: true, + }, + }, + { + input: "invalid:10", + expected: nil, + err: true, + }, + { + input: "10s:invalid", + expected: nil, + err: true, + }, + } + + for _, tt := range tests { + actual, err := ParseRateLimit(tt.input) + assert.Equal(t, tt.expected, actual) + assert.Equal(t, tt.err, err != nil, "expect error: %v, got: %v", tt.err, err) + } +} + +func TestParseCmdlineRateLimits(t *testing.T) { + tests := []struct { + input []string + expected []*RateLimitConfig + err bool + }{ + { + input: []string{"10s:10"}, + expected: []*RateLimitConfig{{Interval: Duration(10 * time.Second), Limit: 10}}, + }, + { + input: []string{"10s:10:perip", "10s:20:global"}, + expected: []*RateLimitConfig{ + {Interval: Duration(10 * time.Second), Limit: 10, PerIP: true}, + {Interval: Duration(10 * time.Second), Limit: 20}, + }, + }, + { + input: []string{"10s:10;10s:20"}, + expected: []*RateLimitConfig{ + {Interval: Duration(10 * time.Second), Limit: 10}, + {Interval: Duration(10 * time.Second), Limit: 20}, + }, + }, + { + input: []string{"invalid"}, + err: true, + }, + { + input: []string{"10s:invalid"}, + err: true, + }, + } + + for _, tt := range tests { + t.Run("", func(t *testing.T) { + got, err := ParseCmdlineRateLimits(tt.input) + assert.Equal(t, tt.expected, got) + assert.Equal(t, tt.err, err != nil, "expect error: %v, got: %v", tt.err, err) + }) + } +}