Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion backend/chainlit/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,13 @@ def get_data_layer():
azure_storage_key = os.getenv("APP_AZURE_STORAGE_ACCESS_KEY")
is_using_azure = bool(azure_storage_account and azure_storage_key)

# Local Storage
local_storage_path = os.getenv("APP_LOCAL_STORAGE_PATH")
is_using_local = bool(local_storage_path)

storage_client = None

if sum([is_using_s3, is_using_gcs, is_using_azure]) > 1:
if sum([is_using_s3, is_using_gcs, is_using_azure, is_using_local]) > 1:
warnings.warn(
"Multiple storage configurations detected. Please use only one."
)
Expand Down Expand Up @@ -92,6 +96,12 @@ def get_data_layer():
storage_account=azure_storage_account,
storage_key=azure_storage_key,
)
elif is_using_local:
from chainlit.data.storage_clients.local import LocalStorageClient

storage_client = LocalStorageClient(
storage_path=local_storage_path,
)

_data_layer = ChainlitDataLayer(
database_url=database_url, storage_client=storage_client
Expand Down
8 changes: 8 additions & 0 deletions backend/chainlit/data/storage_clients/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,11 @@ async def delete_file(self, object_key: str) -> bool:
@abstractmethod
async def get_read_url(self, object_key: str) -> str:
pass

async def download_file(self, object_key: str) -> tuple[bytes, str] | None:
"""
Optional method to download file content directly, to allow files downloads to be proxied by ChainLit backend itself

Returns (file_content, mime_type) if implemented, None otherwise.
"""
return None
192 changes: 192 additions & 0 deletions backend/chainlit/data/storage_clients/local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import mimetypes
import shutil
from pathlib import Path
from typing import Any, Dict, Union
from urllib.request import pathname2url

from chainlit import make_async
from chainlit.data.storage_clients.base import BaseStorageClient
from chainlit.logger import logger


class LocalStorageClient(BaseStorageClient):
"""
Class to enable local file system storage provider
"""

def __init__(self, storage_path: str):
try:
self.storage_path = Path(storage_path).resolve()

# Create storage directory if it doesn't exist
self.storage_path.mkdir(parents=True, exist_ok=True)

logger.info(
f"LocalStorageClient initialized with path: {self.storage_path}"
)
except Exception as e:
logger.warning(f"LocalStorageClient initialization error: {e}")
raise

def _validate_object_key(self, object_key: str) -> Path:
"""
Validate object_key and ensure the resolved path is within storage directory.

Args:
object_key: The object key to validate

Returns:
Resolved Path object within storage directory

Raises:
ValueError: If path traversal is detected or path is invalid
"""
try:
# Reject absolute paths immediately
if object_key.startswith("/"):
logger.warning(f"Absolute path rejected: {object_key}")
raise ValueError("Invalid object key: absolute paths not allowed")

# Normalize object_key and check for traversal patterns
normalized_key = object_key.strip()
if ".." in normalized_key or "\\" in normalized_key:
logger.warning(f"Path traversal patterns detected: {object_key}")
raise ValueError("Invalid object key: path traversal detected")

# Create the file path
file_path = self.storage_path / normalized_key
resolved_path = file_path.resolve()

# Ensure the resolved path is within the storage directory
resolved_path.relative_to(self.storage_path)

return resolved_path
except ValueError as e:
# Re-raise ValueError as is (our custom errors)
raise e
except Exception as e:
logger.warning(f"Path validation error for {object_key}: {e}")
raise ValueError(f"Invalid object key: {e}")

def sync_get_read_url(self, object_key: str) -> str:
try:
file_path = self._validate_object_key(object_key)
if file_path.exists():
# Return URL pointing to the backend's storage route
url_path = pathname2url(object_key)
return f"/storage/file/{url_path}"
else:
logger.warning(f"LocalStorageClient: File not found: {object_key}")
return object_key
except ValueError:
# Path validation failed, return object_key as fallback
return object_key
except Exception as e:
logger.warning(f"LocalStorageClient, get_read_url error: {e}")
return object_key

async def get_read_url(self, object_key: str) -> str:
return await make_async(self.sync_get_read_url)(object_key)

def sync_upload_file(
self,
object_key: str,
data: Union[bytes, str],
mime: str = "application/octet-stream",
overwrite: bool = True,
content_disposition: str | None = None,
) -> Dict[str, Any]:
try:
file_path = self._validate_object_key(object_key)

# Create parent directories if they don't exist
file_path.parent.mkdir(parents=True, exist_ok=True)

# Check if file exists and overwrite is False
if file_path.exists() and not overwrite:
logger.warning(
f"LocalStorageClient: File exists and overwrite=False: {object_key}"
)
return {}

# Write data to file
if isinstance(data, str):
file_path.write_text(data, encoding="utf-8")
else:
file_path.write_bytes(data)

# Generate URL for the uploaded file using backend's storage route
relative_path = file_path.relative_to(self.storage_path)
url_path = pathname2url(str(relative_path))
url = f"/storage/file/{url_path}"

return {"object_key": object_key, "url": url}
except ValueError as e:
logger.warning(f"LocalStorageClient, upload_file error: {e}")
return {}
except Exception as e:
logger.warning(f"LocalStorageClient, upload_file error: {e}")
return {}

async def upload_file(
self,
object_key: str,
data: Union[bytes, str],
mime: str = "application/octet-stream",
overwrite: bool = True,
content_disposition: str | None = None,
) -> Dict[str, Any]:
return await make_async(self.sync_upload_file)(
object_key, data, mime, overwrite, content_disposition
)

def sync_delete_file(self, object_key: str) -> bool:
try:
file_path = self._validate_object_key(object_key)
if file_path.exists():
if file_path.is_file():
file_path.unlink()
elif file_path.is_dir():
shutil.rmtree(file_path)
return True
else:
logger.warning(
f"LocalStorageClient: File not found for deletion: {object_key}"
)
return False
except ValueError as e:
logger.warning(f"LocalStorageClient, delete_file error: {e}")
return False
except Exception as e:
logger.warning(f"LocalStorageClient, delete_file error: {e}")
return False

async def delete_file(self, object_key: str) -> bool:
return await make_async(self.sync_delete_file)(object_key)

def sync_download_file(self, object_key: str) -> tuple[bytes, str] | None:
try:
file_path = self._validate_object_key(object_key)
if not file_path.exists() or not file_path.is_file():
logger.warning(
f"LocalStorageClient: File not found for download: {object_key}"
)
return None

# Get MIME type
mime_type, _ = mimetypes.guess_type(str(file_path))
if not mime_type:
mime_type = "application/octet-stream"

# Read file content
content = file_path.read_bytes()
return (content, mime_type)
except ValueError as e:
logger.warning(f"LocalStorageClient, download_file error: {e}")
return None
except Exception as e:
logger.warning(f"LocalStorageClient, download_file error: {e}")
return None

async def download_file(self, object_key: str) -> tuple[bytes, str] | None:
return await make_async(self.sync_download_file)(object_key)
92 changes: 92 additions & 0 deletions backend/chainlit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1636,6 +1636,98 @@ async def get_file(
raise HTTPException(status_code=404, detail="File not found")


@router.get("/storage/file/{object_key:path}")
async def get_storage_file(
object_key: str,
current_user: UserParam,
):
"""Get a file from the storage client if it supports direct downloads."""
from chainlit.data import get_data_layer

data_layer = get_data_layer()
if not data_layer or not data_layer.storage_client:
raise HTTPException(
status_code=404,
detail="Storage not configured",
)

# Validate user authentication
if not current_user:
raise HTTPException(status_code=401, detail="Unauthorized")

# Extract thread_id from object_key to validate thread ownership
# Object key patterns:
# 1. threads/{thread_id}/files/{element.id} (chainlit_data_layer)
# 2. {user_id}/{thread_id}/{element.id} (dynamodb)
# 3. {user_id}/{element.id}[/{element.name}] (sql_alchemy)
thread_id = None

# Try to extract thread_id from different patterns
parts = object_key.split("/")
if len(parts) >= 3:
if parts[0] == "threads":
# Pattern: threads/{thread_id}/files/{element.id}
thread_id = parts[1]
elif len(parts) == 3:
# Pattern: {user_id}/{thread_id}/{element.id} (dynamodb)
# We need to verify this is actually a thread_id by checking if it exists
potential_thread_id = parts[1]
try:
# Check if this looks like a thread by validating thread author
await is_thread_author(current_user.identifier, potential_thread_id)
thread_id = potential_thread_id
except HTTPException:
# Not a valid thread or user doesn't have access
pass

# If we found a thread_id, validate thread ownership
if thread_id:
await is_thread_author(current_user.identifier, thread_id)
else:
# For files without thread association (pattern 3), we should still
# validate that the user_id in the path matches the current user
if len(parts) >= 2:
user_id_in_path = parts[0]
if user_id_in_path != current_user.identifier:
raise HTTPException(
status_code=403,
detail="Access denied: file belongs to different user",
)

# Try to extract element_id and get the original filename from database
element_id = None
element_name = None

# Extract element_id from object_key patterns
if len(parts) >= 4 and parts[0] == "threads" and parts[2] == "files":
# Pattern: threads/{thread_id}/files/{element_id}
element_id = parts[3]
# Query database for element details
if thread_id and element_id:
element = await data_layer.get_element(thread_id, element_id)
if element:
element_name = element.get("name")

# Only serve files if storage client implements download_file
file_data = await data_layer.storage_client.download_file(object_key)
if file_data is None:
raise HTTPException(
status_code=404,
detail="File not found or storage client does not support direct downloads",
)

content, mime_type = file_data

# Use the original filename if available, otherwise fall back to the UUID
filename = element_name if element_name else Path(object_key).name

return Response(
content=content,
media_type=mime_type,
headers={"Content-Disposition": f"inline; filename={filename}"},
)


@router.get("/favicon")
async def get_favicon():
"""Get the favicon for the UI."""
Expand Down
11 changes: 10 additions & 1 deletion backend/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import pytest_asyncio

import chainlit.data as data_module
from chainlit import config
from chainlit.callbacks import data_layer
from chainlit.context import ChainlitContext, context_var
Expand Down Expand Up @@ -94,10 +95,18 @@ def mock_data_layer(monkeypatch: pytest.MonkeyPatch) -> AsyncMock:


@pytest.fixture
def mock_get_data_layer(mock_data_layer: AsyncMock, test_config: config.ChainlitConfig):
def mock_get_data_layer(
mock_data_layer: AsyncMock,
test_config: config.ChainlitConfig,
monkeypatch: pytest.MonkeyPatch,
):
# Instantiate mock data layer
mock_get_data_layer = Mock(return_value=mock_data_layer)

# Clear the cached data layer so every test exercises its own factory.
monkeypatch.setattr(data_module, "_data_layer", None)
monkeypatch.setattr(data_module, "_data_layer_initialized", False)

# Configure it using @data_layer decorator
return data_layer(mock_get_data_layer)

Expand Down
Loading
Loading