@@ -13,10 +13,18 @@ const messageSchema = z.object({
13
13
14
14
// TODO: Overhaul
15
15
16
- const connections = new Map < string , any > ( ) ;
17
- const limiter = {
18
- timeWindowMs : 60 * 1000 ,
19
- limit : 100 ,
16
+ let connections = new Map < string , any > ( ) ;
17
+
18
+ type RateLimit = {
19
+ count : number ;
20
+ } ;
21
+
22
+ const generalRateLimit = {
23
+ timeWindowMs : 10000 ,
24
+ count : 10 ,
25
+ } ;
26
+ const globalChatRateLimit : RateLimit = {
27
+ count : 1 ,
20
28
} ;
21
29
22
30
function initWebsocket ( wss : WebSocketServer ) {
@@ -51,12 +59,15 @@ function initWebsocket(wss: WebSocketServer) {
51
59
52
60
ws . on ( "message" , async ( reqData : string ) => {
53
61
try {
54
- const exceededRateLimit = await checkRateLimit ( connection ! ) ;
62
+ const exceededRateLimit = await handleRateLimit ( connection ! ) ;
55
63
if ( exceededRateLimit ) {
56
- return ws . send ( JSON . stringify ( { error : { message : "Rate limit exceeded" } } ) ) ;
64
+ return ws . send ( JSON . stringify ( { error : { message : "Rate limit exceeded, slow down " } } ) ) ;
57
65
}
58
66
59
- await validateSession ( ws , sessionId ) ;
67
+ const sessionValid = await validateSession ( sessionId ) ;
68
+ if ( ! sessionValid ) {
69
+ return ws . send ( JSON . stringify ( { error : { message : "Invalid session" } } ) ) ;
70
+ }
60
71
61
72
let reqJson = JSON . parse ( reqData ) ;
62
73
let parsed = messageSchema . safeParse ( reqJson ) ;
@@ -68,7 +79,10 @@ function initWebsocket(wss: WebSocketServer) {
68
79
69
80
switch ( type ) {
70
81
case "globalMessage" : {
71
- // TODO: Add global chat rate limits
82
+ const exceededRateLimit = await checkRateLimit ( connection ! , globalChatRateLimit ) ;
83
+ if ( exceededRateLimit ) {
84
+ return ws . send ( JSON . stringify ( { error : { message : "Rate limit exceeded, slow down" } } ) ) ;
85
+ }
72
86
handleGlobalMessage ( ws , wss , data , user ) ;
73
87
break ;
74
88
}
@@ -149,26 +163,25 @@ async function handleChatMessage(ws: WebSocket, data: any, user: User) {
149
163
}
150
164
}
151
165
152
- async function validateSession ( ws : WebSocket , sessionId : string | string [ ] ) {
166
+ async function validateSession ( sessionId : string | string [ ] ) : Promise < boolean > {
153
167
if ( Array . isArray ( sessionId ) ) {
154
- ws . close ( ) ;
155
- return ;
168
+ return false ;
156
169
}
157
170
const { session, user } = await lucia . validateSession ( sessionId ) ;
158
171
if ( ! session || ! user ) {
159
- ws . close ( ) ;
160
- return ;
172
+ return false ;
173
+ } else {
174
+ return true ;
161
175
}
162
176
}
163
177
164
- async function checkRateLimit ( connection : any ) : Promise < boolean > {
165
- console . log ( connection ) ;
178
+ async function handleRateLimit ( connection : any ) : Promise < boolean > {
166
179
const now = Date . now ( ) ;
167
180
const elapsedTime = now - connection . lastMessageTime ;
168
181
169
- if ( elapsedTime < limiter . timeWindowMs ) {
182
+ if ( elapsedTime < generalRateLimit . timeWindowMs ) {
170
183
connection . count ++ ;
171
- if ( connection . count > limiter . limit ) {
184
+ if ( connection . count > generalRateLimit . count ) {
172
185
return true ;
173
186
}
174
187
} else {
@@ -178,6 +191,15 @@ async function checkRateLimit(connection: any): Promise<boolean> {
178
191
return false ;
179
192
}
180
193
194
+ async function checkRateLimit ( connection : any , rateLimit : RateLimit ) : Promise < boolean > {
195
+ console . log ( connection . count > rateLimit . count ) ;
196
+ if ( connection . count > rateLimit . count ) {
197
+ return true ;
198
+ } else {
199
+ return false ;
200
+ }
201
+ }
202
+
181
203
function broadcast ( wss : WebSocketServer , data : any , skip : WebSocket | null ) {
182
204
wss . clients ;
183
205
wss . clients . forEach ( ( client ) => {
0 commit comments