Skip to content

Commit

Permalink
Merge branch 'huggingface:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
colesmcintosh authored Feb 19, 2025
2 parents 2f12c59 + 1f998f9 commit 8f941e6
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 44 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ jobs:
uv run pytest ./tests/test_tools.py
if: ${{ success() || failure() }}

- name: Tool validation tests
run: |
uv run pytest ./tests/test_tool_validation.py
if: ${{ success() || failure() }}

- name: Types tests
run: |
uv run pytest ./tests/test_types.py
Expand Down
19 changes: 5 additions & 14 deletions docs/source/en/tutorials/inspect_runs.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ Here's how it then looks like on the platform:
First install the required packages. Here we install [Phoenix by Arize AI](https://github.com/Arize-ai/phoenix) because that's a good solution to collect and inspect the logs, but there are other OpenTelemetry-compatible platforms that you could use for this collection & inspection part.

```shell
pip install smolagents
pip install arize-phoenix opentelemetry-sdk opentelemetry-exporter-otlp openinference-instrumentation-smolagents
pip install 'smolagents[telemetry]'
```

Then run the collector in the background.
Expand All @@ -55,22 +54,14 @@ Then run the collector in the background.
python -m phoenix.server.main serve
```

Finally, set up `SmolagentsInstrumentor` to trace your agents and send the traces to Phoenix at the endpoint defined below.
Finally, set up `SmolagentsInstrumentor` to trace your agents and send the traces to Phoenix default endpoint.

```python
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor

from phoenix.otel import register
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.trace.export import ConsoleSpanExporter, SimpleSpanProcessor

endpoint = "http://0.0.0.0:6006/v1/traces"
trace_provider = TracerProvider()
trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint)))

SmolagentsInstrumentor().instrument(tracer_provider=trace_provider)
register()
SmolagentsInstrumentor().instrument()
```
Then you can run your agents!

Expand Down
16 changes: 6 additions & 10 deletions examples/inspect_multiagent_run.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from phoenix.otel import register


register()
SmolagentsInstrumentor().instrument(skip_dep_check=True)


from smolagents import (
CodeAgent,
Expand All @@ -12,13 +15,6 @@
)


# Let's setup the instrumentation first

trace_provider = TracerProvider()
trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter("http://0.0.0.0:6006/v1/traces")))

SmolagentsInstrumentor().instrument(tracer_provider=trace_provider, skip_dep_check=True)

# Then we run the agentic part!
model = HfApiModel()

Expand Down
48 changes: 30 additions & 18 deletions src/smolagents/default_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,48 +138,62 @@ class GoogleSearchTool(Tool):
}
output_type = "string"

def __init__(self):
def __init__(self, provider: str = "serpapi"):
super().__init__(self)
import os

self.serpapi_key = os.getenv("SERPAPI_API_KEY")
self.provider = provider
if provider == "serpapi":
self.organic_key = "organic_results"
api_key_env_name = "SERPAPI_API_KEY"
else:
self.organic_key = "organic"
api_key_env_name = "SERPER_API_KEY"
self.api_key = os.getenv(api_key_env_name)
if self.api_key is None:
raise ValueError(f"Missing API key. Make sure you have '{api_key_env_name}' in your env variables.")

def forward(self, query: str, filter_year: Optional[int] = None) -> str:
import requests

if self.serpapi_key is None:
raise ValueError("Missing SerpAPI key. Make sure you have 'SERPAPI_API_KEY' in your env variables.")

params = {
"engine": "google",
"q": query,
"api_key": self.serpapi_key,
"google_domain": "google.com",
}
if self.provider == "serpapi":
params = {
"q": query,
"api_key": self.api_key,
"engine": "google",
"google_domain": "google.com",
}
base_url = "https://serpapi.com/search.json"
else:
params = {
"q": query,
"api_key": self.api_key,
}
base_url = "https://google.serper.dev/search"
if filter_year is not None:
params["tbs"] = f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}"

response = requests.get("https://serpapi.com/search.json", params=params)
response = requests.get(base_url, params=params)

if response.status_code == 200:
results = response.json()
else:
raise ValueError(response.json())

if "organic_results" not in results.keys():
if self.organic_key not in results.keys():
if filter_year is not None:
raise Exception(
f"No results found for query: '{query}' with filtering on year={filter_year}. Use a less restrictive query or do not filter on year."
)
else:
raise Exception(f"No results found for query: '{query}'. Use a less restrictive query.")
if len(results["organic_results"]) == 0:
if len(results[self.organic_key]) == 0:
year_filter_message = f" with filter year={filter_year}" if filter_year is not None else ""
return f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter."

web_snippets = []
if "organic_results" in results:
for idx, page in enumerate(results["organic_results"]):
if self.organic_key in results:
for idx, page in enumerate(results[self.organic_key]):
date_published = ""
if "date" in page:
date_published = "\nDate published: " + page["date"]
Expand All @@ -193,8 +207,6 @@ def forward(self, query: str, filter_year: Optional[int] = None) -> str:
snippet = "\n" + page["snippet"]

redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{snippet}"

redacted_version = redacted_version.replace("Your browser can't play this video.", "")
web_snippets.append(redacted_version)

return "## Search Results\n" + "\n\n".join(web_snippets)
Expand Down
10 changes: 8 additions & 2 deletions src/smolagents/local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,8 +1229,14 @@ def evaluate_ast(
# For loop -> execute the loop
return evaluate_for(expression, *common_params)
elif isinstance(expression, ast.FormattedValue):
# Formatted value (part of f-string) -> evaluate the content and return
return evaluate_ast(expression.value, *common_params)
# Formatted value (part of f-string) -> evaluate the content and format it
value = evaluate_ast(expression.value, *common_params)
# Early return if no format spec
if not expression.format_spec:
return value
# Apply format specification
format_spec = evaluate_ast(expression.format_spec, *common_params)
return format(value, format_spec)
elif isinstance(expression, ast.If):
# If -> execute the right branch
return evaluate_if(expression, *common_params)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,22 @@ def test_evaluate_f_string(self):
assert result == "This is x: 3."
self.assertDictEqualNoPrint(state, {"x": 3, "text": "This is x: 3.", "_operations_count": 6})

def test_evaluate_f_string_with_format(self):
code = "text = f'This is x: {x:.2f}.'"
state = {"x": 3.336}
result, _ = evaluate_python_code(code, {}, state=state)
assert result == "This is x: 3.34."
self.assertDictEqualNoPrint(state, {"x": 3.336, "text": "This is x: 3.34.", "_operations_count": 8})

def test_evaluate_f_string_with_complex_format(self):
code = "text = f'This is x: {x:>{width}.{precision}f}.'"
state = {"x": 3.336, "width": 10, "precision": 2}
result, _ = evaluate_python_code(code, {}, state=state)
assert result == "This is x: 3.34."
self.assertDictEqualNoPrint(
state, {"x": 3.336, "width": 10, "precision": 2, "text": "This is x: 3.34.", "_operations_count": 14}
)

def test_evaluate_if(self):
code = "if x <= 3:\n y = 2\nelse:\n y = 5"
state = {"x": 3}
Expand Down
9 changes: 9 additions & 0 deletions tests/test_tool_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest

from smolagents.default_tools import DuckDuckGoSearchTool, GoogleSearchTool, VisitWebpageTool
from smolagents.tool_validation import validate_tool_attributes


@pytest.mark.parametrize("tool_class", [DuckDuckGoSearchTool, GoogleSearchTool, VisitWebpageTool])
def test_validate_tool_attributes(tool_class):
assert validate_tool_attributes(tool_class) is None, f"failed for {tool_class.name} tool"

0 comments on commit 8f941e6

Please sign in to comment.