Skip to content

Commit

Permalink
test: Rewrite tests using pytest and organize into a class structure
Browse files Browse the repository at this point in the history
  • Loading branch information
sysradium committed Feb 18, 2025
1 parent 8809bd9 commit 532d4fa
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 95 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,4 @@ archive/
savedir/
output/
tool_output/
.aider*
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ dependencies = [
"pillow>=11.0.0",
"markdownify>=0.14.1",
"duckduckgo-search>=6.3.7",
"python-dotenv",
"filetype"
"python-magic>=0.4.27,<0.5.0",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -76,6 +75,7 @@ test = [
"python-dotenv>=1.0.1", # For test_all_docs
"smolagents[all]",
"rank-bm25", # For test_all_docs
"pytest-datadir",
]
dev = [
"smolagents[quality,test]",
Expand Down
5 changes: 1 addition & 4 deletions src/smolagents/gradio_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
import re
import shutil
from typing import Optional
from uuid import main

import magic
from numpy import core

from smolagents.agent_types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
from smolagents.agents import ActionStep, MultiStepAgent
Expand Down Expand Up @@ -232,7 +230,6 @@ def upload_file(self, file, file_uploads_log, allowed_file_types=None):
file_path = os.path.join(self.file_upload_folder, f"{base_name}_{counter}{ext}")
counter += 1

print("COPYING ", file_path, ext)
shutil.copy(file.name, file_path)

return gr.Textbox(f"File uploaded: {file_path}", visible=True), file_uploads_log + [file_path]
Expand All @@ -244,7 +241,7 @@ def _validate_file_type(self, file, allowed_file_types) -> tuple[str, str]:
raise Exception(f"Error reading file: {e}")

if kind not in (*allowed_file_types, "inode/x-empty"):
raise Exception(f"File type disallowed or undetected: {kind}")
raise Exception("File type disallowed")

original_name = os.path.basename(file.name)
sanitized_name = re.sub(r"[^\w\-.]", "_", original_name)
Expand Down
169 changes: 80 additions & 89 deletions tests/test_gradio_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,111 +14,102 @@
# limitations under the License.

import os
import shutil
import tempfile
from unittest.mock import Mock, patch
from unittest.mock import Mock

import pytest

from smolagents.gradio_ui import GradioUI


class GradioUITester:
def setUp(self):
"""Initialize test environment"""
self.temp_dir = tempfile.mkdtemp()
self.mock_agent = Mock()
self.ui = GradioUI(agent=self.mock_agent, file_upload_folder=self.temp_dir)
self.allowed_types = [".pdf", ".docx", ".txt"]
@pytest.fixture
def file_upload_dir(tmpdir):
return tmpdir.mkdir("file_uploads")


@pytest.fixture
def ui(file_upload_dir):
mock_agent = Mock()
return GradioUI(agent=mock_agent, file_upload_folder=file_upload_dir)


class TestGradioUI:
def test_upload_file_allows_empty_file(self, ui, tmpdir):
empty_file_path = tmpdir.join("empty.txt")
empty_file_path.write("")

with open(empty_file_path) as f:
textbox, uploads_log = ui.upload_file(f, [])

def tearDown(self):
"""Clean up test environment"""
shutil.rmtree(self.temp_dir)
assert "File uploaded:" in textbox.value
assert len(uploads_log) == 1
assert os.path.exists(tmpdir / empty_file_path.basename)

def test_upload_file_default_types(self):
"""Test default allowed file types"""
default_types = [".pdf", ".docx", ".txt"]
for file_type in default_types:
with tempfile.NamedTemporaryFile(suffix=file_type) as temp_file:
mock_file = Mock()
mock_file.name = temp_file.name
def test_upload_file_default_types_disallowed(self, ui, datadir):
with open(datadir / "image.png") as file:
textbox, uploads_log = ui.upload_file(file, [])

textbox, uploads_log = self.ui.upload_file(mock_file, [])
assert textbox.value == "File type disallowed"
assert len(uploads_log) == 0

self.assertIn("File uploaded:", textbox.value)
self.assertEqual(len(uploads_log), 1)
self.assertTrue(os.path.exists(os.path.join(self.temp_dir, os.path.basename(temp_file.name))))
@pytest.mark.parametrize("sample_file_name", ["empty.pdf", "file.txt", "sample.docx"])
def test_upload_file_success(self, ui, sample_file_name, datadir, file_upload_dir):
with open(datadir / sample_file_name) as f:
textbox, uploads_log = ui.upload_file(
f,
[],
allowed_file_types=[
"text/plain",
"application/pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
],
)

def test_upload_file_default_types_disallowed(self):
"""Test default disallowed file types"""
disallowed_types = [".exe", ".sh", ".py", ".jpg"]
for file_type in disallowed_types:
with tempfile.NamedTemporaryFile(suffix=file_type) as temp_file:
mock_file = Mock()
mock_file.name = temp_file.name
assert "File uploaded:" in textbox.value
assert len(uploads_log) == 1

textbox, uploads_log = self.ui.upload_file(mock_file, [])
assert (file_upload_dir / sample_file_name).exists()
assert uploads_log[0] == file_upload_dir / sample_file_name

self.assertEqual(textbox.value, "File type disallowed")
self.assertEqual(len(uploads_log), 0)
def test_upload_file_creates_new_file_if_filenames_clash(self, ui, file_upload_dir, tmpdir):
empty_file_path = tmpdir.join("empty.txt")
empty_file_path.write("")

def test_upload_file_success(self):
"""Test successful file upload scenario"""
with tempfile.NamedTemporaryFile(suffix=".txt") as temp_file:
mock_file = Mock()
mock_file.name = temp_file.name
with open(empty_file_path) as f:
_, uploads_log = ui.upload_file(f, [])

textbox, uploads_log = self.ui.upload_file(mock_file, [])
assert len(uploads_log) == 1

self.assertIn("File uploaded:", textbox.value)
self.assertEqual(len(uploads_log), 1)
self.assertTrue(os.path.exists(os.path.join(self.temp_dir, os.path.basename(temp_file.name))))
self.assertEqual(uploads_log[0], os.path.join(self.temp_dir, os.path.basename(temp_file.name)))
textbox, uploads_log = ui.upload_file(f, [])

def test_upload_file_none(self):
assert "File uploaded:" in textbox.value
assert len(uploads_log) == 1

# now there should be 2 files, even though file-names are duplicated
assert len(file_upload_dir.listdir()) == 2
assert {"empty.txt", "empty_1.txt"} == {f.basename for f in file_upload_dir.listdir() if f.isfile()}

def test_upload_file_none(self, ui):
"""Test scenario when no file is selected"""
textbox, uploads_log = self.ui.upload_file(None, [])
textbox, uploads_log = ui.upload_file(None, [])

self.assertEqual(textbox.value, "No file uploaded")
self.assertEqual(len(uploads_log), 0)
assert textbox.value == "No file uploaded"
assert len(uploads_log) == 0

def test_upload_file_invalid_type(self):
def test_upload_file_invalid_type(self, ui, datadir):
"""Test disallowed file type"""
with tempfile.NamedTemporaryFile(suffix=".exe") as temp_file:
mock_file = Mock()
mock_file.name = temp_file.name

textbox, uploads_log = self.ui.upload_file(mock_file, [])

self.assertEqual(textbox.value, "File type disallowed")
self.assertEqual(len(uploads_log), 0)

def test_upload_file_special_chars(self):
"""Test scenario with special characters in filename"""
with tempfile.NamedTemporaryFile(suffix=".txt") as temp_file:
# Create a new temporary file with special characters
special_char_name = os.path.join(os.path.dirname(temp_file.name), "test@#$%^&*.txt")
shutil.copy(temp_file.name, special_char_name)
try:
mock_file = Mock()
mock_file.name = special_char_name

with patch("shutil.copy"):
textbox, uploads_log = self.ui.upload_file(mock_file, [])

self.assertIn("File uploaded:", textbox.value)
self.assertEqual(len(uploads_log), 1)
self.assertIn("test_____", uploads_log[0])
finally:
# Clean up the special character file
if os.path.exists(special_char_name):
os.remove(special_char_name)

def test_upload_file_custom_types(self):
"""Test custom allowed file types"""
with tempfile.NamedTemporaryFile(suffix=".csv") as temp_file:
mock_file = Mock()
mock_file.name = temp_file.name

textbox, uploads_log = self.ui.upload_file(mock_file, [], allowed_file_types=["text/csv"])

self.assertIn("File uploaded:", textbox.value)
self.assertEqual(len(uploads_log), 1)
with open(datadir / "empty.pdf") as file:
textbox, uploads_log = ui.upload_file(file, [], allowed_file_types=["text/plain"])

assert textbox.value == "File type disallowed"
assert len(uploads_log) == 0

def test_upload_file_special_chars(self, ui, tmpdir):
special_char_name = tmpdir.join("test@#$%^&*.txt")
special_char_name.write("something")

with open(special_char_name) as mock_file:
textbox, uploads_log = ui.upload_file(mock_file, [])

assert "File uploaded:" in textbox.value
assert len(uploads_log) == 1
assert "test_____" in uploads_log[0]
Binary file added tests/test_gradio_ui/empty.pdf
Binary file not shown.
1 change: 1 addition & 0 deletions tests/test_gradio_ui/file.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sample content
Binary file added tests/test_gradio_ui/image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/test_gradio_ui/sample.docx
Binary file not shown.

0 comments on commit 532d4fa

Please sign in to comment.