@@ -13,8 +13,36 @@ const messageSchema = z.object({
13
13
14
14
// TODO: Overhaul
15
15
16
+ let connections = new Map < string , any > ( ) ;
17
+
18
+ type RateLimit = {
19
+ count : number ;
20
+ } ;
21
+
22
+ const generalRateLimit = {
23
+ timeWindowMs : 10000 ,
24
+ count : 10 ,
25
+ } ;
26
+ const globalChatRateLimit : RateLimit = {
27
+ count : 1 ,
28
+ } ;
29
+
16
30
function initWebsocket ( wss : WebSocketServer ) {
17
31
wss . on ( "connection" , async ( ws , req ) => {
32
+ if ( ! req . socket . remoteAddress ) {
33
+ ws . close ( ) ;
34
+ return ;
35
+ }
36
+ const ip = req . socket . remoteAddress ;
37
+ let connection = connections . get ( ip ) ;
38
+ if ( ! connection ) {
39
+ connection = {
40
+ count : 0 ,
41
+ lastMessageTime : Date . now ( ) ,
42
+ } ;
43
+ connections . set ( ip , connection ) ;
44
+ }
45
+
18
46
let sessionId = url . parse ( req . url ! , true ) . query . session ! ;
19
47
if ( ! sessionId || Array . isArray ( sessionId ) ) {
20
48
ws . close ( ) ;
@@ -31,7 +59,15 @@ function initWebsocket(wss: WebSocketServer) {
31
59
32
60
ws . on ( "message" , async ( reqData : string ) => {
33
61
try {
34
- await validateSession ( ws , sessionId ) ;
62
+ const exceededRateLimit = await handleRateLimit ( connection ! ) ;
63
+ if ( exceededRateLimit ) {
64
+ return ws . send ( JSON . stringify ( { error : { message : "Rate limit exceeded, slow down" } } ) ) ;
65
+ }
66
+
67
+ const sessionValid = await validateSession ( sessionId ) ;
68
+ if ( ! sessionValid ) {
69
+ return ws . send ( JSON . stringify ( { error : { message : "Invalid session" } } ) ) ;
70
+ }
35
71
36
72
let reqJson = JSON . parse ( reqData ) ;
37
73
let parsed = messageSchema . safeParse ( reqJson ) ;
@@ -43,6 +79,10 @@ function initWebsocket(wss: WebSocketServer) {
43
79
44
80
switch ( type ) {
45
81
case "globalMessage" : {
82
+ const exceededRateLimit = await checkRateLimit ( connection ! , globalChatRateLimit ) ;
83
+ if ( exceededRateLimit ) {
84
+ return ws . send ( JSON . stringify ( { error : { message : "Rate limit exceeded, slow down" } } ) ) ;
85
+ }
46
86
handleGlobalMessage ( ws , wss , data , user ) ;
47
87
break ;
48
88
}
@@ -123,19 +163,45 @@ async function handleChatMessage(ws: WebSocket, data: any, user: User) {
123
163
}
124
164
}
125
165
126
- async function validateSession ( ws : WebSocket , sessionId : string | string [ ] ) {
166
+ async function validateSession ( sessionId : string | string [ ] ) : Promise < boolean > {
127
167
if ( Array . isArray ( sessionId ) ) {
128
- ws . close ( ) ;
129
- return ;
168
+ return false ;
130
169
}
131
170
const { session, user } = await lucia . validateSession ( sessionId ) ;
132
171
if ( ! session || ! user ) {
133
- ws . close ( ) ;
134
- return ;
172
+ return false ;
173
+ } else {
174
+ return true ;
175
+ }
176
+ }
177
+
178
+ async function handleRateLimit ( connection : any ) : Promise < boolean > {
179
+ const now = Date . now ( ) ;
180
+ const elapsedTime = now - connection . lastMessageTime ;
181
+
182
+ if ( elapsedTime < generalRateLimit . timeWindowMs ) {
183
+ connection . count ++ ;
184
+ if ( connection . count > generalRateLimit . count ) {
185
+ return true ;
186
+ }
187
+ } else {
188
+ connection . count = 1 ;
189
+ connection . lastMessageTime = now ;
190
+ }
191
+ return false ;
192
+ }
193
+
194
+ async function checkRateLimit ( connection : any , rateLimit : RateLimit ) : Promise < boolean > {
195
+ console . log ( connection . count > rateLimit . count ) ;
196
+ if ( connection . count > rateLimit . count ) {
197
+ return true ;
198
+ } else {
199
+ return false ;
135
200
}
136
201
}
137
202
138
203
function broadcast ( wss : WebSocketServer , data : any , skip : WebSocket | null ) {
204
+ wss . clients ;
139
205
wss . clients . forEach ( ( client ) => {
140
206
if ( skip && client == skip ) return ;
141
207
if ( client . readyState === WebSocket . OPEN ) {
0 commit comments