Skip to content

Commit e343382

Browse files
committed
Adding websocket support & refactoring
Amend: requirements.txt
1 parent 40d22fa commit e343382

File tree

7 files changed

+159
-57
lines changed

7 files changed

+159
-57
lines changed

lib/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .duplex import Duplex

lib/classes.py renamed to lib/duplex.py

+35-17
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import weakref
33
from fastapi import Request
44
from dataclasses import dataclass
5-
from fastapi.responses import StreamingResponse
65

76

87
@dataclass
@@ -16,52 +15,71 @@ class Duplex:
1615

1716
instances = weakref.WeakValueDictionary()
1817

19-
def __init__(self, stream, identifier: str, file: File, wait_for_client: bool):
20-
self.stream = stream
18+
def __init__(self, identifier: str, file: File, stream=None):
2119
self.identifier = identifier
2220
self.file = file
23-
self.queue = asyncio.Queue(1 if wait_for_client else 0)
21+
self.queue = asyncio.Queue(1)
2422
self.client_connected = asyncio.Event()
2523

2624
@staticmethod
27-
def get_upload_details(request: Request):
28-
stream = request.stream
29-
identifier = request.path_params.get('identifier')
25+
def get_file_from_request(request: Request):
3026
file = File(
3127
name=request.path_params.get('file_name'),
3228
size=int(request.headers.get('content-length')),
3329
content_type=request.headers.get('content-type')
3430
)
35-
return stream, identifier, file
31+
return file
32+
33+
@staticmethod
34+
def get_file_from_header(header: dict):
35+
file = File(
36+
name=header['file_name'],
37+
size=int(header['file_size']),
38+
content_type=header['file_type']
39+
)
40+
return file
41+
42+
def get_file_info(self):
43+
return self.file.name, self.file.size, self.file.content_type
3644

3745
@classmethod
38-
def from_upload(cls, request: Request):
39-
stream, identifier, file = cls.get_upload_details(request)
40-
duplex = cls(stream, identifier, file, wait_for_client=True)
46+
def create_duplex(cls, identifier: str, file: File):
47+
duplex = cls(identifier, file)
48+
cls.instances[identifier] = duplex
49+
return duplex
50+
51+
@classmethod
52+
def create_duplex_ws(cls, identifier: str, name: str, size: int, type: str):
53+
file = File(name=name, size=size, content_type=type)
54+
duplex = cls(identifier, file)
4155
cls.instances[identifier] = duplex
4256
return duplex
4357

4458
@classmethod
45-
def from_identifer(cls, identifier: str):
59+
def get(cls, identifier: str):
4660
if duplex := cls.instances.get(identifier):
4761
return duplex
4862
else:
4963
raise KeyError(f"Duplex '{identifier}' not found.")
5064

5165
def get_file_info(self):
5266
return self.file.name, self.file.size, self.file.content_type
67+
68+
async def wait_for_empty_queue(self, seconds=600):
69+
while not self.queue.empty() and seconds > 0:
70+
await asyncio.sleep(1)
71+
seconds -= 1
5372

54-
async def transfer(self):
73+
async def transfer(self, stream):
5574
bytes_read = 0
5675

57-
async for chunk in self.stream():
76+
async for chunk in stream:
5877
bytes_read += len(chunk)
5978
await self.queue.put(chunk)
6079

6180
await self.queue.put(None)
62-
63-
while not self.queue.empty():
64-
await asyncio.sleep(0.5)
81+
await self.wait_for_empty_queue()
82+
return bytes_read
6583

6684
async def receive(self):
6785
while True:

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ pydantic-core>=2.10.1
1414
sniffio>=1.3.0
1515
starlette>=0.27.0
1616
typing-extensions>=4.8.0
17+
websockets>=12.0

views/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .http import router as http_router
2+
from .websockets import router as ws_router

views/http.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import asyncio
2+
from fastapi import Request, APIRouter
3+
from fastapi.responses import Response, StreamingResponse, PlainTextResponse
4+
5+
from lib import Duplex
6+
7+
router = APIRouter()
8+
9+
10+
@router.put("/{identifier}/{file_name}")
11+
async def http_upload(request: Request, identifier: str, file_name: str):
12+
uid = identifier
13+
print(f"{uid} - HTTP transfer request." )
14+
15+
file = Duplex.get_file_from_request(request)
16+
duplex = Duplex.create_duplex(identifier, file)
17+
18+
print(f"{uid} - Waiting for client to connect...")
19+
await duplex.client_connected.wait()
20+
21+
print(f"{uid} - Client connected. Transfering...")
22+
await duplex.transfer(request.stream())
23+
24+
print(f"{uid} - Transfer complete.")
25+
return Response(status_code=200)
26+
27+
28+
@router.get("/{identifier}")
29+
async def http_download(identifier: str):
30+
uid = identifier
31+
print(f"{uid} - HTTP download request." )
32+
33+
try:
34+
duplex = Duplex.get(identifier)
35+
except KeyError:
36+
return PlainTextResponse("File not found", status_code=404)
37+
38+
print(f"{uid} - Notifying client is connected.")
39+
duplex.client_connected.set()
40+
await asyncio.sleep(0.5)
41+
42+
file_name, file_size, file_type = duplex.get_file_info()
43+
44+
print(f"{uid} - Starting download.")
45+
return StreamingResponse(
46+
duplex.receive(),
47+
media_type=file_type,
48+
headers={"Content-Disposition": f"attachment; filename={file_name}", "Content-Length": str(file_size)}
49+
)

views/websockets.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import asyncio
2+
from fastapi import WebSocket, APIRouter
3+
4+
from lib import Duplex
5+
6+
7+
router = APIRouter()
8+
9+
10+
@router.websocket("/send/{identifier}")
11+
async def websocket_upload(websocket: WebSocket, identifier: str):
12+
uid = identifier
13+
await websocket.accept()
14+
print(f"{uid} - Websocket transfer request." )
15+
16+
header = await websocket.receive_json()
17+
18+
try:
19+
file = Duplex.get_file_from_header(header)
20+
except KeyError:
21+
print(f"{uid} - Invalid header: {header}")
22+
return
23+
24+
duplex = Duplex.create_duplex(uid, file)
25+
26+
await duplex.client_connected.wait()
27+
await websocket.send_text(f"Go for file chunks")
28+
29+
print(f"{uid} - Starting transfer...")
30+
await duplex.transfer(websocket.iter_bytes())
31+
32+
print(f"{uid} - Transfer complete.")
33+
34+
35+
@router.websocket("/receive/{identifier}")
36+
async def websocket_download(websocket: WebSocket, identifier: str):
37+
uid = identifier
38+
await websocket.accept()
39+
print(f"{uid} - Websocket download request." )
40+
41+
try:
42+
duplex = Duplex.get(identifier)
43+
except KeyError:
44+
print(f"{uid} - File not found.")
45+
await websocket.send_text("File not found")
46+
return
47+
48+
file_name, file_size, file_type = duplex.get_file_info()
49+
await websocket.send_json({'file_name': file_name, 'file_size': file_size, 'file_type': file_type})
50+
51+
print(f"{uid} - Waiting for go-ahead...")
52+
while (msg := await websocket.receive_text()) != "Go for file chunks":
53+
print(f"{uid} - Unexpected message: {msg}")
54+
55+
print(f"{uid} - Notifying client is connected.")
56+
duplex.client_connected.set()
57+
await asyncio.sleep(0.5)
58+
59+
print(f"{uid} - Starting transfer...")
60+
async for chunk in duplex.receive():
61+
await websocket.send_bytes(chunk)
62+
await websocket.send_bytes(b'')
63+
print(f"{uid} - Transfer complete.")
64+

webapp.py

+7-40
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
import asyncio
21
from pathlib import Path
3-
from fastapi import FastAPI, Request
2+
from fastapi import FastAPI
43
from fastapi.staticfiles import StaticFiles
5-
from fastapi.responses import StreamingResponse, FileResponse, PlainTextResponse, Response
6-
7-
from lib.classes import Duplex
4+
from fastapi.responses import FileResponse
85

6+
from views import http_router, ws_router
97

108
app = FastAPI()
9+
app.include_router(http_router)
10+
app.include_router(ws_router)
1111

1212

1313
@app.get('/')
@@ -25,43 +25,10 @@ async def get_health():
2525
return {"status": "ok"}
2626

2727

28-
@app.put("/{identifier}/{file_name}")
29-
async def upload_file(request: Request, identifier: str, file_name: str):
30-
duplex = Duplex.from_upload(request)
31-
id_ = duplex.identifier
32-
33-
print(f"[{id_}] Waiting for client to connect...")
34-
await duplex.client_connected.wait()
35-
36-
print(f"[{id_}] Client connected. Transfering...")
37-
await duplex.transfer()
38-
39-
print(f"[{id_}] Transfer complete.")
40-
return Response(status_code=200)
41-
42-
43-
@app.get("/{identifier}")
44-
async def get_file(identifier: str):
45-
try:
46-
duplex = Duplex.from_identifer(identifier)
47-
file_name, file_size, file_type = duplex.get_file_info()
48-
except KeyError:
49-
return PlainTextResponse("File not found", status_code=404)
50-
51-
duplex.client_connected.set()
52-
await asyncio.sleep(0.5)
53-
54-
return StreamingResponse(
55-
duplex.receive(),
56-
media_type=file_type,
57-
headers={"Content-Disposition": f"attachment; filename={file_name}", "Content-Length": str(file_size)}
58-
)
59-
60-
61-
# Mount local static directory for HTML
28+
# Mount local static for HTML
6229
app.mount('/static', StaticFiles(directory='static', html=True), name='static')
6330

64-
# Mount remote disk if present or local static for CSS
31+
# Mount remote if present or local static for CSS
6532
if Path('/extra').exists():
6633
app.mount('/css', StaticFiles(directory='/extra'), name='css')
6734
else:

0 commit comments

Comments
 (0)