Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,6 @@ docs/.env.local
data/
*.h5
*.db

# macOS
.DS_Store
94 changes: 94 additions & 0 deletions alembic/versions/20260205_0002_add_user_policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""add_user_policies

Revision ID: 0002_user_policies
Revises: 36f9d434e95b
Create Date: 2026-02-05

This migration adds:
1. tax_benefit_model_id foreign key to policies table
2. user_policies table for user-policy associations

Note: user_id in user_policies is NOT a foreign key to users table.
It's a client-generated UUID stored in localStorage, allowing anonymous
users to save policies without authentication.
"""

from typing import Sequence, Union

import sqlalchemy as sa
import sqlmodel.sql.sqltypes

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "0002_user_policies"
down_revision: Union[str, Sequence[str], None] = "36f9d434e95b"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
"""Add user_policies table and policy.tax_benefit_model_id."""
# Add tax_benefit_model_id to policies table
op.add_column(
"policies", sa.Column("tax_benefit_model_id", sa.Uuid(), nullable=False)
)
op.create_index(
op.f("ix_policies_tax_benefit_model_id"),
"policies",
["tax_benefit_model_id"],
unique=False,
)
op.create_foreign_key(
"fk_policies_tax_benefit_model_id",
"policies",
"tax_benefit_models",
["tax_benefit_model_id"],
["id"],
)

# Create user_policies table
# Note: user_id is NOT a foreign key - it's a client-generated UUID from localStorage
op.create_table(
"user_policies",
sa.Column("user_id", sa.Uuid(), nullable=False),
sa.Column("policy_id", sa.Uuid(), nullable=False),
sa.Column("country_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("label", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("id", sa.Uuid(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("updated_at", sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(["policy_id"], ["policies.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_user_policies_policy_id"),
"user_policies",
["policy_id"],
unique=False,
)
op.create_index(
op.f("ix_user_policies_user_id"), "user_policies", ["user_id"], unique=False
)
op.create_index(
op.f("ix_user_policies_country_id"),
"user_policies",
["country_id"],
unique=False,
)


def downgrade() -> None:
"""Remove user_policies table and policy.tax_benefit_model_id."""
# Drop user_policies table
op.drop_index(op.f("ix_user_policies_country_id"), table_name="user_policies")
op.drop_index(op.f("ix_user_policies_user_id"), table_name="user_policies")
op.drop_index(op.f("ix_user_policies_policy_id"), table_name="user_policies")
op.drop_table("user_policies")

# Remove tax_benefit_model_id from policies
op.drop_constraint(
"fk_policies_tax_benefit_model_id", "policies", type_="foreignkey"
)
op.drop_index(op.f("ix_policies_tax_benefit_model_id"), table_name="policies")
op.drop_column("policies", "tax_benefit_model_id")
45 changes: 24 additions & 21 deletions src/policyengine_api/agent_sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,7 @@ def openapi_to_claude_tools(spec: dict) -> list[dict]:

prop = schema_to_json_schema(spec, param_schema)
prop["description"] = (
param.get("description", "")
+ f" (in: {param_in})"
param.get("description", "") + f" (in: {param_in})"
)
properties[param_name] = prop

Expand Down Expand Up @@ -268,16 +267,18 @@ def openapi_to_claude_tools(spec: dict) -> list[dict]:
if required:
input_schema["required"] = list(set(required))

tools.append({
"name": tool_name,
"description": full_desc[:1024], # Claude has limits
"input_schema": input_schema,
"_meta": {
"path": path,
"method": method,
"parameters": operation.get("parameters", []),
},
})
tools.append(
{
"name": tool_name,
"description": full_desc[:1024], # Claude has limits
"input_schema": input_schema,
"_meta": {
"path": path,
"method": method,
"parameters": operation.get("parameters", []),
},
}
)

return tools

Expand Down Expand Up @@ -347,7 +348,9 @@ def execute_api_tool(
url, params=query_params, json=body_data, headers=headers, timeout=60
)
elif method == "delete":
resp = requests.delete(url, params=query_params, headers=headers, timeout=60)
resp = requests.delete(
url, params=query_params, headers=headers, timeout=60
)
else:
return f"Unsupported method: {method}"

Expand Down Expand Up @@ -415,9 +418,7 @@ def log(msg: str) -> None:
tool_lookup = {t["name"]: t for t in tools}

# Strip _meta from tools before sending to Claude (it doesn't need it)
claude_tools = [
{k: v for k, v in t.items() if k != "_meta"} for t in tools
]
claude_tools = [{k: v for k, v in t.items() if k != "_meta"} for t in tools]
# Add the sleep tool
claude_tools.append(SLEEP_TOOL)

Expand Down Expand Up @@ -477,11 +478,13 @@ def log(msg: str) -> None:

log(f"[TOOL_RESULT] {result[:300]}")

tool_results.append({
"type": "tool_result",
"tool_use_id": block.id,
"content": result,
})
tool_results.append(
{
"type": "tool_result",
"tool_use_id": block.id,
"content": result,
}
)

messages.append({"role": "assistant", "content": assistant_content})

Expand Down
2 changes: 2 additions & 0 deletions src/policyengine_api/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
tax_benefit_model_versions,
tax_benefit_models,
user_household_associations,
user_policies,
variables,
)

Expand All @@ -43,5 +44,6 @@
api_router.include_router(analysis.router)
api_router.include_router(agent.router)
api_router.include_router(user_household_associations.router)
api_router.include_router(user_policies.router)

__all__ = ["api_router"]
22 changes: 18 additions & 4 deletions src/policyengine_api/api/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def get_traceparent() -> str | None:
TraceContextTextMapPropagator().inject(carrier)
return carrier.get("traceparent")


router = APIRouter(prefix="/agent", tags=["agent"])


Expand Down Expand Up @@ -93,7 +94,9 @@ def _run_local_agent(
from policyengine_api.agent_sandbox import _run_agent_impl

try:
history_dicts = [{"role": m.role, "content": m.content} for m in (history or [])]
history_dicts = [
{"role": m.role, "content": m.content} for m in (history or [])
]
result = _run_agent_impl(question, api_base_url, call_id, history_dicts)
_calls[call_id]["status"] = result.get("status", "completed")
_calls[call_id]["result"] = result
Expand Down Expand Up @@ -136,9 +139,15 @@ async def run_agent(request: RunRequest) -> RunResponse:

traceparent = get_traceparent()
run_fn = modal.Function.from_name("policyengine-sandbox", "run_agent")
history_dicts = [{"role": m.role, "content": m.content} for m in request.history]
history_dicts = [
{"role": m.role, "content": m.content} for m in request.history
]
call = run_fn.spawn(
request.question, api_base_url, call_id, history_dicts, traceparent=traceparent
request.question,
api_base_url,
call_id,
history_dicts,
traceparent=traceparent,
)

_calls[call_id] = {
Expand Down Expand Up @@ -166,7 +175,12 @@ async def run_agent(request: RunRequest) -> RunResponse:
# Run in background using asyncio
loop = asyncio.get_event_loop()
loop.run_in_executor(
None, _run_local_agent, call_id, request.question, api_base_url, request.history
None,
_run_local_agent,
call_id,
request.question,
api_base_url,
request.history,
)

return RunResponse(call_id=call_id, status="running")
Expand Down
29 changes: 21 additions & 8 deletions src/policyengine_api/api/household.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,13 @@ def _calculate_household_uk(
from pathlib import Path

import pandas as pd
from policyengine.core import Simulation
from microdf import MicroDataFrame
from policyengine.core import Simulation
from policyengine.tax_benefit_models.uk import uk_latest
from policyengine.tax_benefit_models.uk.datasets import PolicyEngineUKDataset
from policyengine.tax_benefit_models.uk.datasets import UKYearData
from policyengine.tax_benefit_models.uk.datasets import (
PolicyEngineUKDataset,
UKYearData,
)

n_people = len(people)
n_benunits = max(1, len(benunit))
Expand Down Expand Up @@ -466,7 +468,14 @@ def _run_local_household_us(

try:
result = _calculate_household_us(
people, marital_unit, family, spm_unit, tax_unit, household, year, policy_data
people,
marital_unit,
family,
spm_unit,
tax_unit,
household,
year,
policy_data,
)

# Update job with result
Expand Down Expand Up @@ -512,11 +521,13 @@ def _calculate_household_us(
from pathlib import Path

import pandas as pd
from policyengine.core import Simulation
from microdf import MicroDataFrame
from policyengine.core import Simulation
from policyengine.tax_benefit_models.us import us_latest
from policyengine.tax_benefit_models.us.datasets import PolicyEngineUSDataset
from policyengine.tax_benefit_models.us.datasets import USYearData
from policyengine.tax_benefit_models.us.datasets import (
PolicyEngineUSDataset,
USYearData,
)

n_people = len(people)
n_households = max(1, len(household))
Expand Down Expand Up @@ -672,7 +683,9 @@ def safe_convert(value):
except (ValueError, TypeError):
return str(value)

def extract_entity_outputs(entity_name: str, entity_data, n_rows: int) -> list[dict]:
def extract_entity_outputs(
entity_name: str, entity_data, n_rows: int
) -> list[dict]:
outputs = []
for i in range(n_rows):
row_dict = {}
Expand Down
29 changes: 23 additions & 6 deletions src/policyengine_api/api/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from typing import List
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session, select

from policyengine_api.models import (
Expand All @@ -40,6 +40,7 @@
Policy,
PolicyCreate,
PolicyRead,
TaxBenefitModel,
)
from policyengine_api.services.database import get_session

Expand Down Expand Up @@ -67,8 +68,17 @@ def create_policy(policy: PolicyCreate, session: Session = Depends(get_session))
]
}
"""
# Validate tax_benefit_model exists
tax_model = session.get(TaxBenefitModel, policy.tax_benefit_model_id)
if not tax_model:
raise HTTPException(status_code=404, detail="Tax benefit model not found")

# Create the policy
db_policy = Policy(name=policy.name, description=policy.description)
db_policy = Policy(
name=policy.name,
description=policy.description,
tax_benefit_model_id=policy.tax_benefit_model_id,
)
session.add(db_policy)
session.flush() # Get the policy ID before adding parameter values

Expand Down Expand Up @@ -112,10 +122,17 @@ def create_policy(policy: PolicyCreate, session: Session = Depends(get_session))


@router.get("/", response_model=List[PolicyRead])
def list_policies(session: Session = Depends(get_session)):
"""List all policies."""
policies = session.exec(select(Policy)).all()
return policies
def list_policies(
tax_benefit_model_id: UUID | None = Query(
None, description="Filter by tax benefit model"
),
session: Session = Depends(get_session),
):
"""List all policies, optionally filtered by tax benefit model."""
query = select(Policy)
if tax_benefit_model_id:
query = query.where(Policy.tax_benefit_model_id == tax_benefit_model_id)
return session.exec(query).all()


@router.get("/{policy_id}", response_model=PolicyRead)
Expand Down
Loading