@@ -60,6 +60,17 @@ impl IPExtractor {
60
60
}
61
61
}
62
62
63
+ fn ip_from_req ( req : & HttpRequest ) -> actix_web:: Result < String > {
64
+ let ip = if let Some ( ip) = req. connection_info ( ) . realip_remote_addr ( ) {
65
+ ip. to_string ( )
66
+ } else {
67
+ println ! ( "No ip found for route {}" , req. path( ) ) ;
68
+ return Err ( crate :: error:: Error :: InternalError . into ( ) ) ;
69
+ } ;
70
+
71
+ Ok ( ip)
72
+ }
73
+
63
74
#[ derive( Clone , PartialEq , Eq , Hash , Debug ) ]
64
75
pub struct LoginRateLimitKey {
65
76
user : String ,
@@ -79,37 +90,69 @@ const REQUESTS_PER_SECOND: u32 = 5;
79
90
const REQUESTS_BURST : u32 = 25 ;
80
91
81
92
// RateLimiter for the login route
82
- pub struct LoginRateLimiter {
83
- rate_limiter : RateLimiter <
93
+ pub struct LoginRateLimiter (
94
+ RateLimiter <
84
95
LoginRateLimitKey ,
85
96
dashmap:: DashMap < LoginRateLimitKey , InMemoryState > ,
86
97
QuantaClock ,
87
98
governor:: middleware:: NoOpMiddleware < governor:: clock:: QuantaInstant > ,
88
99
> ,
89
- }
100
+ ) ;
90
101
91
102
impl LoginRateLimiter {
92
103
pub fn new ( ) -> Self {
93
- Self {
94
- rate_limiter : RateLimiter :: keyed (
95
- Quota :: per_second ( NonZeroU32 :: new ( REQUESTS_PER_SECOND ) . unwrap ( ) )
96
- . allow_burst ( NonZeroU32 :: new ( REQUESTS_BURST ) . unwrap ( ) ) ,
97
- ) ,
98
- }
104
+ Self ( RateLimiter :: keyed (
105
+ Quota :: per_second ( NonZeroU32 :: new ( REQUESTS_PER_SECOND ) . unwrap ( ) )
106
+ . allow_burst ( NonZeroU32 :: new ( REQUESTS_BURST ) . unwrap ( ) ) ,
107
+ ) )
99
108
}
100
109
101
110
pub fn check ( & self , email : String , req : & HttpRequest ) -> actix_web:: Result < ( ) > {
102
- let ip = if let Some ( ip) = req. connection_info ( ) . realip_remote_addr ( ) {
103
- ip. to_string ( )
104
- } else {
105
- println ! ( "No ip found for route {}" , req. path( ) ) ;
106
- return Err ( crate :: error:: Error :: InternalError . into ( ) ) ;
107
- } ;
111
+ let ip = ip_from_req ( req) ?;
108
112
109
113
let key = LoginRateLimitKey { user : email, ip } ;
110
- if let Err ( err) = self . rate_limiter . check_key ( & key) {
114
+ if let Err ( err) = self . 0 . check_key ( & key) {
111
115
log:: warn!( "RateLimiter triggered for {:?}" , key) ;
112
- let now = self . rate_limiter . clock ( ) . now ( ) ;
116
+ let now = self . 0 . clock ( ) . now ( ) ;
117
+
118
+ Err ( RateLimitError :: new ( err, now) . into ( ) )
119
+ } else {
120
+ Ok ( ( ) )
121
+ }
122
+ }
123
+ }
124
+
125
+ #[ derive( Clone , PartialEq , Eq , Hash , Debug ) ]
126
+ pub struct ChargerRateLimitKey {
127
+ charger_id : String ,
128
+ ip : String ,
129
+ }
130
+
131
+ // Rate limiter for all routes that get called by chargers
132
+ pub struct ChargerRateLimiter (
133
+ RateLimiter <
134
+ ChargerRateLimitKey ,
135
+ dashmap:: DashMap < ChargerRateLimitKey , InMemoryState > ,
136
+ QuantaClock ,
137
+ governor:: middleware:: NoOpMiddleware < governor:: clock:: QuantaInstant > ,
138
+ > ,
139
+ ) ;
140
+
141
+ impl ChargerRateLimiter {
142
+ pub fn new ( ) -> Self {
143
+ Self ( RateLimiter :: keyed (
144
+ Quota :: per_second ( NonZeroU32 :: new ( REQUESTS_PER_SECOND ) . unwrap ( ) )
145
+ . allow_burst ( NonZeroU32 :: new ( REQUESTS_BURST ) . unwrap ( ) ) ,
146
+ ) )
147
+ }
148
+
149
+ pub fn check ( & self , charger_id : String , req : & HttpRequest ) -> actix_web:: Result < ( ) > {
150
+ let ip = ip_from_req ( req) ?;
151
+
152
+ let key = ChargerRateLimitKey { charger_id, ip } ;
153
+ if let Err ( err) = self . 0 . check_key ( & key) {
154
+ log:: warn!( "RateLimiter triggered for {:?}" , key) ;
155
+ let now = self . 0 . clock ( ) . now ( ) ;
113
156
114
157
Err ( RateLimitError :: new ( err, now) . into ( ) )
115
158
} else {
@@ -155,6 +198,8 @@ impl ResponseError for RateLimitError {
155
198
mod tests {
156
199
use actix_web:: test;
157
200
201
+ use crate :: rate_limit:: ChargerRateLimiter ;
202
+
158
203
use super :: LoginRateLimiter ;
159
204
160
205
#[ actix_web:: test]
@@ -195,4 +240,43 @@ mod tests {
195
240
let ret = limiter. check ( email. clone ( ) , & req) ;
196
241
assert ! ( ret. is_ok( ) ) ;
197
242
}
243
+
244
+ #[ actix_web:: test]
245
+ async fn test_charger_rate_limiter ( ) {
246
+ let limiter = ChargerRateLimiter :: new ( ) ;
247
+ let req = test:: TestRequest :: get ( )
248
+ . uri ( "/login" )
249
+ . insert_header ( ( "X-Forwarded-For" , "123.123.123.2" ) )
250
+ . to_http_request ( ) ;
251
+ let email = uuid:: Uuid :: new_v4 ( ) . to_string ( ) ;
252
+
253
+ let ret = limiter. check ( email. clone ( ) , & req) ;
254
+ assert ! ( ret. is_ok( ) ) ;
255
+
256
+ let ret = limiter. check ( email. clone ( ) , & req) ;
257
+ assert ! ( ret. is_ok( ) ) ;
258
+
259
+ let ret = limiter. check ( email. clone ( ) , & req) ;
260
+ assert ! ( ret. is_ok( ) ) ;
261
+
262
+ let ret = limiter. check ( email. clone ( ) , & req) ;
263
+ assert ! ( ret. is_ok( ) ) ;
264
+
265
+ let ret = limiter. check ( email. clone ( ) , & req) ;
266
+ assert ! ( ret. is_ok( ) ) ;
267
+
268
+ let ret = limiter. check ( email. clone ( ) , & req) ;
269
+ assert ! ( ret. is_err( ) ) ;
270
+
271
+ let email2 = uuid:: Uuid :: new_v4 ( ) . to_string ( ) ;
272
+ let ret = limiter. check ( email2. clone ( ) , & req) ;
273
+ assert ! ( ret. is_ok( ) ) ;
274
+
275
+ let req = test:: TestRequest :: get ( )
276
+ . uri ( "/login" )
277
+ . insert_header ( ( "X-Forwarded-For" , "123.123.123.3" ) )
278
+ . to_http_request ( ) ;
279
+ let ret = limiter. check ( email. clone ( ) , & req) ;
280
+ assert ! ( ret. is_ok( ) ) ;
281
+ }
198
282
}
0 commit comments