Skip to content

Commit 4734dbf

Browse files
authored
♻️ refactor the codegen for future
1 parent 5ba5574 commit 4734dbf

39 files changed

+8833
-85
lines changed

codegen/__init__.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from jinja2 import Environment, PackageLoader
88

99
from .log import logger
10-
from .config import Config
1110
from .source import get_source
11+
from .config import Config, RestConfig, WebhookConfig
1212
from .parser import (
1313
OpenAPIData,
1414
WebhookData,
@@ -39,16 +39,14 @@
3939
def load_config() -> Config:
4040
pyproject = tomli.loads(Path("./pyproject.toml").read_text())
4141
config_dict: Dict[str, Any] = pyproject.get("tool", {}).get("codegen", {})
42-
config_dict = {
43-
k.replace("--", "").replace("-", "_"): v for k, v in config_dict.items()
44-
}
42+
4543
return Config.parse_obj(config_dict)
4644

4745

48-
def build_rest_api(data: OpenAPIData, config: Config):
46+
def build_rest_api(data: OpenAPIData, rest: RestConfig):
4947
logger.info("Start generating rest api codes...")
5048

51-
client_path = Path(config.client_output)
49+
client_path = Path(rest.output_dir)
5250
shutil.rmtree(client_path)
5351
client_path.mkdir(parents=True, exist_ok=True)
5452

@@ -74,7 +72,7 @@ def build_rest_api(data: OpenAPIData, config: Config):
7472
tag_path = client_path / f"{tag}.py"
7573
tag_path.write_text(
7674
client_template.render(
77-
tag=tag, endpoints=endpoints, rest_api_version=config.rest_api_version
75+
tag=tag, endpoints=endpoints, rest_api_version=rest.version
7876
)
7977
)
8078
logger.info(f"Successfully built endpoints for tag {tag}!")
@@ -92,21 +90,21 @@ def build_rest_api(data: OpenAPIData, config: Config):
9290
logger.info("Successfully generated rest api codes!")
9391

9492

95-
def build_webhook(data: WebhookData, config: Config):
93+
def build_webhook(data: WebhookData, webhook: WebhookConfig):
9694
logger.info("Start generating webhook codes...")
9795

9896
# build models
9997
logger.info("Building webhook models...")
10098
models_template = env.get_template("models/webhooks.py.jinja")
101-
models_path = Path(config.webhooks_output)
99+
models_path = Path(webhook.output)
102100
models_path.parent.mkdir(parents=True, exist_ok=True)
103101
models_path.write_text(models_template.render(models=data.models))
104102
logger.info("Successfully built webhook models!")
105103

106104
# build types
107105
logger.info("Building webhook types...")
108106
types_template = env.get_template("models/webhook_types.py.jinja")
109-
types_path = Path(config.webhook_types_output)
107+
types_path = Path(webhook.types_output)
110108
types_path.parent.mkdir(parents=True, exist_ok=True)
111109
types_path.write_text(
112110
types_template.render(
@@ -133,30 +131,31 @@ def build():
133131
config = load_config()
134132
logger.info(f"Loaded config: {config!r}")
135133

136-
logger.info("Start getting OpenAPI source...")
137-
source = get_source(httpx.URL(config.rest_description_source))
138-
logger.info(f"Getting schema from {source.uri} succeeded!")
139-
140-
logger.info("Start parsing OpenAPI spec...")
141-
_patch_openapi_spec(source.root)
142-
parsed_data = parse_openapi_spec(source, config)
143-
logger.info(
144-
"Successfully parsed OpenAPI spec: "
145-
f"{len(parsed_data.schemas)} schemas, {len(parsed_data.endpoints)} endpoints"
146-
)
134+
for versioned_rest in config.rest:
135+
logger.info(f"Start getting OpenAPI source for {versioned_rest.version}...")
136+
source = get_source(httpx.URL(versioned_rest.description_source))
137+
logger.info(f"Getting schema from {source.uri} succeeded!")
138+
139+
logger.info(f"Start parsing OpenAPI spec for {versioned_rest.version}...")
140+
_patch_openapi_spec(source.root)
141+
parsed_data = parse_openapi_spec(source, versioned_rest, config)
142+
logger.info(
143+
f"Successfully parsed OpenAPI spec {versioned_rest.version}: "
144+
f"{len(parsed_data.schemas)} schemas, {len(parsed_data.endpoints)} endpoints"
145+
)
147146

148-
build_rest_api(parsed_data, config)
147+
build_rest_api(parsed_data, versioned_rest)
149148

150-
del parsed_data
149+
del parsed_data
151150

152151
logger.info("Start getting Webhook source...")
153-
source = get_source(httpx.URL(config.webhook_schema_source))
152+
source = get_source(httpx.URL(config.webhook.schema_source))
154153
logger.info(f"Getting schema from {source.uri} succeeded!")
155154

156155
logger.info("Start parsing Webhook spec...")
157-
parsed_data = parse_webhook_schema(source, config)
156+
parsed_data = parse_webhook_schema(source, config.webhook, config)
158157
logger.info(
159158
f"Successfully parsed Webhook spec: {len(parsed_data.definitions)} schemas"
160159
)
161160

162-
build_webhook(parsed_data, config)
161+
build_webhook(parsed_data, config.webhook)

codegen/config.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,27 @@
1-
from typing import Any, Dict
1+
from typing import Any, Dict, List
22

3-
from pydantic import BaseModel
43
import openapi_schema_pydantic as oas
4+
from pydantic import Field, BaseModel
55

66

7-
class Config(BaseModel):
8-
rest_description_source: str
9-
rest_api_version: str
10-
webhook_schema_source: str
11-
class_overrides: Dict[str, str] = {}
12-
field_overrides: Dict[str, str] = {}
13-
schema_overrides: Dict[str, Dict[str, Any]] = {}
7+
class Overridable(BaseModel):
8+
class_overrides: Dict[str, str] = Field(default_factory=dict)
9+
field_overrides: Dict[str, str] = Field(default_factory=dict)
10+
schema_overrides: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
1411

15-
client_output: str
16-
webhooks_output: str
17-
webhook_types_output: str
12+
13+
class RestConfig(Overridable):
14+
version: str
15+
description_source: str
16+
output_dir: str
17+
18+
19+
class WebhookConfig(Overridable):
20+
schema_source: str
21+
output: str
22+
types_output: str
23+
24+
25+
class Config(Overridable):
26+
rest: List[RestConfig]
27+
webhook: WebhookConfig

codegen/parser/__init__.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
from contextvars import ContextVar
2-
from typing import Dict, List, Union, Optional
2+
from typing import Dict, List, Tuple, Union, Optional
33

44
import httpx
55
from openapi_schema_pydantic import OpenAPI
66

77
# parser context
8+
_override_config: ContextVar[Tuple["Overridable", ...]] = ContextVar("override_config")
89
_schemas: ContextVar[Dict[httpx.URL, "SchemaData"]] = ContextVar("schemas")
9-
_config: ContextVar["Config"] = ContextVar("config")
1010

1111

12-
def get_config() -> "Config":
13-
return _config.get()
12+
def get_override_config() -> Tuple["Overridable", ...]:
13+
return _override_config.get()
1414

1515

1616
def get_schemas() -> Dict[httpx.URL, "SchemaData"]:
@@ -27,7 +27,6 @@ def add_schema(ref: httpx.URL, schema: "SchemaData"):
2727
_schemas.get()[ref] = schema
2828

2929

30-
from ..config import Config
3130
from ..source import Source
3231
from .endpoints import parse_endpoint
3332
from .utils import sanitize as sanitize
@@ -39,13 +38,14 @@ def add_schema(ref: httpx.URL, schema: "SchemaData"):
3938
from .endpoints import EndpointData as EndpointData
4039
from .schemas import SchemaData, UnionSchema, parse_schema
4140
from .utils import fix_reserved_words as fix_reserved_words
41+
from ..config import Config, RestConfig, Overridable, WebhookConfig
4242

4343

44-
def parse_openapi_spec(source: Source, config: Config) -> OpenAPIData:
44+
def parse_openapi_spec(source: Source, rest: RestConfig, config: Config) -> OpenAPIData:
4545
source = source.get_root()
4646

47+
_ot = _override_config.set((rest, config))
4748
_st = _schemas.set({})
48-
_ct = _config.set(config)
4949

5050
try:
5151
openapi = OpenAPI.parse_obj(source.root)
@@ -70,15 +70,17 @@ def parse_openapi_spec(source: Source, config: Config) -> OpenAPIData:
7070
schemas=list(get_schemas().values()),
7171
)
7272
finally:
73+
_override_config.reset(_ot)
7374
_schemas.reset(_st)
74-
_config.reset(_ct)
7575

7676

77-
def parse_webhook_schema(source: Source, config: Config) -> WebhookData:
77+
def parse_webhook_schema(
78+
source: Source, webhook: WebhookConfig, config: Config
79+
) -> WebhookData:
7880
source = source.get_root()
7981

82+
_ot = _override_config.set((webhook, config))
8083
_st = _schemas.set({})
81-
_ct = _config.set(config)
8284

8385
try:
8486
root_schema = parse_schema(source, "webhook_schema")
@@ -106,5 +108,5 @@ def parse_webhook_schema(source: Source, config: Config) -> WebhookData:
106108
definitions=definitions,
107109
)
108110
finally:
111+
_override_config.reset(_ot)
109112
_schemas.reset(_st)
110-
_config.reset(_ct)

codegen/parser/utils.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from pydantic import parse_obj_as
88
import openapi_schema_pydantic as oas
99

10-
from . import get_config
1110
from ..source import Source
11+
from . import get_override_config
1212

1313
DELIMITERS = r"\. _-"
1414

@@ -80,20 +80,29 @@ def build_boolean(value: Union[bool, str]) -> bool:
8080

8181

8282
def build_class_name(name: str) -> str:
83-
config = get_config()
83+
sources = get_override_config()
8484
class_name = fix_reserved_words(pascal_case(name))
85-
return config.class_overrides.get(class_name, class_name)
85+
for override_source in sources:
86+
if override := override_source.class_overrides.get(class_name):
87+
return override
88+
return class_name
8689

8790

8891
def build_prop_name(name: str) -> str:
89-
config = get_config()
90-
name = config.field_overrides.get(name, name)
92+
sources = get_override_config()
93+
for override_source in sources:
94+
if override := override_source.field_overrides.get(name):
95+
name = override
96+
break
9197
return fix_reserved_words(snake_case(name))
9298

9399

94100
def get_schema_override(source: Source) -> Optional[Dict[str, Any]]:
95-
config = get_config()
96-
return config.schema_overrides.get(source.pointer.path, None)
101+
sources = get_override_config()
102+
for override_source in sources:
103+
if schema := override_source.schema_overrides.get(source.pointer.path):
104+
return schema
105+
return None
97106

98107

99108
def merge_dict(old: dict, new: dict):

codegen/source.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __truediv__(self, other: Union[str, int]) -> "Source":
3838

3939

4040
@cache
41-
def get_content(source: Union[httpx.URL, Path]) -> OpenAPI:
41+
def get_content(source: Union[httpx.URL, Path]) -> dict:
4242
return (
4343
json.loads(source.read_text())
4444
if isinstance(source, Path)

codegen/templates/client/_request.py.jinja

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,32 @@ url = "{{ endpoint.path }}"
1111
{% endmacro %}
1212

1313
{% macro build_query(params) %}
14+
{% if params %}
1415
params = {
1516
{% for param in params %}
1617
"{{ param.name }}": {{ param.prop_name }},
1718
{% endfor %}
1819
}
20+
{% endif %}
1921
{% endmacro %}
2022

2123
{% macro build_header(params) %}
2224
headers = {
2325
{% for param in params %}
2426
"{{ param.name }}": {{ param.prop_name }},
2527
{% endfor %}
28+
"X-GitHub-Api-Version": self._REST_API_VERSION,
2629
}
2730
{% endmacro %}
2831

2932
{% macro build_cookie(params) %}
33+
{% if params %}
3034
cookies = {
3135
{% for param in params %}
3236
"{{ param.name }}": {{ param.prop_name }},
3337
{% endfor %}
3438
}
39+
{% endif %}
3540
{% endmacro %}
3641

3742
{% macro build_body(request_body) %}
@@ -49,9 +54,9 @@ if not kwargs:
4954

5055
{% macro build_request(endpoint) %}
5156
{{ build_path(endpoint) }}
52-
{% if endpoint.query_params %}
5357
{{ build_query(endpoint.query_params) }}
54-
{% endif %}
58+
{{ build_header(endpoint.header_params) }}
59+
{{ build_cookie(endpoint.cookie_params) }}
5560
{% if endpoint.request_body %}
5661
{{ build_body(endpoint.request_body) }}
5762
{% endif %}
@@ -67,11 +72,7 @@ params=exclude_unset(params),
6772
{% set name = TYPE_MAPPING[endpoint.request_body.type] %}
6873
{{ name }}=exclude_unset({{ name }}),
6974
{% endif %}
70-
{% if endpoint.header_params %}
71-
headers=exclude_unset({**headers, "X-GitHub-Api-Version": self._REST_API_VERSION}),
72-
{% else %}
73-
headers={"X-GitHub-Api-Version": self._REST_API_VERSION},
74-
{% endif %}
75+
headers=exclude_unset(headers),
7576
{% if endpoint.cookie_params %}
7677
cookies=exclude_unset(cookies),
7778
{% endif %}

0 commit comments

Comments
 (0)