Skip to content

Commit

Permalink
add unit test for rate limit config
Browse files Browse the repository at this point in the history
  • Loading branch information
hugefiver committed Jul 10, 2024
1 parent bd92470 commit 12c27c6
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 79 deletions.
77 changes: 6 additions & 71 deletions conf/conf.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,13 @@ import (
"fmt"
"io"
"os"
"strconv"
"strings"
"time"

"github.com/hugefiver/fakessh/modules/gitserver"
"github.com/hugefiver/fakessh/utils"
"github.com/pelletier/go-toml/v2"
)

type Duration time.Duration

func (d *Duration) UnmarshalText(text []byte) error {
if n, err := strconv.ParseFloat(string(text), 64); err == nil {
*d = Duration(time.Duration(n*1000) * time.Millisecond)
return nil
}

if n, err := time.ParseDuration(string(text)); err == nil {
*d = Duration(n)
return nil
}
return fmt.Errorf("cannot unmarshal %q into a Duration", text)
}

func (d Duration) Duration() time.Duration {
return time.Duration(d)
}
type RateLimitConfig = utils.RateLimitConfig

type AppConfig struct {
BaseConfig
Expand Down Expand Up @@ -67,12 +48,6 @@ type BaseConfig struct {
} `toml:"key"`
}

type RateLimitConfig struct {
Interval Duration `toml:"interval"`
Limit int `toml:"limit"`
PerIP bool `toml:"per_ip,omitempty"`
}

type ModulesConfig struct {
GitServer gitserver.Config `toml:"gitserver"`
}
Expand Down Expand Up @@ -196,51 +171,11 @@ func MergeConfig(c *AppConfig, f *FlagArgsStruct, set StringSet) error {
}

if len(f.RateLimits) > 0 {
for _, s := range f.RateLimits {
// format "interval:limit"
// or "interval:limit:perip"/"interval:limit:p" or "interval:limit:global"/"interval:limit:g", default global
// or "interval:limit;interval:limit" for multiple rate limits in one string
rs := strings.Split(s, ";")
for _, r := range rs {
r = strings.TrimSpace(r)
if r == "" {
continue
}

r, err := parseRateLimit(r)
if err != nil {
return err
}

c.Server.RateLimits = append(c.Server.RateLimits, r)
}
rs, err := utils.ParseCmdlineRateLimits(f.RateLimits)
if err != nil {
return err
}
c.Server.RateLimits = append(c.Server.RateLimits, rs...)
}
return nil
}

func parseRateLimit(s string) (*RateLimitConfig, error) {
parts := strings.Split(s, ":")
if len(parts) != 2 && len(parts) != 3 {
return nil, fmt.Errorf("invalid rate limit string: '%s', expected format: `interval:limit[:tag]`", s)
}

perip := false
if len(parts) == 3 {
switch strings.ToLower(parts[2]) {
case "p", "perip":
perip = true
default:
}
}

var interval Duration
if err := interval.UnmarshalText([]byte(parts[0])); err != nil {
return nil, err
}
limit, err := strconv.Atoi(parts[1])
if err != nil {
return nil, err
}
return &RateLimitConfig{interval, limit, perip}, nil
}
19 changes: 11 additions & 8 deletions rate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ import (
"time"

"github.com/hugefiver/fakessh/conf"
"github.com/hugefiver/fakessh/utils"
)

func TestRateLimiter(t *testing.T) {
t.Parallel()

rl := NewRateLimiter([]*conf.RateLimitConfig{
{Limit: 3, Interval: conf.Duration(time.Second)},
{Limit: 5, Interval: conf.Duration(time.Second * 10)},
{Limit: 3, Interval: utils.Duration(time.Second)},
{Limit: 5, Interval: utils.Duration(time.Second * 10)},
})

r1 := rl.Allow()
Expand Down Expand Up @@ -44,12 +45,14 @@ func TestRateLimiter(t *testing.T) {
func TestSSHRateLimiter(t *testing.T) {
t.Parallel()

rl := NewSSHRateLimiter([]*conf.RateLimitConfig{
{Limit: 3, Interval: conf.Duration(time.Second)},
{Limit: 5, Interval: conf.Duration(time.Second * 10)},
}, []*conf.RateLimitConfig{
{Limit: 1, Interval: conf.Duration(time.Second), PerIP: true},
})
rl := NewSSHRateLimiter(
[]*conf.RateLimitConfig{
{Limit: 3, Interval: utils.Duration(time.Second)},
{Limit: 5, Interval: utils.Duration(time.Second * 10)},
},
[]*conf.RateLimitConfig{
{Limit: 1, Interval: utils.Duration(time.Second), PerIP: true},
})

x1 := rl.Allow("1")
x2 := rl.Allow("1")
Expand Down
83 changes: 83 additions & 0 deletions utils/rateconf.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package utils

import (
"fmt"
"strconv"
"strings"
"time"
)

type RateLimitConfig struct {
Interval Duration `toml:"interval"`
Limit int `toml:"limit"`
PerIP bool `toml:"per_ip,omitempty"`
}

type Duration time.Duration

func (d *Duration) UnmarshalText(text []byte) error {
if n, err := strconv.ParseFloat(string(text), 64); err == nil {
*d = Duration(time.Duration(n*1000) * time.Millisecond)
return nil
}

if n, err := time.ParseDuration(string(text)); err == nil {
*d = Duration(n)
return nil
}
return fmt.Errorf("cannot unmarshal %q into a Duration", text)
}

func (d Duration) Duration() time.Duration {
return time.Duration(d)
}

func ParseRateLimit(s string) (*RateLimitConfig, error) {
parts := strings.Split(s, ":")
if len(parts) != 2 && len(parts) != 3 {
return nil, fmt.Errorf("invalid rate limit string: '%s', expected format: `interval:limit[:tag]`", s)
}

perip := false
if len(parts) == 3 {
switch strings.ToLower(parts[2]) {
case "p", "perip":
perip = true
default:
}
}

var interval Duration
if err := interval.UnmarshalText([]byte(parts[0])); err != nil {
return nil, err
}
limit, err := strconv.Atoi(parts[1])
if err != nil {
return nil, err
}
return &RateLimitConfig{interval, limit, perip}, nil
}

func ParseCmdlineRateLimits(ss []string) ([]*RateLimitConfig, error) {
var ret []*RateLimitConfig
for _, s := range ss {
// format "interval:limit"
// or "interval:limit:perip"/"interval:limit:p" or "interval:limit:global"/"interval:limit:g", default global
// or "interval:limit;interval:limit" for multiple rate limits in one string
rs := strings.Split(s, ";")
for _, r := range rs {
r = strings.TrimSpace(r)
if r == "" {
continue
}

r, err := ParseRateLimit(r)
if err != nil {
return nil, err
}

ret = append(ret, r)
}
}
return ret, nil
}
116 changes: 116 additions & 0 deletions utils/rateconf_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package utils

import (
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestParseRateLimit(t *testing.T) {
tests := []struct {
input string
expected *RateLimitConfig
err bool
}{
{
input: "10:10",
expected: &RateLimitConfig{
Interval: Duration(time.Second * 10),
Limit: 10,
PerIP: false,
},
},
{
input: "1m:10",
expected: &RateLimitConfig{
Interval: Duration(time.Second * 60),
Limit: 10,
PerIP: false,
},
},
{
input: "10s:10",
expected: &RateLimitConfig{
Interval: Duration(time.Second * 10),
Limit: 10,
PerIP: false,
},
},
{
input: "10s:10:p",
expected: &RateLimitConfig{
Interval: Duration(time.Second * 10),
Limit: 10,
PerIP: true,
},
},
{
input: "10s:10:perip",
expected: &RateLimitConfig{
Interval: Duration(time.Second * 10),
Limit: 10,
PerIP: true,
},
},
{
input: "invalid:10",
expected: nil,
err: true,
},
{
input: "10s:invalid",
expected: nil,
err: true,
},
}

for _, tt := range tests {
actual, err := ParseRateLimit(tt.input)
assert.Equal(t, tt.expected, actual)
assert.Equal(t, tt.err, err != nil, "expect error: %v, got: %v", tt.err, err)
}
}

func TestParseCmdlineRateLimits(t *testing.T) {
tests := []struct {
input []string
expected []*RateLimitConfig
err bool
}{
{
input: []string{"10s:10"},
expected: []*RateLimitConfig{{Interval: Duration(10 * time.Second), Limit: 10}},
},
{
input: []string{"10s:10:perip", "10s:20:global"},
expected: []*RateLimitConfig{
{Interval: Duration(10 * time.Second), Limit: 10, PerIP: true},
{Interval: Duration(10 * time.Second), Limit: 20},
},
},
{
input: []string{"10s:10;10s:20"},
expected: []*RateLimitConfig{
{Interval: Duration(10 * time.Second), Limit: 10},
{Interval: Duration(10 * time.Second), Limit: 20},
},
},
{
input: []string{"invalid"},
err: true,
},
{
input: []string{"10s:invalid"},
err: true,
},
}

for _, tt := range tests {
t.Run("", func(t *testing.T) {
got, err := ParseCmdlineRateLimits(tt.input)
assert.Equal(t, tt.expected, got)
assert.Equal(t, tt.err, err != nil, "expect error: %v, got: %v", tt.err, err)
})
}
}

0 comments on commit 12c27c6

Please sign in to comment.