Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use filetype for better type detection in GradIO #569

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ dependencies = [
"pillow>=11.0.0",
"markdownify>=0.14.1",
"duckduckgo-search>=6.3.7",
"python-dotenv"
"python-magic>=0.4.27,<0.5.0",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is adding this import really necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aymeric-roucher

It is needed if we want to use python-magic for filetype detection. If we don't want then the most of the pr is pointless.

"python-dotenv",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -79,6 +80,7 @@ test = [
"python-dotenv>=1.0.1", # For test_all_docs
"smolagents[all]",
"rank-bm25", # For test_all_docs
"pytest-datadir",
]
dev = [
"smolagents[quality,test]",
Expand Down Expand Up @@ -111,4 +113,4 @@ lines-after-imports = 2

[project.scripts]
smolagent = "smolagents.cli:main"
webagent = "smolagents.vision_web_browser:main"
webagent = "smolagents.vision_web_browser:main"
53 changes: 40 additions & 13 deletions src/smolagents/gradio_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import shutil
from typing import Optional

import magic

from smolagents.agent_types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
from smolagents.agents import ActionStep, MultiStepAgent
from smolagents.memory import MemoryStep
Expand Down Expand Up @@ -213,32 +215,57 @@ def interact_with_agent(self, prompt, messages, session_state):

def upload_file(self, file, file_uploads_log, allowed_file_types=None):
"""
Handle file uploads, default allowed types are .pdf, .docx, and .txt
Secure file upload handling with real MIME-type validation.
Uses `filetype` to dynamically determine the correct file extension.
"""
import gradio as gr

if allowed_file_types is None:
allowed_file_types = [
"application/pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"text/plain",
"text/html",
"application/json",
"applicaion/xml",
]

if file is None:
return gr.Textbox(value="No file uploaded", visible=True), file_uploads_log

if allowed_file_types is None:
allowed_file_types = [".pdf", ".docx", ".txt"]
try:
base_name, ext = self._validate_file_type(file, allowed_file_types)
except Exception as e:
return gr.Textbox(value=str(e), visible=True), file_uploads_log

file_ext = os.path.splitext(file.name)[1].lower()
if file_ext not in allowed_file_types:
return gr.Textbox("File type disallowed", visible=True), file_uploads_log
file_path = os.path.join(self.file_upload_folder, f"{base_name}{ext}")

# Sanitize file name
original_name = os.path.basename(file.name)
sanitized_name = re.sub(
r"[^\w\-.]", "_", original_name
) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores
# Prevent overwriting files by appending a counter if needed
counter = 1
while os.path.exists(file_path):
file_path = os.path.join(self.file_upload_folder, f"{base_name}_{counter}{ext}")
counter += 1

# Save the uploaded file to the specified folder
file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name))
shutil.copy(file.name, file_path)

return gr.Textbox(f"File uploaded: {file_path}", visible=True), file_uploads_log + [file_path]

def _validate_file_type(self, file, allowed_file_types) -> tuple[str, str]:
try:
kind = magic.from_file(file.name, mime=True)
except Exception as e:
raise Exception(f"Error reading file: {e}")

if kind not in (*allowed_file_types, "inode/x-empty"):
raise Exception("File type disallowed")

original_name = os.path.basename(file.name)
sanitized_name = re.sub(r"[^\w\-.]", "_", original_name)

base_name, ext = os.path.splitext(sanitized_name)

return (base_name, ext)

def log_user_message(self, text_input, file_uploads_log):
import gradio as gr

Expand Down
170 changes: 80 additions & 90 deletions tests/test_gradio_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,112 +14,102 @@
# limitations under the License.

import os
import shutil
import tempfile
import unittest
from unittest.mock import Mock, patch
from unittest.mock import Mock

import pytest

from smolagents.gradio_ui import GradioUI


class GradioUITester(unittest.TestCase):
def setUp(self):
"""Initialize test environment"""
self.temp_dir = tempfile.mkdtemp()
self.mock_agent = Mock()
self.ui = GradioUI(agent=self.mock_agent, file_upload_folder=self.temp_dir)
self.allowed_types = [".pdf", ".docx", ".txt"]
@pytest.fixture
def file_upload_dir(tmpdir):
return tmpdir.mkdir("file_uploads")


@pytest.fixture
def ui(file_upload_dir):
mock_agent = Mock()
return GradioUI(agent=mock_agent, file_upload_folder=file_upload_dir)


class TestGradioUI:
def test_upload_file_allows_empty_file(self, ui, tmpdir):
empty_file_path = tmpdir.join("empty.txt")
empty_file_path.write("")

with open(empty_file_path) as f:
textbox, uploads_log = ui.upload_file(f, [])

def tearDown(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove this method?

Copy link
Contributor Author

@sysradium sysradium Mar 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it is no-longer needed. Each test receives its own file_upload_dir, tmpdir fixture which are automatically wiped out by pytest itself.

"""Clean up test environment"""
shutil.rmtree(self.temp_dir)
assert "File uploaded:" in textbox.value
assert len(uploads_log) == 1
assert os.path.exists(tmpdir / empty_file_path.basename)

def test_upload_file_default_types(self):
"""Test default allowed file types"""
default_types = [".pdf", ".docx", ".txt"]
for file_type in default_types:
with tempfile.NamedTemporaryFile(suffix=file_type) as temp_file:
mock_file = Mock()
mock_file.name = temp_file.name
def test_upload_file_default_types_disallowed(self, ui, datadir):
with open(datadir / "image.png") as file:
textbox, uploads_log = ui.upload_file(file, [])

textbox, uploads_log = self.ui.upload_file(mock_file, [])
assert textbox.value == "File type disallowed"
assert len(uploads_log) == 0

self.assertIn("File uploaded:", textbox.value)
self.assertEqual(len(uploads_log), 1)
self.assertTrue(os.path.exists(os.path.join(self.temp_dir, os.path.basename(temp_file.name))))
@pytest.mark.parametrize("sample_file_name", ["empty.pdf", "file.txt", "sample.docx"])
def test_upload_file_success(self, ui, sample_file_name, datadir, file_upload_dir):
with open(datadir / sample_file_name) as f:
textbox, uploads_log = ui.upload_file(
f,
[],
allowed_file_types=[
"text/plain",
"application/pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
],
)

def test_upload_file_default_types_disallowed(self):
"""Test default disallowed file types"""
disallowed_types = [".exe", ".sh", ".py", ".jpg"]
for file_type in disallowed_types:
with tempfile.NamedTemporaryFile(suffix=file_type) as temp_file:
mock_file = Mock()
mock_file.name = temp_file.name
assert "File uploaded:" in textbox.value
assert len(uploads_log) == 1

textbox, uploads_log = self.ui.upload_file(mock_file, [])
assert (file_upload_dir / sample_file_name).exists()
assert uploads_log[0] == file_upload_dir / sample_file_name

self.assertEqual(textbox.value, "File type disallowed")
self.assertEqual(len(uploads_log), 0)
def test_upload_file_creates_new_file_if_filenames_clash(self, ui, file_upload_dir, tmpdir):
empty_file_path = tmpdir.join("empty.txt")
empty_file_path.write("")

def test_upload_file_success(self):
"""Test successful file upload scenario"""
with tempfile.NamedTemporaryFile(suffix=".txt") as temp_file:
mock_file = Mock()
mock_file.name = temp_file.name
with open(empty_file_path) as f:
_, uploads_log = ui.upload_file(f, [])

textbox, uploads_log = self.ui.upload_file(mock_file, [])
assert len(uploads_log) == 1

self.assertIn("File uploaded:", textbox.value)
self.assertEqual(len(uploads_log), 1)
self.assertTrue(os.path.exists(os.path.join(self.temp_dir, os.path.basename(temp_file.name))))
self.assertEqual(uploads_log[0], os.path.join(self.temp_dir, os.path.basename(temp_file.name)))
textbox, uploads_log = ui.upload_file(f, [])

def test_upload_file_none(self):
assert "File uploaded:" in textbox.value
assert len(uploads_log) == 1

# now there should be 2 files, even though file-names are duplicated
assert len(file_upload_dir.listdir()) == 2
assert {"empty.txt", "empty_1.txt"} == {f.basename for f in file_upload_dir.listdir() if f.isfile()}

def test_upload_file_none(self, ui):
"""Test scenario when no file is selected"""
textbox, uploads_log = self.ui.upload_file(None, [])
textbox, uploads_log = ui.upload_file(None, [])

self.assertEqual(textbox.value, "No file uploaded")
self.assertEqual(len(uploads_log), 0)
assert textbox.value == "No file uploaded"
assert len(uploads_log) == 0

def test_upload_file_invalid_type(self):
def test_upload_file_invalid_type(self, ui, datadir):
"""Test disallowed file type"""
with tempfile.NamedTemporaryFile(suffix=".exe") as temp_file:
mock_file = Mock()
mock_file.name = temp_file.name

textbox, uploads_log = self.ui.upload_file(mock_file, [])

self.assertEqual(textbox.value, "File type disallowed")
self.assertEqual(len(uploads_log), 0)

def test_upload_file_special_chars(self):
"""Test scenario with special characters in filename"""
with tempfile.NamedTemporaryFile(suffix=".txt") as temp_file:
# Create a new temporary file with special characters
special_char_name = os.path.join(os.path.dirname(temp_file.name), "test@#$%^&*.txt")
shutil.copy(temp_file.name, special_char_name)
try:
mock_file = Mock()
mock_file.name = special_char_name

with patch("shutil.copy"):
textbox, uploads_log = self.ui.upload_file(mock_file, [])

self.assertIn("File uploaded:", textbox.value)
self.assertEqual(len(uploads_log), 1)
self.assertIn("test_____", uploads_log[0])
finally:
# Clean up the special character file
if os.path.exists(special_char_name):
os.remove(special_char_name)

def test_upload_file_custom_types(self):
"""Test custom allowed file types"""
with tempfile.NamedTemporaryFile(suffix=".csv") as temp_file:
mock_file = Mock()
mock_file.name = temp_file.name

textbox, uploads_log = self.ui.upload_file(mock_file, [], allowed_file_types=[".csv"])

self.assertIn("File uploaded:", textbox.value)
self.assertEqual(len(uploads_log), 1)
with open(datadir / "empty.pdf") as file:
textbox, uploads_log = ui.upload_file(file, [], allowed_file_types=["text/plain"])

assert textbox.value == "File type disallowed"
assert len(uploads_log) == 0

def test_upload_file_special_chars(self, ui, tmpdir):
special_char_name = tmpdir.join("test@#$%^&*.txt")
special_char_name.write("something")

with open(special_char_name) as mock_file:
textbox, uploads_log = ui.upload_file(mock_file, [])

assert "File uploaded:" in textbox.value
assert len(uploads_log) == 1
assert "test_____" in uploads_log[0]
Binary file added tests/test_gradio_ui/empty.pdf
Binary file not shown.
1 change: 1 addition & 0 deletions tests/test_gradio_ui/file.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sample content
Binary file added tests/test_gradio_ui/image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/test_gradio_ui/sample.docx
Binary file not shown.