Skip to content

Commit

Permalink
use filetype for better type detection in GradIO
Browse files Browse the repository at this point in the history
  • Loading branch information
sysradium committed Feb 9, 2025
1 parent 63adfcd commit 2349ff7
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 25 deletions.
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ dependencies = [
"pillow>=11.0.0",
"markdownify>=0.14.1",
"duckduckgo-search>=6.3.7",
"python-dotenv"
"python-dotenv",
"filetype"
]

[project.optional-dependencies]
Expand Down Expand Up @@ -104,4 +105,4 @@ lines-after-imports = 2

[project.scripts]
smolagent = "smolagents.cli:main"
webagent = "smolagents.vision_web_browser:main"
webagent = "smolagents.vision_web_browser:main"
54 changes: 31 additions & 23 deletions src/smolagents/gradio_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import mimetypes
import os
import re
import shutil
from typing import Optional

import filetype

from smolagents.agent_types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
from smolagents.agents import ActionStep, MultiStepAgent
from smolagents.memory import MemoryStep
Expand Down Expand Up @@ -210,39 +211,46 @@ def upload_file(
],
):
"""
Handle file uploads, default allowed types are .pdf, .docx, and .txt
Secure file upload handling with real MIME-type validation.
Uses `filetype` to dynamically determine the correct file extension.
"""
import gradio as gr

if file is None:
return gr.Textbox("No file uploaded", visible=True), file_uploads_log

# Read a small part of the file to determine type
try:
mime_type, _ = mimetypes.guess_type(file.name)
kind = filetype.guess(file.name) # Detect file type
except Exception as e:
return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log
return gr.Textbox(f"Error reading file: {e}", visible=True), file_uploads_log

if mime_type not in allowed_file_types:
return gr.Textbox("File type disallowed", visible=True), file_uploads_log
if not kind or kind.mime not in allowed_file_types:
return gr.Textbox(
f"File type disallowed or undetected: {kind.mime if kind else 'Unknown'}", visible=True
), file_uploads_log

# Sanitize file name
# Sanitize filename (keep only safe characters)
original_name = os.path.basename(file.name)
sanitized_name = re.sub(
r"[^\w\-.]", "_", original_name
) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores

type_to_ext = {}
for ext, t in mimetypes.types_map.items():
if t not in type_to_ext:
type_to_ext[t] = ext

# Ensure the extension correlates to the mime type
sanitized_name = sanitized_name.split(".")[:-1]
sanitized_name.append("" + type_to_ext[mime_type])
sanitized_name = "".join(sanitized_name)

# Save the uploaded file to the specified folder
file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name))
sanitized_name = re.sub(r"[^\w\-.]", "_", original_name) # Replace unsafe chars

# Get correct extension from `filetype`
correct_ext = f".{kind.extension}"

# Ensure correct extension is used
base_name, _ = os.path.splitext(sanitized_name) # Remove existing extension
sanitized_name = f"{base_name}{correct_ext}" # Assign correct extension

# Define full file path
file_path = os.path.join(self.file_upload_folder, sanitized_name)

# Prevent overwriting files by appending a counter if needed
counter = 1
while os.path.exists(file_path):
file_path = os.path.join(self.file_upload_folder, f"{base_name}_{counter}{correct_ext}")
counter += 1

# Save the uploaded file securely
shutil.copy(file.name, file_path)

return gr.Textbox(f"File uploaded: {file_path}", visible=True), file_uploads_log + [file_path]
Expand Down

0 comments on commit 2349ff7

Please sign in to comment.