Skip to content

Commit 6557758

Browse files
Added LocalStorageClient and /storage/file API route
1 parent 173797a commit 6557758

File tree

6 files changed

+849
-2
lines changed

6 files changed

+849
-2
lines changed

backend/chainlit/data/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,13 @@ def get_data_layer():
5757
azure_storage_key = os.getenv("APP_AZURE_STORAGE_ACCESS_KEY")
5858
is_using_azure = bool(azure_storage_account and azure_storage_key)
5959

60+
# Local Storage
61+
local_storage_path = os.getenv("APP_LOCAL_STORAGE_PATH")
62+
is_using_local = bool(local_storage_path)
63+
6064
storage_client = None
6165

62-
if sum([is_using_s3, is_using_gcs, is_using_azure]) > 1:
66+
if sum([is_using_s3, is_using_gcs, is_using_azure, is_using_local]) > 1:
6367
warnings.warn(
6468
"Multiple storage configurations detected. Please use only one."
6569
)
@@ -92,6 +96,12 @@ def get_data_layer():
9296
storage_account=azure_storage_account,
9397
storage_key=azure_storage_key,
9498
)
99+
elif is_using_local:
100+
from chainlit.data.storage_clients.local import LocalStorageClient
101+
102+
storage_client = LocalStorageClient(
103+
storage_path=local_storage_path,
104+
)
95105

96106
_data_layer = ChainlitDataLayer(
97107
database_url=database_url, storage_client=storage_client

backend/chainlit/data/storage_clients/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,11 @@ async def delete_file(self, object_key: str) -> bool:
2626
@abstractmethod
2727
async def get_read_url(self, object_key: str) -> str:
2828
pass
29+
30+
async def download_file(self, object_key: str) -> tuple[bytes, str] | None:
31+
"""
32+
Optional method to download file content directly, to allow files downloads to be proxied by ChainLit backend itself
33+
34+
Returns (file_content, mime_type) if implemented, None otherwise.
35+
"""
36+
return None
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
import mimetypes
2+
import shutil
3+
from pathlib import Path
4+
from typing import Any, Dict, Union
5+
from urllib.request import pathname2url
6+
7+
from chainlit import make_async
8+
from chainlit.data.storage_clients.base import BaseStorageClient
9+
from chainlit.logger import logger
10+
11+
12+
class LocalStorageClient(BaseStorageClient):
13+
"""
14+
Class to enable local file system storage provider
15+
"""
16+
17+
def __init__(self, storage_path: str):
18+
try:
19+
self.storage_path = Path(storage_path).resolve()
20+
21+
# Create storage directory if it doesn't exist
22+
self.storage_path.mkdir(parents=True, exist_ok=True)
23+
24+
logger.info(
25+
f"LocalStorageClient initialized with path: {self.storage_path}"
26+
)
27+
except Exception as e:
28+
logger.warning(f"LocalStorageClient initialization error: {e}")
29+
raise
30+
31+
def _validate_object_key(self, object_key: str) -> Path:
32+
"""
33+
Validate object_key and ensure the resolved path is within storage directory.
34+
35+
Args:
36+
object_key: The object key to validate
37+
38+
Returns:
39+
Resolved Path object within storage directory
40+
41+
Raises:
42+
ValueError: If path traversal is detected or path is invalid
43+
"""
44+
try:
45+
# Reject absolute paths immediately
46+
if object_key.startswith("/"):
47+
logger.warning(f"Absolute path rejected: {object_key}")
48+
raise ValueError("Invalid object key: absolute paths not allowed")
49+
50+
# Normalize object_key and check for traversal patterns
51+
normalized_key = object_key.strip()
52+
if ".." in normalized_key or "\\" in normalized_key:
53+
logger.warning(f"Path traversal patterns detected: {object_key}")
54+
raise ValueError("Invalid object key: path traversal detected")
55+
56+
# Create the file path
57+
file_path = self.storage_path / normalized_key
58+
resolved_path = file_path.resolve()
59+
60+
# Ensure the resolved path is within the storage directory
61+
resolved_path.relative_to(self.storage_path)
62+
63+
return resolved_path
64+
except ValueError as e:
65+
# Re-raise ValueError as is (our custom errors)
66+
raise e
67+
except Exception as e:
68+
logger.warning(f"Path validation error for {object_key}: {e}")
69+
raise ValueError(f"Invalid object key: {e}")
70+
71+
def sync_get_read_url(self, object_key: str) -> str:
72+
try:
73+
file_path = self._validate_object_key(object_key)
74+
if file_path.exists():
75+
# Return URL pointing to the backend's storage route
76+
url_path = pathname2url(object_key)
77+
return f"/storage/file/{url_path}"
78+
else:
79+
logger.warning(f"LocalStorageClient: File not found: {object_key}")
80+
return object_key
81+
except ValueError:
82+
# Path validation failed, return object_key as fallback
83+
return object_key
84+
except Exception as e:
85+
logger.warning(f"LocalStorageClient, get_read_url error: {e}")
86+
return object_key
87+
88+
async def get_read_url(self, object_key: str) -> str:
89+
return await make_async(self.sync_get_read_url)(object_key)
90+
91+
def sync_upload_file(
92+
self,
93+
object_key: str,
94+
data: Union[bytes, str],
95+
mime: str = "application/octet-stream",
96+
overwrite: bool = True,
97+
content_disposition: str | None = None,
98+
) -> Dict[str, Any]:
99+
try:
100+
file_path = self._validate_object_key(object_key)
101+
102+
# Create parent directories if they don't exist
103+
file_path.parent.mkdir(parents=True, exist_ok=True)
104+
105+
# Check if file exists and overwrite is False
106+
if file_path.exists() and not overwrite:
107+
logger.warning(
108+
f"LocalStorageClient: File exists and overwrite=False: {object_key}"
109+
)
110+
return {}
111+
112+
# Write data to file
113+
if isinstance(data, str):
114+
file_path.write_text(data, encoding="utf-8")
115+
else:
116+
file_path.write_bytes(data)
117+
118+
# Generate URL for the uploaded file using backend's storage route
119+
relative_path = file_path.relative_to(self.storage_path)
120+
url_path = pathname2url(str(relative_path))
121+
url = f"/storage/file/{url_path}"
122+
123+
return {"object_key": object_key, "url": url}
124+
except ValueError as e:
125+
logger.warning(f"LocalStorageClient, upload_file error: {e}")
126+
return {}
127+
except Exception as e:
128+
logger.warning(f"LocalStorageClient, upload_file error: {e}")
129+
return {}
130+
131+
async def upload_file(
132+
self,
133+
object_key: str,
134+
data: Union[bytes, str],
135+
mime: str = "application/octet-stream",
136+
overwrite: bool = True,
137+
content_disposition: str | None = None,
138+
) -> Dict[str, Any]:
139+
return await make_async(self.sync_upload_file)(
140+
object_key, data, mime, overwrite, content_disposition
141+
)
142+
143+
def sync_delete_file(self, object_key: str) -> bool:
144+
try:
145+
file_path = self._validate_object_key(object_key)
146+
if file_path.exists():
147+
if file_path.is_file():
148+
file_path.unlink()
149+
elif file_path.is_dir():
150+
shutil.rmtree(file_path)
151+
return True
152+
else:
153+
logger.warning(
154+
f"LocalStorageClient: File not found for deletion: {object_key}"
155+
)
156+
return False
157+
except ValueError as e:
158+
logger.warning(f"LocalStorageClient, delete_file error: {e}")
159+
return False
160+
except Exception as e:
161+
logger.warning(f"LocalStorageClient, delete_file error: {e}")
162+
return False
163+
164+
async def delete_file(self, object_key: str) -> bool:
165+
return await make_async(self.sync_delete_file)(object_key)
166+
167+
def sync_download_file(self, object_key: str) -> tuple[bytes, str] | None:
168+
try:
169+
file_path = self._validate_object_key(object_key)
170+
if not file_path.exists() or not file_path.is_file():
171+
logger.warning(
172+
f"LocalStorageClient: File not found for download: {object_key}"
173+
)
174+
return None
175+
176+
# Get MIME type
177+
mime_type, _ = mimetypes.guess_type(str(file_path))
178+
if not mime_type:
179+
mime_type = "application/octet-stream"
180+
181+
# Read file content
182+
content = file_path.read_bytes()
183+
return (content, mime_type)
184+
except ValueError as e:
185+
logger.warning(f"LocalStorageClient, download_file error: {e}")
186+
return None
187+
except Exception as e:
188+
logger.warning(f"LocalStorageClient, download_file error: {e}")
189+
return None
190+
191+
async def download_file(self, object_key: str) -> tuple[bytes, str] | None:
192+
return await make_async(self.sync_download_file)(object_key)

backend/chainlit/server.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1636,6 +1636,98 @@ async def get_file(
16361636
raise HTTPException(status_code=404, detail="File not found")
16371637

16381638

1639+
@router.get("/storage/file/{object_key:path}")
1640+
async def get_storage_file(
1641+
object_key: str,
1642+
current_user: UserParam,
1643+
):
1644+
"""Get a file from the storage client if it supports direct downloads."""
1645+
from chainlit.data import get_data_layer
1646+
1647+
data_layer = get_data_layer()
1648+
if not data_layer or not data_layer.storage_client:
1649+
raise HTTPException(
1650+
status_code=404,
1651+
detail="Storage not configured",
1652+
)
1653+
1654+
# Validate user authentication
1655+
if not current_user:
1656+
raise HTTPException(status_code=401, detail="Unauthorized")
1657+
1658+
# Extract thread_id from object_key to validate thread ownership
1659+
# Object key patterns:
1660+
# 1. threads/{thread_id}/files/{element.id} (chainlit_data_layer)
1661+
# 2. {user_id}/{thread_id}/{element.id} (dynamodb)
1662+
# 3. {user_id}/{element.id}[/{element.name}] (sql_alchemy)
1663+
thread_id = None
1664+
1665+
# Try to extract thread_id from different patterns
1666+
parts = object_key.split("/")
1667+
if len(parts) >= 3:
1668+
if parts[0] == "threads":
1669+
# Pattern: threads/{thread_id}/files/{element.id}
1670+
thread_id = parts[1]
1671+
elif len(parts) == 3:
1672+
# Pattern: {user_id}/{thread_id}/{element.id} (dynamodb)
1673+
# We need to verify this is actually a thread_id by checking if it exists
1674+
potential_thread_id = parts[1]
1675+
try:
1676+
# Check if this looks like a thread by validating thread author
1677+
await is_thread_author(current_user.identifier, potential_thread_id)
1678+
thread_id = potential_thread_id
1679+
except HTTPException:
1680+
# Not a valid thread or user doesn't have access
1681+
pass
1682+
1683+
# If we found a thread_id, validate thread ownership
1684+
if thread_id:
1685+
await is_thread_author(current_user.identifier, thread_id)
1686+
else:
1687+
# For files without thread association (pattern 3), we should still
1688+
# validate that the user_id in the path matches the current user
1689+
if len(parts) >= 2:
1690+
user_id_in_path = parts[0]
1691+
if user_id_in_path != current_user.identifier:
1692+
raise HTTPException(
1693+
status_code=403,
1694+
detail="Access denied: file belongs to different user",
1695+
)
1696+
1697+
# Try to extract element_id and get the original filename from database
1698+
element_id = None
1699+
element_name = None
1700+
1701+
# Extract element_id from object_key patterns
1702+
if len(parts) >= 4 and parts[0] == "threads" and parts[2] == "files":
1703+
# Pattern: threads/{thread_id}/files/{element_id}
1704+
element_id = parts[3]
1705+
# Query database for element details
1706+
if thread_id and element_id:
1707+
element = await data_layer.get_element(thread_id, element_id)
1708+
if element:
1709+
element_name = element.get("name")
1710+
1711+
# Only serve files if storage client implements download_file
1712+
file_data = await data_layer.storage_client.download_file(object_key)
1713+
if file_data is None:
1714+
raise HTTPException(
1715+
status_code=404,
1716+
detail="File not found or storage client does not support direct downloads",
1717+
)
1718+
1719+
content, mime_type = file_data
1720+
1721+
# Use the original filename if available, otherwise fall back to the UUID
1722+
filename = element_name if element_name else Path(object_key).name
1723+
1724+
return Response(
1725+
content=content,
1726+
media_type=mime_type,
1727+
headers={"Content-Disposition": f"inline; filename={filename}"},
1728+
)
1729+
1730+
16391731
@router.get("/favicon")
16401732
async def get_favicon():
16411733
"""Get the favicon for the UI."""

backend/tests/conftest.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88
import pytest_asyncio
99

10+
import chainlit.data as data_module
1011
from chainlit import config
1112
from chainlit.callbacks import data_layer
1213
from chainlit.context import ChainlitContext, context_var
@@ -94,10 +95,18 @@ def mock_data_layer(monkeypatch: pytest.MonkeyPatch) -> AsyncMock:
9495

9596

9697
@pytest.fixture
97-
def mock_get_data_layer(mock_data_layer: AsyncMock, test_config: config.ChainlitConfig):
98+
def mock_get_data_layer(
99+
mock_data_layer: AsyncMock,
100+
test_config: config.ChainlitConfig,
101+
monkeypatch: pytest.MonkeyPatch,
102+
):
98103
# Instantiate mock data layer
99104
mock_get_data_layer = Mock(return_value=mock_data_layer)
100105

106+
# Clear the cached data layer so every test exercises its own factory.
107+
monkeypatch.setattr(data_module, "_data_layer", None)
108+
monkeypatch.setattr(data_module, "_data_layer_initialized", False)
109+
101110
# Configure it using @data_layer decorator
102111
return data_layer(mock_get_data_layer)
103112

0 commit comments

Comments
 (0)