Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/976.added
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a reusable local H5 worker service boundary and keep the Modal worker script as a thin adapter.
18 changes: 18 additions & 0 deletions docs/engineering/stages/build_outputs.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,24 @@ source-variable cloning, postprocessing, or writing concern. Do not place
country-specific payload mutation in `build_h5()` when it can be represented as
a postprocessor.

## Worker Chunk Execution

`LocalH5WorkerService` is the reusable Stage 4 boundary for executing one
prepared local-H5 worker chunk. It consumes a `WorkerSession`, typed
`AreaBuildRequest` objects, and a `WorkerExecutionConfig`, then returns a
structured `WorkerResult`.

`modal_app.worker_script` should remain a thin CLI/JSON adapter around this
service. It may parse legacy `--work-items` and typed `--requests-json`, prepare
the worker session, and print the legacy coordinator JSON shape, but it should
not regain build-loop, write-loop, or validation-loop logic.

For now, `WorkerResult.to_legacy_dict()` preserves the existing coordinator
contract with `completed`, `failed`, `errors`, `validation_rows`, and
`validation_summary`. New code should prefer the structured `results` and
`issues` fields. Removing the legacy shape and moving the coordinator off worker
subprocess JSON is a later migration step.

## Payload Postprocessors

Payload postprocessors are ordered, country- or product-specific transformations
Expand Down
241 changes: 91 additions & 150 deletions modal_app/worker_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import sys
import traceback
from pathlib import Path
from typing import Any


def parse_args(argv: list[str] | None = None):
Expand Down Expand Up @@ -148,38 +147,6 @@ def _build_publishing_inputs(*, args, run_id: str):
return worker_inputs.to_publishing_input_bundle(run_id=run_id)


def _build_kwargs_from_request(request) -> dict[str, Any]:
"""Translate a typed request into `build_h5(...)` keyword arguments."""

if request.area_type == "national":
return {}

if len(request.filters) != 1:
raise ValueError(
f"{request.area_type} requests must carry exactly one build filter"
)

build_filter = request.filters[0]
if (
request.area_type in {"state", "district"}
and build_filter.geography_field == "cd_geoid"
and build_filter.op == "in"
):
return {"cd_subset": [str(item) for item in build_filter.value]}

if (
request.area_type == "city"
and build_filter.geography_field == "county_fips"
and build_filter.op == "in"
):
return {"county_fips_filter": {str(item) for item in build_filter.value}}

raise ValueError(
f"Unsupported build filter for {request.area_type}: "
f"{build_filter.geography_field}:{build_filter.op}"
)


def _request_key(request) -> str:
"""Return the stable completion key used by worker/coordinator flows."""

Expand All @@ -196,20 +163,6 @@ def _work_item_key(work_item) -> str:
return f"{item_type}:{item_id}"


def _resolve_output_path(*, output_dir: Path, output_relative_path: str) -> Path:
"""Resolve one request output path and reject attempts to escape the run dir."""

candidate_path = (output_dir / output_relative_path).resolve(strict=False)
output_dir_path = output_dir.resolve(strict=False)
try:
candidate_path.relative_to(output_dir_path)
except ValueError as exc:
raise ValueError(
"output_relative_path must stay within the worker output_dir"
) from exc
return candidate_path


def _resolve_request_input(
*,
request_input_mode,
Expand All @@ -232,6 +185,51 @@ def _resolve_request_input(
return _request_key(request), request


def _resolve_worker_requests(
*,
request_input_mode,
request_inputs,
area_catalog,
geography,
) -> tuple[tuple, tuple]:
"""Resolve queued CLI inputs into typed requests plus conversion issues."""

from policyengine_us_data.build_outputs.worker_service import WorkerIssue

if request_input_mode == "requests":
return tuple(request_inputs), ()

requests = []
issues = []
for request_input in request_inputs:
request_key = _work_item_key(request_input)
try:
request_key, request = _resolve_request_input(
request_input_mode=request_input_mode,
request_input=request_input,
area_catalog=area_catalog,
geography=geography,
)
except Exception as exc:
issues.append(
WorkerIssue(
item=request_key,
phase="request",
message=str(exc),
traceback=traceback.format_exc(),
)
)
continue
if request is None:
print(
f"Skipping {request_key}: no matching geography in legacy work item",
file=sys.stderr,
)
continue
requests.append(request)
return tuple(requests), tuple(issues)


def _log_worker_session_ready(*, scope: str, session, geography) -> None:
"""Write worker-session setup details to stderr for Modal diagnostics."""

Expand All @@ -252,7 +250,6 @@ def _log_worker_session_ready(*, scope: str, session, geography) -> None:
def main(argv: list[str] | None = None):
args = parse_args(argv)

dataset_path = Path(args.dataset_path)
output_dir = Path(args.output_dir)
run_id = args.run_id or output_dir.name or "local-worker"

Expand All @@ -265,15 +262,17 @@ def main(argv: list[str] | None = None):
original_stdout = sys.stdout
sys.stdout = sys.stderr

from policyengine_us_data.calibration.publish_local_area import (
build_h5,
)
from policyengine_us_data.build_outputs.area_catalog import USAreaCatalog
from policyengine_us_data.build_outputs.requests import AreaBuildRequest
from policyengine_us_data.build_outputs.validation import (
AreaValidationService,
ValidationPolicy,
)
from policyengine_us_data.build_outputs.worker_service import (
LocalH5WorkerService,
WorkerExecutionConfig,
WorkerResult,
)
from policyengine_us_data.build_outputs.worker_session import WorkerSessionFactory

area_catalog = USAreaCatalog.default()
Expand All @@ -297,8 +296,6 @@ def main(argv: list[str] | None = None):
artifacts_dir=Path(args.artifacts_dir) if args.artifacts_dir else None,
expected_scope_fingerprint=args.scope_fingerprint,
)
weights = session.weights.values
n_records = session.weights.n_records
geography = session.geography
validation_context = session.validation_context
_log_worker_session_ready(scope=scope, session=session, geography=geography)
Expand All @@ -312,111 +309,55 @@ def main(argv: list[str] | None = None):
file=sys.stderr,
)

results = {
"completed": [],
"failed": [],
"errors": [],
"validation_rows": [],
"validation_summary": {},
}

for request_input in request_inputs:
try:
request_key = (
_work_item_key(request_input)
if request_input_mode == "work_items"
else None
)
request_key, request = _resolve_request_input(
request_input_mode=request_input_mode,
request_input=request_input,
area_catalog=area_catalog,
geography=geography,
)
if request is None:
print(
f"Skipping {request_key}: no matching geography in legacy work item",
file=sys.stderr,
)
continue
requests, request_issues = _resolve_worker_requests(
request_input_mode=request_input_mode,
request_inputs=request_inputs,
area_catalog=area_catalog,
geography=geography,
)
worker_result = LocalH5WorkerService(
validation_service=validation_service,
).execute(
session=session,
requests=requests,
config=WorkerExecutionConfig(
output_dir=output_dir,
takeup_filter=tuple(takeup_filter),
validate=not args.no_validate,
),
)
if request_issues:
worker_result = WorkerResult(
area_results=worker_result.area_results,
issues=(*request_issues, *worker_result.issues),
)

output_path = _resolve_output_path(
output_dir=output_dir,
output_relative_path=request.output_relative_path,
for area_result in worker_result.area_results:
if area_result.status == "completed":
print(f"Completed {area_result.key}", file=sys.stderr)
else:
message = (
area_result.issues[0].message if area_result.issues else "unknown error"
)
output_path.parent.mkdir(parents=True, exist_ok=True)
build_kwargs = _build_kwargs_from_request(request)
if request.area_type == "national":
n_clones_from_weights = weights.shape[0] // n_records
if n_clones_from_weights != geography.n_clones:
raise ValueError(
f"National weights have {n_clones_from_weights} clones "
f"but geography has {geography.n_clones}. "
"Use the matching saved geography artifact."
)
path = build_h5(
weights=weights,
geography=geography,
dataset_path=dataset_path,
output_path=output_path,
)
else:
path = build_h5(
weights=weights,
geography=geography,
dataset_path=dataset_path,
output_path=output_path,
takeup_filter=takeup_filter,
**build_kwargs,
)

if path:
results["completed"].append(request_key)
print(
f"Completed {request_key}",
file=sys.stderr,
)

if not args.no_validate and validation_context is not None:
try:
validation_result = validation_service.validate_request(
context=validation_context,
h5_path=str(path),
request=request,
)
v_rows = list(validation_result.rows)
results["validation_rows"].extend(v_rows)
summary = dict(validation_result.summary)
results["validation_summary"][request_key] = summary
print(
f" Validated {request_key}: "
f"{summary['n_targets']} targets, "
f"{summary['n_sanity_fail']} sanity fails, "
f"mean RAE={summary['mean_rel_abs_error']:.4f}",
file=sys.stderr,
)
except Exception as ve:
print(
f" Validation failed for {request_key}: {ve}",
file=sys.stderr,
)

except Exception as e:
results["failed"].append(request_key)
results["errors"].append(
{
"item": request_key,
"error": str(e),
"traceback": traceback.format_exc(),
}
print(f"FAILED {area_result.key}: {message}", file=sys.stderr)
if area_result.validation_status == "passed" and area_result.validation_summary:
summary = area_result.validation_summary
print(
f" Validated {area_result.key}: "
f"{summary['n_targets']} targets, "
f"{summary['n_sanity_fail']} sanity fails, "
f"mean RAE={summary['mean_rel_abs_error']:.4f}",
file=sys.stderr,
)
elif area_result.validation_status == "error" and area_result.issues:
print(
f"FAILED {request_key}: {e}",
f" Validation failed for {area_result.key}: "
f"{area_result.issues[-1].message}",
file=sys.stderr,
)

sys.stdout = original_stdout
print(json.dumps(results))
print(json.dumps(worker_result.to_legacy_dict()))


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion policyengine_us_data/build_outputs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
artifacts, worker-scoped session and validation context setup, microsimulation
access helpers, clone selection, entity reindexing, source-variable cloning,
validated H5 payload contracts, ordered output postprocessing, one-area payload
building, and H5 writing.
building, H5 writing, and worker chunk execution.
"""
Loading
Loading