-
Notifications
You must be signed in to change notification settings - Fork 571
[SOT] Mark dynamic dims by type annotations #2771
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SOT] Mark dynamic dims by type annotations #2771
Conversation
Thanks for your contribution! |
如何理解 Tensor 类型具有隐式动态维度 (0,),是一个规则吗 |
fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py
Show resolved
Hide resolved
是一个规则,主要是基本上所有 Tensor batch 维度基本都是动态的,从标记的角度将会很麻烦,因此添加此默认规则(该隐式规则同 vLLM) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -406,7 +406,7 @@ def load_state_dict(self, state_dict): | |||
def forward( | |||
self, | |||
ids_remove_padding: paddle.Tensor, | |||
image_features: paddle.Tensor, | |||
image_features: Optional[paddle.Tensor], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
其他forward函数中也写成 Optional[paddle.Tensor] 吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
@@ -32,4 +33,5 @@ | |||
"FlashAttentionBackend", | |||
"IluvatarAttnBackend", | |||
"BlockAttentionBackend", | |||
"Attention", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加上这个我记得是会有循环引用的?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不会有的,除非是之前设计不合理导致的
def extract_inner_types(self, data, data_name, tp) -> list[tuple[Accessor[Any, Any], str, type[Any]]]: | ||
raise NotImplementedError | ||
|
||
def resolve(self, data, data_name, tp) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
所有函数参数都需要添加类型注解
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经添加
@@ -937,11 +937,11 @@ def _setting_environ_variables(self): | |||
"SOT_LOG_LEVEL": os.getenv("SOT_LOG_LEVEL", default="0"), | |||
"SOT_UNSAFE_CACHE_FASTPATH": os.getenv("SOT_UNSAFE_CACHE_FASTPATH", default="1"), | |||
"SOT_ENABLE_0_SIZE_FALLBACK": os.getenv("SOT_ENABLE_0_SIZE_FALLBACK", default="0"), | |||
"SOT_SPECIALIZED_DIM_NUMBERS": os.getenv("SOT_SPECIALIZED_DIM_NUMBERS", default="no"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个默认是 no ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
框架里默认特化 1,即 "1"
,FD 里默认是不特化,是 "no"
@@ -144,7 +147,7 @@ def get_kv_cache_shape( | |||
self.head_dim, | |||
) | |||
|
|||
def init_attention_metadata(self, forward_meta: ForwardMeta): | |||
def init_attention_metadata(self, forward_meta: "ForwardMeta"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"ForwardMeta" 包字符串的作用是什么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
常用于
- 解决循环引用
- 降低运行时解析开销
但这些文件都使用了 PEP 563,所以这里其实加不加 "
效果是一样的,为提高可读性这里恢复了下
利用类型提示来标记动态维度,以保证动转静能够一次即可收敛,避免重复组网
dataclass
、Optional[T]
(0,)
Annotated[Tensor, DynamicDims((1, 2))]
标记其余维度示例如下
另外在 warmup 阶段会根据是否发生打断、是否有重复构图来确定能否用更加快速的实现,直接使用编译好的 code 来执行,当无法使用时会报 warning