Skip to content

Commit b5d9cdf

Browse files
committed
feat: accept camelCase keys in input TypedDicts
Synthesize a `<Name>CamelDict` sibling for every input TypedDict so users can pass API-shaped dicts and still satisfy the type checker. Closes #756.
1 parent cc1ae18 commit b5d9cdf

6 files changed

Lines changed: 569 additions & 19 deletions

File tree

.rules.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ Docstrings are written on sync clients and **automatically copied** to async cli
7575

7676
`src/apify_client/_models.py` and `src/apify_client/_typeddicts.py` are **auto-generated** — do not edit them manually. Every Pydantic model and TypedDict comes from the OpenAPI spec.
7777

78+
Each input-side TypedDict ships in two casings: snake_case (`RequestDict`) and camelCase (`RequestCamelDict`). The camel variants are synthesized in `scripts/postprocess_generated_models.py` from the Pydantic `Field(alias=...)` map; resource-client signatures union both with the Pydantic model so users can pass either casing without losing type-checker support.
79+
7880
- Generated by `datamodel-code-generator` from the OpenAPI spec at `https://docs.apify.com/api/openapi.json` (config in `pyproject.toml` under `[tool.datamodel-codegen]`, aliases in `datamodel_codegen_aliases.json`)
7981
- After generation, `scripts/postprocess_generated_models.py` is run to apply additional fixes
8082
- To regenerate locally:

scripts/postprocess_generated_models.py

Lines changed: 199 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
dependencies). The file is generated in full by datamodel-codegen; the trimming happens here.
1616
- Rename every kept class to add a `Dict` suffix so it doesn't clash with the Pydantic model name
1717
(e.g. `WebhookCreate` -> `WebhookCreateDict`) and rewire references.
18+
- Generate a camelCase sibling for every kept TypedDict (`FooDict` -> `FooCamelDict`) so users
19+
can pass API-shaped dicts and still satisfy the type checker. Field identifiers are looked up
20+
in the Pydantic alias map extracted from `_models.py`; nested TypedDict refs are rewired to
21+
the camel variant.
1822
- Add `@docs_group('Typed dicts')` to every kept class.
1923
"""
2024

@@ -391,6 +395,195 @@ def rename_with_dict_suffix(content: str, names: set[str]) -> str:
391395
return content
392396

393397

398+
def _extract_alias_from_field_call(field_call: ast.Call) -> str | None:
399+
"""Return the `alias=` kwarg value from a `Field(...)` call, or None if not present."""
400+
for kw in field_call.keywords:
401+
if kw.arg == 'alias' and isinstance(kw.value, ast.Constant) and isinstance(kw.value.value, str):
402+
return kw.value.value
403+
return None
404+
405+
406+
def _extract_class_field_aliases(class_node: ast.ClassDef) -> dict[str, str]:
407+
"""Return `{snake_field: api_field}` for every annotated field declared on `class_node`.
408+
409+
Fields without a `Field(alias=...)` map to themselves (their declared Python name matches the
410+
API name — typical for single-word fields like `url`, `id`).
411+
"""
412+
aliases: dict[str, str] = {}
413+
for stmt in class_node.body:
414+
if not isinstance(stmt, ast.AnnAssign) or not isinstance(stmt.target, ast.Name):
415+
continue
416+
field_name = stmt.target.id
417+
if field_name == 'model_config':
418+
continue
419+
# Default: no alias means snake name == API name.
420+
api_name = field_name
421+
# Walk the annotation to find a nested `Field(alias='...')` call inside `Annotated[...]`.
422+
for sub in ast.walk(stmt.annotation):
423+
if isinstance(sub, ast.Call) and isinstance(sub.func, ast.Name) and sub.func.id == 'Field':
424+
found = _extract_alias_from_field_call(sub)
425+
if found is not None:
426+
api_name = found
427+
break
428+
aliases[field_name] = api_name
429+
return aliases
430+
431+
432+
def build_alias_map(models_source: str) -> dict[str, dict[str, str]]:
433+
"""Return `{ModelName: {snake_field: api_field}}` for every Pydantic model in `models_source`.
434+
435+
The map is the source of truth for camelCase field names: it captures both `Field(alias=...)`
436+
overrides and the bare-name case (single-word fields without an alias). Used when synthesizing
437+
camelCase TypedDict variants so the API spelling round-trips losslessly.
438+
"""
439+
tree = ast.parse(models_source)
440+
return {node.name: _extract_class_field_aliases(node) for node in tree.body if isinstance(node, ast.ClassDef)}
441+
442+
443+
def _camel_dict_name(snake_name: str) -> str:
444+
"""Insert `Camel` before the trailing `Dict` (e.g. `RequestDict` -> `RequestCamelDict`)."""
445+
if not snake_name.endswith('Dict'):
446+
raise ValueError(f"Expected name to end with 'Dict': {snake_name!r}")
447+
return snake_name[: -len('Dict')] + 'CamelDict'
448+
449+
450+
def _is_dict_str_any(node: ast.expr) -> bool:
451+
"""Return True if `node` is a `dict[str, Any]` subscript (casing-agnostic open mapping)."""
452+
return isinstance(node, ast.Subscript) and isinstance(node.value, ast.Name) and node.value.id == 'dict'
453+
454+
455+
def _rename_fields_in_class_block(block: list[str], field_aliases: dict[str, str]) -> list[str]:
456+
"""Rewrite each field declaration line in `block` using `field_aliases`.
457+
458+
Matches lines of the form `<indent><snake_ident>:` and substitutes the identifier when an
459+
alias is present. Multi-line annotations and trailing default values are preserved verbatim
460+
because only the field name on the first line is replaced.
461+
"""
462+
field_decl = re.compile(r'^(\s+)([a-z_][a-z0-9_]*)(\s*:)')
463+
out: list[str] = []
464+
for line in block:
465+
m = field_decl.match(line)
466+
if m is None:
467+
out.append(line)
468+
continue
469+
indent, name, colon = m.group(1), m.group(2), m.group(3)
470+
api_name = field_aliases.get(name)
471+
if api_name is None or api_name == name:
472+
out.append(line)
473+
continue
474+
out.append(f'{indent}{api_name}{colon}{line[m.end() :]}')
475+
return out
476+
477+
478+
def _rename_typeddict_refs_in_block(block: list[str], rename_set: set[str]) -> list[str]:
479+
"""Rewrite every whole-word occurrence of each name in `rename_set` to its camel form.
480+
481+
Operates on the block as a single string so refs spanning multiple lines (e.g. annotations
482+
wrapped across lines) are caught.
483+
"""
484+
if not rename_set:
485+
return block
486+
text = '\n'.join(block)
487+
# `\b` anchors already prevent partial-prefix matches; we just iterate the set in any stable
488+
# order. Sorting keeps the substitution deterministic across Python hash seeds.
489+
for snake in sorted(rename_set):
490+
text = re.sub(rf'\b{re.escape(snake)}\b', _camel_dict_name(snake), text)
491+
return text.split('\n')
492+
493+
494+
def add_camel_case_typeddicts(content: str, alias_map: dict[str, dict[str, str]]) -> str:
495+
"""Insert a camelCase sibling for every TypedDict and TypeAlias in `content`.
496+
497+
For each class `<Name>Dict(TypedDict)` and each `<Name>Dict: TypeAlias = ...`, emit a sibling
498+
`<Name>CamelDict` directly after the original. Field identifiers are renamed using
499+
`alias_map[<Name>]`; nested TypedDict references in annotations are rewired to their camel
500+
variant via whole-word substitution.
501+
502+
`TaskInputDict: TypeAlias = dict[str, Any]` and similar casing-agnostic aliases get a trivial
503+
camel alias too, so refs from other camel TypedDicts (e.g. `RequestBaseCamelDict.user_data:
504+
NotRequired[RequestUserDataCamelDict]`) resolve cleanly.
505+
506+
Idempotent: blocks whose name already ends with `CamelDict` are skipped.
507+
"""
508+
tree = ast.parse(content)
509+
lines = content.split('\n')
510+
511+
# Pass 1: gather every snake-side symbol that needs a camel sibling.
512+
snake_classes: list[tuple[ast.ClassDef, int, int]] = [] # node, block_start, block_end (exclusive)
513+
snake_aliases: list[tuple[int, int]] = [] # block_start, block_end
514+
flat_aliases: list[tuple[int, str]] = [] # block_end, alias_name
515+
516+
body_with_trailing_docstrings = _extract_top_level_symbols(tree)
517+
end_by_name: dict[str, int] = {name: end for name, _, end in body_with_trailing_docstrings}
518+
existing_symbols: set[str] = {name for name, _, _ in body_with_trailing_docstrings}
519+
520+
for node in tree.body:
521+
if isinstance(node, ast.ClassDef):
522+
# Every class kept in `_typeddicts.py` is a TypedDict — either directly (base is
523+
# `TypedDict`) or by inheriting from a sibling TypedDict (e.g. `RequestDict(RequestBaseDict)`).
524+
# The `Dict` suffix is the load-bearing filter; the base check is informational only.
525+
if not node.name.endswith('Dict') or node.name.endswith('CamelDict'):
526+
continue
527+
if _camel_dict_name(node.name) in existing_symbols:
528+
continue
529+
start = node.lineno - 1
530+
if start > 0 and lines[start - 1].lstrip().startswith('@'):
531+
start -= 1
532+
end = end_by_name.get(node.name, node.end_lineno or node.lineno)
533+
snake_classes.append((node, start, end))
534+
elif (
535+
isinstance(node, ast.AnnAssign)
536+
and isinstance(node.target, ast.Name)
537+
and isinstance(node.annotation, ast.Name)
538+
and node.annotation.id == 'TypeAlias'
539+
):
540+
name = node.target.id
541+
if not name.endswith('Dict') or name.endswith('CamelDict'):
542+
continue
543+
if _camel_dict_name(name) in existing_symbols:
544+
continue
545+
if node.value is None:
546+
continue
547+
start = node.lineno - 1
548+
end = end_by_name.get(name, node.end_lineno or node.lineno)
549+
if _is_dict_str_any(node.value):
550+
flat_aliases.append((end, name))
551+
else:
552+
snake_aliases.append((start, end))
553+
554+
# The rename set covers EVERY snake-side `*Dict` symbol in the file (not just the ones we
555+
# need to clone) so nested refs inside a cloned block still rewire correctly even on re-runs
556+
# where most camel siblings already exist.
557+
rename_set: set[str] = {
558+
name for name in existing_symbols if name.endswith('Dict') and not name.endswith('CamelDict')
559+
}
560+
561+
# Pass 2: build camel blocks.
562+
insertions: list[tuple[int, list[str]]] = []
563+
564+
for class_node, start, end in snake_classes:
565+
block = lines[start:end]
566+
renamed_refs = _rename_typeddict_refs_in_block(block, rename_set)
567+
field_aliases = alias_map.get(class_node.name[: -len('Dict')], {})
568+
camel_block = _rename_fields_in_class_block(renamed_refs, field_aliases)
569+
insertions.append((end, ['', *camel_block]))
570+
571+
for start, end in snake_aliases:
572+
block = lines[start:end]
573+
camel_block = _rename_typeddict_refs_in_block(block, rename_set)
574+
insertions.append((end, ['', *camel_block]))
575+
576+
for end, name in flat_aliases:
577+
insertions.append((end, ['', f'{_camel_dict_name(name)}: TypeAlias = dict[str, Any]']))
578+
579+
# Insert in reverse line order so earlier indices stay valid.
580+
new_lines = lines[:]
581+
for after, block in sorted(insertions, key=lambda i: i[0], reverse=True):
582+
new_lines[after:after] = block
583+
584+
return _collapse_blank_lines('\n'.join(new_lines))
585+
586+
394587
def postprocess_models(models_path: Path, literals_path: Path) -> list[Path]:
395588
"""Apply `_models.py`-specific fixes and emit `_literals.py`.
396589
@@ -414,13 +607,14 @@ def postprocess_models(models_path: Path, literals_path: Path) -> list[Path]:
414607
return changed
415608

416609

417-
def postprocess_typeddicts(path: Path) -> bool:
610+
def postprocess_typeddicts(path: Path, alias_map: dict[str, dict[str, str]]) -> bool:
418611
"""Apply `_typeddicts.py`-specific fixes. Returns True if the file changed."""
419612
original = path.read_text()
420613
pruned, kept = prune_typeddicts(original, RESOURCE_INPUT_TYPEDDICTS)
421614
renamed = rename_with_dict_suffix(pruned, kept)
422615
flattened = flatten_empty_typeddicts(renamed)
423-
final = add_docs_group_decorators(flattened, 'Typed dicts')
616+
camelized = add_camel_case_typeddicts(flattened, alias_map)
617+
final = add_docs_group_decorators(camelized, 'Typed dicts')
424618
if final == original:
425619
return False
426620
path.write_text(final)
@@ -442,9 +636,10 @@ def main() -> None:
442636
else:
443637
print('No fixes needed for _models.py / _literals.py')
444638

445-
if postprocess_typeddicts(TYPEDDICTS_PATH):
639+
alias_map = build_alias_map(MODELS_PATH.read_text())
640+
if postprocess_typeddicts(TYPEDDICTS_PATH, alias_map):
446641
changed.append(TYPEDDICTS_PATH)
447-
print(f'Pruned and renamed TypedDicts in {TYPEDDICTS_PATH}')
642+
print(f'Pruned, renamed, and camelized TypedDicts in {TYPEDDICTS_PATH}')
448643
else:
449644
print('No fixes needed for _typeddicts.py')
450645

src/apify_client/_resource_clients/request_queue.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,14 @@
4545
from datetime import timedelta
4646

4747
from apify_client._literals import GeneralAccess
48-
from apify_client._typeddicts import RequestDict, RequestDraftDeleteDict, RequestDraftDict
48+
from apify_client._typeddicts import (
49+
RequestCamelDict,
50+
RequestDict,
51+
RequestDraftCamelDict,
52+
RequestDraftDeleteCamelDict,
53+
RequestDraftDeleteDict,
54+
RequestDraftDict,
55+
)
4956
from apify_client._types import Timeout
5057

5158
_RQ_MAX_REQUESTS_PER_BATCH = 25
@@ -190,7 +197,7 @@ def list_and_lock_head(
190197

191198
def add_request(
192199
self,
193-
request: RequestDraftDict | RequestDraft,
200+
request: RequestDraftDict | RequestDraftCamelDict | RequestDraft,
194201
*,
195202
forefront: bool | None = None,
196203
timeout: Timeout = 'short',
@@ -252,7 +259,7 @@ def get_request(self, request_id: str, *, timeout: Timeout = 'short') -> Request
252259

253260
def update_request(
254261
self,
255-
request: RequestDict | Request,
262+
request: RequestDict | RequestCamelDict | Request,
256263
*,
257264
forefront: bool | None = None,
258265
timeout: Timeout = 'medium',
@@ -366,7 +373,7 @@ def delete_request_lock(
366373

367374
def batch_add_requests(
368375
self,
369-
requests: list[RequestDraft] | list[RequestDraftDict],
376+
requests: list[RequestDraft] | list[RequestDraftDict] | list[RequestDraftCamelDict],
370377
*,
371378
forefront: bool = False,
372379
max_parallel: int = 1,
@@ -464,7 +471,7 @@ def batch_add_requests(
464471

465472
def batch_delete_requests(
466473
self,
467-
requests: list[RequestDraftDelete] | list[RequestDraftDeleteDict],
474+
requests: list[RequestDraftDelete] | list[RequestDraftDeleteDict] | list[RequestDraftDeleteCamelDict],
468475
*,
469476
timeout: Timeout = 'short',
470477
) -> BatchDeleteResult:
@@ -747,7 +754,7 @@ async def list_and_lock_head(
747754

748755
async def add_request(
749756
self,
750-
request: RequestDraftDict | RequestDraft,
757+
request: RequestDraftDict | RequestDraftCamelDict | RequestDraft,
751758
*,
752759
forefront: bool | None = None,
753760
timeout: Timeout = 'short',
@@ -807,7 +814,7 @@ async def get_request(self, request_id: str, *, timeout: Timeout = 'short') -> R
807814

808815
async def update_request(
809816
self,
810-
request: RequestDict | Request,
817+
request: RequestDict | RequestCamelDict | Request,
811818
*,
812819
forefront: bool | None = None,
813820
timeout: Timeout = 'medium',
@@ -968,7 +975,7 @@ async def _batch_add_requests_worker(
968975

969976
async def batch_add_requests(
970977
self,
971-
requests: list[RequestDraft] | list[RequestDraftDict],
978+
requests: list[RequestDraft] | list[RequestDraftDict] | list[RequestDraftCamelDict],
972979
*,
973980
forefront: bool = False,
974981
max_parallel: int = 5,
@@ -1077,7 +1084,7 @@ async def batch_add_requests(
10771084

10781085
async def batch_delete_requests(
10791086
self,
1080-
requests: list[RequestDraftDelete] | list[RequestDraftDeleteDict],
1087+
requests: list[RequestDraftDelete] | list[RequestDraftDeleteDict] | list[RequestDraftDeleteCamelDict],
10811088
*,
10821089
timeout: Timeout = 'short',
10831090
) -> BatchDeleteResult:

0 commit comments

Comments
 (0)