From f6d1d864706e3c7f9f77d588a03bbfeb75e266f3 Mon Sep 17 00:00:00 2001 From: Hugefiver Date: Tue, 9 Jul 2024 17:47:52 +0800 Subject: [PATCH] feat(ratelimit): add rate limit logic --- go.mod | 9 +++ go.sum | 10 ++++ main.go | 8 ++- rate.go | 166 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ssh.go | 63 ++++++++++++++++++++- 5 files changed, 253 insertions(+), 3 deletions(-) create mode 100644 rate.go diff --git a/go.mod b/go.mod index e13de27..b37e485 100644 --- a/go.mod +++ b/go.mod @@ -20,4 +20,13 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect; inirt ) +require golang.org/x/time v0.5.0 + +require ( + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/puzpuzpuz/xsync/v2 v2.5.1 // indirect + github.com/samber/lo v1.44.0 // indirect + golang.org/x/text v0.16.0 // indirect +) + // replace golang.org/x/crypto => ./third/crypto diff --git a/go.sum b/go.sum index 473f1a2..9c55a92 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -5,6 +7,10 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6 github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/puzpuzpuz/xsync/v2 v2.5.1 h1:mVGYAvzDSu52+zaGyNjC+24Xw2bQi3kTr4QJ6N9pIIU= +github.com/puzpuzpuz/xsync/v2 v2.5.1/go.mod h1:gD2H2krq/w52MfPLE+Uy64TzJDVY7lP2znR9qmR35kU= +github.com/samber/lo v1.44.0 h1:5il56KxRE+GHsm1IR+sZ/6J42NODigFiqCWpSc2dybA= +github.com/samber/lo v1.44.0/go.mod h1:RmDH9Ct32Qy3gduHQuKJ3gW1fMHAnE/fAzQuf6He5cU= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -25,6 +31,10 @@ golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk= golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go index b1fde5a..00d923e 100644 --- a/main.go +++ b/main.go @@ -147,7 +147,7 @@ func main() { } } - serverConfig := ssh.ServerConfig{ + serverConfig := &ssh.ServerConfig{ Config: ssh.Config{}, NoClientAuth: false, MaxAuthTries: sc.Server.MaxTry, @@ -170,6 +170,10 @@ func main() { ) } + opt := &Option{ + SSHRateLimits: sc.Server.RateLimits, + } + // Wait goroutines wg := sync.WaitGroup{} @@ -179,7 +183,7 @@ func main() { if !sc.Server.AntiScan { log.Warn("[Sever] Anti honeypot scan DISABLED") } - StartSSHServer(&serverConfig) + StartSSHServer(serverConfig, opt) wg.Done() }() diff --git a/rate.go b/rate.go new file mode 100644 index 0000000..31b4004 --- /dev/null +++ b/rate.go @@ -0,0 +1,166 @@ +package main + +import ( + "hash/maphash" + "time" + + "github.com/cespare/xxhash/v2" + "github.com/hugefiver/fakessh/conf" + "github.com/puzpuzpuz/xsync/v2" + "golang.org/x/time/rate" +) + +type RateLimiter struct { + limiters []*rate.Limiter +} + +type Reservation struct { + reservations []*rate.Reservation + ok bool +} + +func (r Reservation) OK() bool { + return r.ok +} + +func (r Reservation) CancelAt(t time.Time) { + for _, x := range r.reservations { + x.CancelAt(t) + } +} + +func (r Reservation) Cancel() { + r.CancelAt(time.Now()) +} + +func (r Reservation) Merge(o Reservation) Reservation { + if !r.ok || !o.ok { + return Reservation{ok: false} + } + + reservations := make([]*rate.Reservation, 0, len(r.reservations)+len(o.reservations)) + return Reservation{ + reservations: append(append(reservations, r.reservations...), o.reservations...), + ok: true, + } +} + +func NewRateLimiter(cs []*conf.RateLimitConfig) *RateLimiter { + rs := make([]*rate.Limiter, len(cs)) + + for i, c := range cs { + rs[i] = rate.NewLimiter(rate.Every(c.Interval.Duration()), c.Limit) + } + return &RateLimiter{limiters: rs} +} + +func (r *RateLimiter) AllowN(n int) Reservation { + if len(r.limiters) == 0 { + return Reservation{ok: true} + } + + var taken []*rate.Reservation + now := time.Now() + for _, l := range r.limiters { + if rsv := l.ReserveN(now, n); !rsv.OK() { + for _, x := range taken { + x.CancelAt(now) + } + return Reservation{ok: false} + } else { + taken = append(taken, rsv) + } + } + return Reservation{ok: true, reservations: taken} +} + +func (r *RateLimiter) Allow() Reservation { + return r.AllowN(1) +} + +type SSHRateLimiter struct { + globalConfs []*conf.RateLimitConfig + peripConfs []*conf.RateLimitConfig + + globalRl *RateLimiter + peripRls *xsync.MapOf[string, *RateLimiter] +} + +func hashString(seed maphash.Seed, s string) uint64 { + h := xxhash.NewWithSeed(seedSize) + + _, _ = h.WriteString(s) + return h.Sum64() +} + +func NewSSHRateLimiter(global []*conf.RateLimitConfig, perip []*conf.RateLimitConfig) *SSHRateLimiter { + return &SSHRateLimiter{ + globalConfs: global, + peripConfs: perip, + globalRl: NewRateLimiter(global), + peripRls: xsync.NewTypedMapOf[string, *RateLimiter](hashString), + } +} + +func (r *SSHRateLimiter) HasPerIP() bool { + return len(r.peripConfs) > 0 +} + +func (r *SSHRateLimiter) AllowGlobal() Reservation { + return r.globalRl.Allow() +} + +func (r *SSHRateLimiter) AllowPerIP(ip string) Reservation { + if !r.HasPerIP() { + return Reservation{ok: true} + } + + var rl *RateLimiter + if v, ok := r.peripRls.Load(ip); ok { + rl = v + } else { + rl = NewRateLimiter(r.peripConfs) + rl, _ = r.peripRls.LoadOrStore(ip, rl) + } + + return rl.Allow() +} + +func (r *SSHRateLimiter) Allow(ip string) Reservation { + rsv := r.AllowGlobal() + if rsv.OK() && r.HasPerIP() { + if rsv2 := r.AllowPerIP(ip); rsv2.OK() { + return rsv.Merge(rsv2) + } + rsv.Cancel() + return Reservation{ok: false} + } + + return rsv +} + +func (r *SSHRateLimiter) CleanEmpty() (int, int) { + cleaned := 0 + kept := 0 + r.peripRls.Range(func(k string, v *RateLimiter) bool { + ok := true + for _, l := range v.limiters { + if l.Tokens() < float64(l.Burst()) { + ok = false + break + } + } + + if ok { + r.peripRls.Delete(k) + cleaned++ + } else { + kept++ + } + return true + }) + + log.Debugf("[RateLimiterClean] cleaned %d, kept %d", cleaned, kept) + + return cleaned, kept +} diff --git a/ssh.go b/ssh.go index 4db0067..69cc1f9 100644 --- a/ssh.go +++ b/ssh.go @@ -7,12 +7,57 @@ import ( "net" "time" + "github.com/hugefiver/fakessh/conf" "github.com/hugefiver/fakessh/third/ssh" + "github.com/samber/lo" ) -func StartSSHServer(config *ssh.ServerConfig) { +type Option struct { + SSHRateLimits []*conf.RateLimitConfig +} + +func StartSSHServer(config *ssh.ServerConfig, opt *Option) { port := cl.ServPort + pConf, gConf := lo.FilterReject(opt.SSHRateLimits, func(x *conf.RateLimitConfig, _ int) bool { + return x.PerIP + }) + + limiter := NewSSHRateLimiter(gConf, pConf) + + if limiter.HasPerIP() { + log.Debug("[RateLimiterClean] Start in every 5 minutes") + go func() { + const InitDuration = time.Minute * 5 + const MaxDuration = time.Hour + + currDuration := InitDuration + ticker := time.NewTicker(InitDuration) + clearCount := 0 + + for range ticker.C { + c, k := limiter.CleanEmpty() + if c == 0 { + clearCount++ + if k == 0 && clearCount >= 3 { + currDuration *= 2 + if currDuration > MaxDuration { + currDuration = MaxDuration + } + ticker.Reset(currDuration) + } else if k != 0 { + currDuration = InitDuration * 2 + ticker.Reset(currDuration) + } + } else { + clearCount = 0 + currDuration = InitDuration + ticker.Reset(currDuration) + } + } + }() + } + // Binding port listener, err := net.Listen("tcp", port) if err != nil { @@ -23,6 +68,22 @@ func StartSSHServer(config *ssh.ServerConfig) { // Handle connects for { conn, err := listener.Accept() + + var ip string + addr, ok := conn.RemoteAddr().(*net.TCPAddr) + if !ok { + ip = conn.RemoteAddr().String() + } else { + ip = addr.IP.String() + } + + pass := limiter.Allow(conn.RemoteAddr().String()).OK() + if !pass { + log.Infof("[Disconnect] out of rate limit, ip: %s", ip) + _ = conn.Close() + continue + } + if err != nil { log.Debugf("[Disconnect] failed to accept connect %v : %v", conn.RemoteAddr(), err) }