Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 109 additions & 57 deletions backend/agent_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import sys
from typing import Any, Dict, List, Optional

from model_backends import get_backend, make_backend_importer

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -538,6 +540,59 @@ def analyse_microdata(
return {"error": str(e)}


def compute(
operation: str,
data: List[float],
other: Optional[List[float]] = None,
) -> Dict[str, Any]:
"""Small numeric helper retained for older tests/tool dispatch paths."""
if not data:
return {"error": "data must be non-empty"}

try:
if operation == "diff":
return {
"result": [
data[i + 1] - data[i] for i in range(len(data) - 1)
]
}
if operation == "pct_change":
return {
"result": [
0 if data[i] == 0 else (data[i + 1] - data[i]) / data[i] * 100
for i in range(len(data) - 1)
]
}
if operation == "mean":
return {"result": sum(data) / len(data)}
if operation == "sum":
return {"result": sum(data)}
if operation in {"subtract", "divide", "marginal_rate"}:
if other is None:
return {"error": f"{operation} requires other"}
if len(data) != len(other):
return {"error": "data and other must have the same length"}
if operation == "subtract":
return {"result": [a - b for a, b in zip(data, other)]}
if operation == "divide":
return {
"result": [
0 if b == 0 else a / b for a, b in zip(data, other)
]
}
return {
"result": [
0
if other[i + 1] == other[i]
else 100 * (data[i + 1] - data[i]) / (other[i + 1] - other[i])
for i in range(len(data) - 1)
]
}
return {"error": f"Unknown operation: {operation}"}
except Exception as e:
return {"error": str(e)}


def generate_chart(
chart_type: str, title: str, data: List[Dict[str, Any]], x_field: str, y_fields: List[str],
x_label: Optional[str] = None, y_label: Optional[str] = None,
Expand Down Expand Up @@ -578,29 +633,10 @@ def generate_chart(
return {"error": str(e)}


def run_python(code: str) -> Dict[str, Any]:
"""Execute Python code with the PolicyEngine UK compiled interface preloaded.

The code should assign its final result to a variable called `result`.
The environment includes the official Python wrapper so runs are easy to
reproduce outside the chat app.
"""
import math
def _safe_builtins_for_backend(backend_id: str) -> Dict[str, Any]:
import builtins as _builtins
_ensure_compiled_package_importable()
import pandas as pd
import policyengine_uk_compiled as pe

from policyengine_uk_compiled import (
Simulation,
StructuralReform,
Parameters,
aggregate_microdata,
combine_microdata,
capabilities,
ensure_dataset,
)

backend = get_backend(backend_id)
safe_names = (
"range", "len", "int", "float", "str", "bool", "list", "dict",
"tuple", "set", "zip", "enumerate", "map", "filter", "sorted",
Expand All @@ -609,37 +645,33 @@ def run_python(code: str) -> Dict[str, Any]:
"print", "any", "all", "pow", "divmod", "complex", "type",
"dir", "hasattr", "getattr",
)
safe_builtins = {k: getattr(_builtins, k) for k in safe_names if hasattr(_builtins, k)}
safe_builtins = {
k: getattr(_builtins, k) for k in safe_names if hasattr(_builtins, k)
}
safe_builtins["__import__"] = make_backend_importer(backend)
return safe_builtins

try:
import numpy as np
except ImportError:
np = None

def run_python(code: str, backend_id: str = "uk_compiled") -> Dict[str, Any]:
"""Execute Python code with the selected model backend preloaded.

The code should assign its final result to a variable called `result`.
The backend adapter controls the model-specific globals made available.
"""
backend = get_backend(backend_id)
safe_builtins = _safe_builtins_for_backend(backend.id)

output_lines: List[str] = []
def safe_print(*args, **kwargs):
output_lines.append(" ".join(str(a) for a in args))

safe_builtins["print"] = safe_print
safe_builtins["__import__"] = _safe_import

allowed_globals: Dict[str, Any] = {
"__builtins__": safe_builtins,
"math": math,
"json": json,
"pd": pd,
"pe": pe,
"Simulation": Simulation,
"StructuralReform": StructuralReform,
"Parameters": Parameters,
"aggregate_microdata": aggregate_microdata,
"combine_microdata": combine_microdata,
"capabilities": capabilities,
"ensure_dataset": ensure_dataset,
}
if np is not None:
allowed_globals["np"] = np
allowed_globals["numpy"] = np
try:
allowed_globals = backend.execution_globals()
except Exception as e:
return {"error": f"Backend import failed for '{backend.id}': {type(e).__name__}: {e}"}
allowed_globals["__builtins__"] = safe_builtins

try:
exec(code, allowed_globals)
Expand All @@ -656,6 +688,7 @@ def safe_print(*args, **kwargs):
if not response:
response["result"] = None
response["note"] = "No 'result' variable was set and nothing was printed."
response["backend"] = backend.id
return response


Expand Down Expand Up @@ -685,10 +718,16 @@ def _run_generator(code: str) -> Dict[str, Any]:
return result


def execute_tool(tool_name: str, tool_input: Dict[str, Any]) -> Dict[str, Any]:
logger.info(f"[TOOLS] Executing {tool_name}")
def execute_tool(
tool_name: str,
tool_input: Dict[str, Any],
backend_id: str = "uk_compiled",
) -> Dict[str, Any]:
logger.info(f"[TOOLS] Executing {tool_name} with backend={backend_id}")
tools = {
"run_python": run_python,
"compute": compute,
"generate_chart": generate_chart,
}
if tool_name not in tools:
return {"error": f"Unknown tool: {tool_name}"}
Expand All @@ -698,6 +737,8 @@ def execute_tool(tool_name: str, tool_input: Dict[str, Any]) -> Dict[str, Any]:
logger.info(f"[TOOLS] Running generator for {tool_name}")
tool_input = _run_generator(tool_input["generator"])
logger.info(f"[TOOLS] Generator produced keys: {list(tool_input.keys())}")
if tool_name == "run_python" and "backend_id" not in tool_input:
tool_input = {**tool_input, "backend_id": backend_id}
result = tools[tool_name](**tool_input)
logger.info(f"[TOOLS] {tool_name} completed")
return result
Expand All @@ -706,16 +747,27 @@ def execute_tool(tool_name: str, tool_input: Dict[str, Any]) -> Dict[str, Any]:
return {"error": str(e)}


TOOL_DEFINITIONS = [
{
"name": "run_python",
"description": "Execute reproducible Python code using the official PolicyEngine UK compiled interface. The environment preloads `policyengine_uk_compiled` as `pe`, plus `Simulation`, `Parameters`, `StructuralReform`, `aggregate_microdata`, `combine_microdata`, `capabilities`, `ensure_dataset`, `pd`, `np`, `json`, and `math`. Assign the final answer to `result` and use `print()` for intermediate output.",
"input_schema": {
"type": "object",
"properties": {
"code": {"type": "string", "description": "Python code to execute. Must assign the final answer to `result`. Use the preloaded PolicyEngine interface directly, for example: `sim = Simulation(year=2025)` or `policy = Parameters.model_validate({...})`."},
def get_tool_definitions(backend_id: str = "uk_compiled") -> List[Dict[str, Any]]:
backend = get_backend(backend_id)
return [
{
"name": "run_python",
"description": backend.tool_description(),
"input_schema": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": (
"Python code to execute. Must assign the final answer "
"to `result`. Use the preloaded model interface directly."
),
},
},
"required": ["code"],
},
"required": ["code"],
},
},
]
]


TOOL_DEFINITIONS = get_tool_definitions("uk_compiled")
Loading