diff --git a/.gitignore b/.gitignore index 59bba3ae6..e44e6308c 100644 --- a/.gitignore +++ b/.gitignore @@ -150,3 +150,4 @@ archive/ savedir/ output/ tool_output/ +.aider* diff --git a/pyproject.toml b/pyproject.toml index dbce2ebe4..69efa6e25 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,8 +20,7 @@ dependencies = [ "pillow>=11.0.0", "markdownify>=0.14.1", "duckduckgo-search>=6.3.7", - "python-dotenv", - "filetype" + "python-magic>=0.4.27,<0.5.0", ] [project.optional-dependencies] @@ -76,6 +75,7 @@ test = [ "python-dotenv>=1.0.1", # For test_all_docs "smolagents[all]", "rank-bm25", # For test_all_docs + "pytest-datadir", ] dev = [ "smolagents[quality,test]", diff --git a/src/smolagents/gradio_ui.py b/src/smolagents/gradio_ui.py index 645c6caf5..95f6d7099 100644 --- a/src/smolagents/gradio_ui.py +++ b/src/smolagents/gradio_ui.py @@ -17,10 +17,8 @@ import re import shutil from typing import Optional -from uuid import main import magic -from numpy import core from smolagents.agent_types import AgentAudio, AgentImage, AgentText, handle_agent_output_types from smolagents.agents import ActionStep, MultiStepAgent @@ -214,6 +212,9 @@ def upload_file(self, file, file_uploads_log, allowed_file_types=None): "application/pdf", "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "text/plain", + "text/html", + "application/json", + "applicaion/xml", ] if file is None: @@ -232,7 +233,6 @@ def upload_file(self, file, file_uploads_log, allowed_file_types=None): file_path = os.path.join(self.file_upload_folder, f"{base_name}_{counter}{ext}") counter += 1 - print("COPYING ", file_path, ext) shutil.copy(file.name, file_path) return gr.Textbox(f"File uploaded: {file_path}", visible=True), file_uploads_log + [file_path] @@ -244,7 +244,7 @@ def _validate_file_type(self, file, allowed_file_types) -> tuple[str, str]: raise Exception(f"Error reading file: {e}") if kind not in (*allowed_file_types, "inode/x-empty"): - raise Exception(f"File type disallowed or undetected: {kind}") + raise Exception("File type disallowed") original_name = os.path.basename(file.name) sanitized_name = re.sub(r"[^\w\-.]", "_", original_name) diff --git a/tests/test_gradio_ui.py b/tests/test_gradio_ui.py index f611b79ab..fb7f50b52 100644 --- a/tests/test_gradio_ui.py +++ b/tests/test_gradio_ui.py @@ -14,111 +14,102 @@ # limitations under the License. import os -import shutil -import tempfile -from unittest.mock import Mock, patch +from unittest.mock import Mock + +import pytest from smolagents.gradio_ui import GradioUI -class GradioUITester: - def setUp(self): - """Initialize test environment""" - self.temp_dir = tempfile.mkdtemp() - self.mock_agent = Mock() - self.ui = GradioUI(agent=self.mock_agent, file_upload_folder=self.temp_dir) - self.allowed_types = [".pdf", ".docx", ".txt"] +@pytest.fixture +def file_upload_dir(tmpdir): + return tmpdir.mkdir("file_uploads") + + +@pytest.fixture +def ui(file_upload_dir): + mock_agent = Mock() + return GradioUI(agent=mock_agent, file_upload_folder=file_upload_dir) + + +class TestGradioUI: + def test_upload_file_allows_empty_file(self, ui, tmpdir): + empty_file_path = tmpdir.join("empty.txt") + empty_file_path.write("") + + with open(empty_file_path) as f: + textbox, uploads_log = ui.upload_file(f, []) - def tearDown(self): - """Clean up test environment""" - shutil.rmtree(self.temp_dir) + assert "File uploaded:" in textbox.value + assert len(uploads_log) == 1 + assert os.path.exists(tmpdir / empty_file_path.basename) - def test_upload_file_default_types(self): - """Test default allowed file types""" - default_types = [".pdf", ".docx", ".txt"] - for file_type in default_types: - with tempfile.NamedTemporaryFile(suffix=file_type) as temp_file: - mock_file = Mock() - mock_file.name = temp_file.name + def test_upload_file_default_types_disallowed(self, ui, datadir): + with open(datadir / "image.png") as file: + textbox, uploads_log = ui.upload_file(file, []) - textbox, uploads_log = self.ui.upload_file(mock_file, []) + assert textbox.value == "File type disallowed" + assert len(uploads_log) == 0 - self.assertIn("File uploaded:", textbox.value) - self.assertEqual(len(uploads_log), 1) - self.assertTrue(os.path.exists(os.path.join(self.temp_dir, os.path.basename(temp_file.name)))) + @pytest.mark.parametrize("sample_file_name", ["empty.pdf", "file.txt", "sample.docx"]) + def test_upload_file_success(self, ui, sample_file_name, datadir, file_upload_dir): + with open(datadir / sample_file_name) as f: + textbox, uploads_log = ui.upload_file( + f, + [], + allowed_file_types=[ + "text/plain", + "application/pdf", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ], + ) - def test_upload_file_default_types_disallowed(self): - """Test default disallowed file types""" - disallowed_types = [".exe", ".sh", ".py", ".jpg"] - for file_type in disallowed_types: - with tempfile.NamedTemporaryFile(suffix=file_type) as temp_file: - mock_file = Mock() - mock_file.name = temp_file.name + assert "File uploaded:" in textbox.value + assert len(uploads_log) == 1 - textbox, uploads_log = self.ui.upload_file(mock_file, []) + assert (file_upload_dir / sample_file_name).exists() + assert uploads_log[0] == file_upload_dir / sample_file_name - self.assertEqual(textbox.value, "File type disallowed") - self.assertEqual(len(uploads_log), 0) + def test_upload_file_creates_new_file_if_filenames_clash(self, ui, file_upload_dir, tmpdir): + empty_file_path = tmpdir.join("empty.txt") + empty_file_path.write("") - def test_upload_file_success(self): - """Test successful file upload scenario""" - with tempfile.NamedTemporaryFile(suffix=".txt") as temp_file: - mock_file = Mock() - mock_file.name = temp_file.name + with open(empty_file_path) as f: + _, uploads_log = ui.upload_file(f, []) - textbox, uploads_log = self.ui.upload_file(mock_file, []) + assert len(uploads_log) == 1 - self.assertIn("File uploaded:", textbox.value) - self.assertEqual(len(uploads_log), 1) - self.assertTrue(os.path.exists(os.path.join(self.temp_dir, os.path.basename(temp_file.name)))) - self.assertEqual(uploads_log[0], os.path.join(self.temp_dir, os.path.basename(temp_file.name))) + textbox, uploads_log = ui.upload_file(f, []) - def test_upload_file_none(self): + assert "File uploaded:" in textbox.value + assert len(uploads_log) == 1 + + # now there should be 2 files, even though file-names are duplicated + assert len(file_upload_dir.listdir()) == 2 + assert {"empty.txt", "empty_1.txt"} == {f.basename for f in file_upload_dir.listdir() if f.isfile()} + + def test_upload_file_none(self, ui): """Test scenario when no file is selected""" - textbox, uploads_log = self.ui.upload_file(None, []) + textbox, uploads_log = ui.upload_file(None, []) - self.assertEqual(textbox.value, "No file uploaded") - self.assertEqual(len(uploads_log), 0) + assert textbox.value == "No file uploaded" + assert len(uploads_log) == 0 - def test_upload_file_invalid_type(self): + def test_upload_file_invalid_type(self, ui, datadir): """Test disallowed file type""" - with tempfile.NamedTemporaryFile(suffix=".exe") as temp_file: - mock_file = Mock() - mock_file.name = temp_file.name - - textbox, uploads_log = self.ui.upload_file(mock_file, []) - - self.assertEqual(textbox.value, "File type disallowed") - self.assertEqual(len(uploads_log), 0) - - def test_upload_file_special_chars(self): - """Test scenario with special characters in filename""" - with tempfile.NamedTemporaryFile(suffix=".txt") as temp_file: - # Create a new temporary file with special characters - special_char_name = os.path.join(os.path.dirname(temp_file.name), "test@#$%^&*.txt") - shutil.copy(temp_file.name, special_char_name) - try: - mock_file = Mock() - mock_file.name = special_char_name - - with patch("shutil.copy"): - textbox, uploads_log = self.ui.upload_file(mock_file, []) - - self.assertIn("File uploaded:", textbox.value) - self.assertEqual(len(uploads_log), 1) - self.assertIn("test_____", uploads_log[0]) - finally: - # Clean up the special character file - if os.path.exists(special_char_name): - os.remove(special_char_name) - - def test_upload_file_custom_types(self): - """Test custom allowed file types""" - with tempfile.NamedTemporaryFile(suffix=".csv") as temp_file: - mock_file = Mock() - mock_file.name = temp_file.name - - textbox, uploads_log = self.ui.upload_file(mock_file, [], allowed_file_types=["text/csv"]) - - self.assertIn("File uploaded:", textbox.value) - self.assertEqual(len(uploads_log), 1) + with open(datadir / "empty.pdf") as file: + textbox, uploads_log = ui.upload_file(file, [], allowed_file_types=["text/plain"]) + + assert textbox.value == "File type disallowed" + assert len(uploads_log) == 0 + + def test_upload_file_special_chars(self, ui, tmpdir): + special_char_name = tmpdir.join("test@#$%^&*.txt") + special_char_name.write("something") + + with open(special_char_name) as mock_file: + textbox, uploads_log = ui.upload_file(mock_file, []) + + assert "File uploaded:" in textbox.value + assert len(uploads_log) == 1 + assert "test_____" in uploads_log[0] diff --git a/tests/test_gradio_ui/empty.pdf b/tests/test_gradio_ui/empty.pdf new file mode 100644 index 000000000..656152a3c Binary files /dev/null and b/tests/test_gradio_ui/empty.pdf differ diff --git a/tests/test_gradio_ui/file.txt b/tests/test_gradio_ui/file.txt new file mode 100644 index 000000000..4b4f223d5 --- /dev/null +++ b/tests/test_gradio_ui/file.txt @@ -0,0 +1 @@ +sample content diff --git a/tests/test_gradio_ui/image.png b/tests/test_gradio_ui/image.png new file mode 100644 index 000000000..a3b5225fc Binary files /dev/null and b/tests/test_gradio_ui/image.png differ diff --git a/tests/test_gradio_ui/sample.docx b/tests/test_gradio_ui/sample.docx new file mode 100644 index 000000000..f3b4cf287 Binary files /dev/null and b/tests/test_gradio_ui/sample.docx differ