@@ -356,8 +356,82 @@ async def _forward_data_to_target(self, data: bytes) -> None:
356
356
pipeline_output = pipeline_output .reconstruct ()
357
357
self .target_transport .write (pipeline_output )
358
358
359
+ def _has_complete_body (self ) -> bool :
360
+ """
361
+ Check if we have received the complete request body based on Content-Length header.
362
+
363
+ We check the headers from the buffer instead of using self.request.headers on purpose
364
+ because with CONNECT requests, the whole request arrives in the data and is stored in
365
+ the buffer.
366
+ """
367
+ try :
368
+ # For the initial CONNECT request
369
+ if not self .headers_parsed and self .request and self .request .method == "CONNECT" :
370
+ return True
371
+
372
+ # For subsequent requests or non-CONNECT requests, parse the method from the buffer
373
+ try :
374
+ first_line = self .buffer [: self .buffer .index (b"\r \n " )].decode ("utf-8" )
375
+ method = first_line .split ()[0 ]
376
+ except (ValueError , IndexError ):
377
+ # Haven't received the complete request line yet
378
+ return False
379
+
380
+ if method != "POST" : # do we need to check for other methods? PUT?
381
+ return True
382
+
383
+ # Parse headers from the buffer instead of using self.request.headers
384
+ headers_dict = {}
385
+ try :
386
+ headers_end = self .buffer .index (b"\r \n \r \n " )
387
+ if headers_end <= 0 : # Ensure we have a valid headers section
388
+ return False
389
+
390
+ headers = self .buffer [:headers_end ].split (b"\r \n " )
391
+ if len (headers ) <= 1 : # Ensure we have headers after the request line
392
+ return False
393
+
394
+ for header in headers [1 :]: # Skip the request line
395
+ if not header : # Skip empty lines
396
+ continue
397
+ try :
398
+ name , value = header .decode ("utf-8" ).split (":" , 1 )
399
+ headers_dict [name .strip ().lower ()] = value .strip ()
400
+ except ValueError :
401
+ # Skip malformed headers
402
+ continue
403
+ except ValueError :
404
+ # Haven't received the complete headers yet
405
+ return False
406
+
407
+ # TODO: Add proper support for chunked transfer encoding
408
+ # For now, just pass through and let the pipeline handle it
409
+ if "transfer-encoding" in headers_dict :
410
+ return True
411
+
412
+ try :
413
+ content_length = int (headers_dict .get ("content-length" ))
414
+ except (ValueError , TypeError ):
415
+ # Content-Length header is required for POST requests without chunked encoding
416
+ logger .error ("Missing or invalid Content-Length header in POST request" )
417
+ return False
418
+
419
+ body_start = headers_end + 4 # Add safety check for buffer length
420
+ if body_start >= len (self .buffer ):
421
+ return False
422
+
423
+ current_body_length = len (self .buffer ) - body_start
424
+ return current_body_length >= content_length
425
+ except Exception as e :
426
+ logger .error (f"Error checking body completion: { e } " )
427
+ return False
428
+
359
429
def data_received (self , data : bytes ) -> None :
360
- """Handle received data from client"""
430
+ """
431
+ Handle received data from client. Since we need to process the complete body
432
+ through our pipeline before forwarding, we accumulate the entire request first.
433
+ """
434
+ logger .info (f"Received data from { self .peername } : { data } " )
361
435
try :
362
436
if not self ._check_buffer_size (data ):
363
437
self .send_error_response (413 , b"Request body too large" )
@@ -370,10 +444,17 @@ def data_received(self, data: bytes) -> None:
370
444
if self .headers_parsed :
371
445
if self .request .method == "CONNECT" :
372
446
self .handle_connect ()
447
+ self .buffer .clear ()
373
448
else :
449
+ # Only process the request once we have the complete body
374
450
asyncio .create_task (self .handle_http_request ())
375
451
else :
376
- asyncio .create_task (self ._forward_data_to_target (data ))
452
+ if self ._has_complete_body ():
453
+ # Process the complete request through the pipeline
454
+ complete_request = bytes (self .buffer )
455
+ logger .debug (f"Complete request: { complete_request } " )
456
+ self .buffer .clear ()
457
+ asyncio .create_task (self ._forward_data_to_target (complete_request ))
377
458
378
459
except Exception as e :
379
460
logger .error (f"Error processing received data: { e } " )
0 commit comments