Skip to content

Commit

Permalink
Update api body params to schema (#461)
Browse files Browse the repository at this point in the history
  • Loading branch information
wu-clan authored Nov 16, 2024
1 parent 12d7ec8 commit 5e60d9c
Show file tree
Hide file tree
Showing 12 changed files with 76 additions and 43 deletions.
9 changes: 5 additions & 4 deletions backend/app/admin/api/v1/sys/casbin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
DeletePolicyParam,
DeleteUserRoleParam,
GetPolicyListDetails,
UpdatePoliciesParam,
UpdatePolicyParam,
)
from backend.app.admin.service.casbin_service import casbin_service
Expand Down Expand Up @@ -92,8 +93,8 @@ async def create_policies(ps: list[CreatePolicyParam]) -> ResponseModel:
DependsRBAC,
],
)
async def update_policy(old: UpdatePolicyParam, new: UpdatePolicyParam) -> ResponseModel:
data = await casbin_service.update_policy(old=old, new=new)
async def update_policy(obj: UpdatePolicyParam) -> ResponseModel:
data = await casbin_service.update_policy(obj=obj)
return response_base.success(data=data)


Expand All @@ -105,8 +106,8 @@ async def update_policy(old: UpdatePolicyParam, new: UpdatePolicyParam) -> Respo
DependsRBAC,
],
)
async def update_policies(old: list[UpdatePolicyParam], new: list[UpdatePolicyParam]) -> ResponseModel:
data = await casbin_service.update_policies(old=old, new=new)
async def update_policies(obj: UpdatePoliciesParam) -> ResponseModel:
data = await casbin_service.update_policies(obj=obj)
return response_base.success(data=data)


Expand Down
2 changes: 1 addition & 1 deletion backend/app/admin/schema/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class ApiSchemaBase(SchemaBase):
name: str
method: MethodType = Field(default=MethodType.GET, description='请求方法')
path: str = Field(..., description='api路径')
path: str = Field(description='api路径')
remark: str | None = None


Expand Down
24 changes: 15 additions & 9 deletions backend/app/admin/schema/casbin_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,19 @@


class CreatePolicyParam(SchemaBase):
sub: str = Field(..., description='用户uuid / 角色ID')
path: str = Field(..., description='api 路径')
sub: str = Field(description='用户uuid / 角色ID')
path: str = Field(description='api 路径')
method: MethodType = Field(default=MethodType.GET, description='请求方法')


class UpdatePolicyParam(CreatePolicyParam):
pass
class UpdatePolicyParam(SchemaBase):
old: CreatePolicyParam
new: CreatePolicyParam


class UpdatePoliciesParam(SchemaBase):
old: list[CreatePolicyParam]
new: list[CreatePolicyParam]


class DeletePolicyParam(CreatePolicyParam):
Expand All @@ -26,8 +32,8 @@ class DeleteAllPoliciesParam(SchemaBase):


class CreateUserRoleParam(SchemaBase):
uuid: str = Field(..., description='用户 uuid')
role: str = Field(..., description='角色')
uuid: str = Field(description='用户 uuid')
role: str = Field(description='角色')


class DeleteUserRoleParam(CreateUserRoleParam):
Expand All @@ -38,9 +44,9 @@ class GetPolicyListDetails(SchemaBase):
model_config = ConfigDict(from_attributes=True)

id: int
ptype: str = Field(..., description='规则类型, p / g')
v0: str = Field(..., description='用户 uuid / 角色')
v1: str = Field(..., description='api 路径 / 角色')
ptype: str = Field(description='规则类型, p / g')
v0: str = Field(description='用户 uuid / 角色')
v1: str = Field(description='api 路径 / 角色')
v2: str | None = None
v3: str | None = None
v4: str | None = None
Expand Down
8 changes: 4 additions & 4 deletions backend/app/admin/schema/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,21 @@ class AuthLoginParam(AuthSchemaBase):

class RegisterUserParam(AuthSchemaBase):
nickname: str | None = None
email: EmailStr = Field(..., examples=['[email protected]'])
email: EmailStr = Field(examples=['[email protected]'])


class AddUserParam(AuthSchemaBase):
dept_id: int
roles: list[int]
nickname: str | None = None
email: EmailStr = Field(..., examples=['[email protected]'])
email: EmailStr = Field(examples=['[email protected]'])


class UserInfoSchemaBase(SchemaBase):
dept_id: int | None = None
username: str
nickname: str
email: EmailStr = Field(..., examples=['[email protected]'])
email: EmailStr = Field(examples=['[email protected]'])
phone: CustomPhoneNumber | None = None


Expand All @@ -49,7 +49,7 @@ class UpdateUserRoleParam(SchemaBase):


class AvatarParam(SchemaBase):
url: HttpUrl = Field(..., description='头像 http 地址')
url: HttpUrl = Field(description='头像 http 地址')


class GetUserInfoNoRelationDetail(UserInfoSchemaBase):
Expand Down
15 changes: 10 additions & 5 deletions backend/app/admin/service/casbin_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
DeleteAllPoliciesParam,
DeletePolicyParam,
DeleteUserRoleParam,
UpdatePoliciesParam,
UpdatePolicyParam,
)
from backend.common.exception import errors
Expand Down Expand Up @@ -49,19 +50,23 @@ async def create_policies(*, ps: list[CreatePolicyParam]) -> bool:
return data

@staticmethod
async def update_policy(*, old: UpdatePolicyParam, new: UpdatePolicyParam) -> bool:
async def update_policy(*, obj: UpdatePolicyParam) -> bool:
old_obj = obj.old
new_obj = obj.new
enforcer = await rbac.enforcer()
_p = enforcer.has_policy(old.sub, old.path, old.method)
_p = enforcer.has_policy(old_obj.sub, old_obj.path, old_obj.method)
if not _p:
raise errors.NotFoundError(msg='权限不存在')
data = await enforcer.update_policy([old.sub, old.path, old.method], [new.sub, new.path, new.method])
data = await enforcer.update_policy(
[old_obj.sub, old_obj.path, old_obj.method], [new_obj.sub, new_obj.path, new_obj.method]
)
return data

@staticmethod
async def update_policies(*, old: list[UpdatePolicyParam], new: list[UpdatePolicyParam]) -> bool:
async def update_policies(*, obj: UpdatePoliciesParam) -> bool:
enforcer = await rbac.enforcer()
data = await enforcer.update_policies(
[list(o.model_dump().values()) for o in old], [list(n.model_dump().values()) for n in new]
[list(o.model_dump().values()) for o in obj.old], [list(n.model_dump().values()) for n in obj.new]
)
return data

Expand Down
11 changes: 4 additions & 7 deletions backend/app/generator/api/v1/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
# -*- coding: utf-8 -*-
from typing import Annotated

from fastapi import APIRouter, Body, Depends, Path, Query
from fastapi import APIRouter, Depends, Path, Query
from fastapi.responses import StreamingResponse

from backend.app.generator.conf import generator_settings
from backend.app.generator.schema.gen import ImportParam
from backend.app.generator.service.gen_service import gen_service
from backend.common.response.response_schema import ResponseModel, response_base
from backend.common.security.jwt import DependsJwtAuth
Expand All @@ -29,12 +30,8 @@ async def get_all_tables(table_schema: Annotated[str, Query(..., description='
DependsRBAC,
],
)
async def import_table(
app: Annotated[str, Body(..., description='应用名称,用于代码生成到指定 app')],
table_name: Annotated[str, Body(..., description='数据库表名')],
table_schema: Annotated[str, Body(..., description='数据库名')] = 'fba',
) -> ResponseModel:
await gen_service.import_business_and_model(app=app, table_schema=table_schema, table_name=table_name)
async def import_table(obj: ImportParam) -> ResponseModel:
await gen_service.import_business_and_model(obj=obj)
return response_base.success()


Expand Down
11 changes: 11 additions & 0 deletions backend/app/generator/schema/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from pydantic import Field

from backend.common.schema import SchemaBase


class ImportParam(SchemaBase):
app: str = Field(description='应用名称,用于代码生成到指定 app')
table_name: str = Field(description='数据库表名')
table_schema: str = Field(description='数据库名')
11 changes: 6 additions & 5 deletions backend/app/generator/service/gen_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from backend.app.generator.crud.crud_gen_business import gen_business_dao
from backend.app.generator.crud.crud_gen_model import gen_model_dao
from backend.app.generator.model import GenBusiness
from backend.app.generator.schema.gen import ImportParam
from backend.app.generator.schema.gen_business import CreateGenBusinessParam
from backend.app.generator.schema.gen_model import CreateGenModelParam
from backend.app.generator.service.gen_model_service import gen_model_service
Expand All @@ -32,17 +33,17 @@ async def get_tables(*, table_schema: str) -> Sequence[str]:
return await gen_dao.get_all_tables(db, table_schema)

@staticmethod
async def import_business_and_model(*, app: str, table_schema: str, table_name: str) -> None:
async def import_business_and_model(*, obj: ImportParam) -> None:
async with async_db_session.begin() as db:
table_info = await gen_dao.get_table(db, table_name)
table_info = await gen_dao.get_table(db, obj.table_name)
if not table_info:
raise errors.NotFoundError(msg='数据库表不存在')
business_info = await gen_business_dao.get_by_name(db, table_name)
business_info = await gen_business_dao.get_by_name(db, obj.table_name)
if business_info:
raise errors.ForbiddenError(msg='已存在相同数据库表业务')
table_name = table_info[0]
business_data = {
'app_name': app,
'app_name': obj.app,
'table_name_en': table_name,
'table_name_zh': table_info[1] or ' '.join(table_name.split('_')),
'table_simple_name_zh': table_info[1] or table_name.split('_')[-1],
Expand All @@ -51,7 +52,7 @@ async def import_business_and_model(*, app: str, table_schema: str, table_name:
new_business = GenBusiness(**CreateGenBusinessParam(**business_data).model_dump())
db.add(new_business)
await db.flush()
column_info = await gen_dao.get_all_columns(db, table_schema, table_name)
column_info = await gen_dao.get_all_columns(db, obj.table_schema, table_name)
for column in column_info:
column_type = column[-1].split('(')[0].upper()
pd_type = sql_type_to_pydantic(column_type)
Expand Down
11 changes: 4 additions & 7 deletions backend/app/task/api/v1/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# -*- coding: utf-8 -*-
from typing import Annotated

from fastapi import APIRouter, Body, Depends, Path
from fastapi import APIRouter, Depends, Path

from backend.app.task.schema.task import RunParam
from backend.app.task.service.task_service import task_service
from backend.common.response.response_schema import ResponseModel, response_base
from backend.common.security.jwt import DependsJwtAuth
Expand Down Expand Up @@ -45,10 +46,6 @@ async def get_task_result(tid: Annotated[str, Path(description='任务ID')]) ->
DependsRBAC,
],
)
async def run_task(
name: Annotated[str, Path(description='任务名称')],
args: Annotated[list | None, Body(description='任务函数位置参数')] = None,
kwargs: Annotated[dict | None, Body(description='任务函数关键字参数')] = None,
) -> ResponseModel:
task = task_service.run(name=name, args=args, kwargs=kwargs)
async def run_task(obj: RunParam) -> ResponseModel:
task = task_service.run(name=obj.name, args=obj.args, kwargs=obj.kwargs)
return response_base.success(data=task)
2 changes: 2 additions & 0 deletions backend/app/task/schema/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
11 changes: 11 additions & 0 deletions backend/app/task/schema/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from pydantic import Field

from backend.common.schema import SchemaBase


class RunParam(SchemaBase):
name: str = Field(description='任务名称')
args: list | None = Field(default=None, description='任务函数位置参数')
kwargs: dict | None = Field(default=None, description='任务函数关键字参数')
4 changes: 3 additions & 1 deletion backend/utils/gen_template.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Sequence

from jinja2 import Environment, FileSystemLoader, Template, select_autoescape
from pydantic.alias_generators import to_pascal, to_snake

Expand Down Expand Up @@ -77,7 +79,7 @@ def get_code_gen_path(self, tpl_path: str, business: GenBusiness) -> str:
return code_gen_path_mapping[tpl_path]

@staticmethod
def get_vars(business: GenBusiness, models: list[GenModel]) -> dict:
def get_vars(business: GenBusiness, models: Sequence[GenModel]) -> dict:
"""
获取模版变量
Expand Down

0 comments on commit 5e60d9c

Please sign in to comment.