We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 7c33db5 commit 9bacc1cCopy full SHA for 9bacc1c
lmdeploy/pytorch/models/q_modules.py
@@ -1,6 +1,6 @@
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
-from dataclasses import dataclass
+from dataclasses import dataclass, fields
4
5
import torch
6
import torch.nn as nn
@@ -19,13 +19,15 @@ class QTensor:
19
scale: torch.Tensor
20
zero_point: torch.Tensor = None
21
22
+ def __post_init__(self):
23
+ self.fields = [field.name for field in fields(self)]
24
+
25
def __getattr__(self, name: str):
26
"""Allows attribute access to be forwarded to the wrapped tensor when
27
the attribute doesn't exist in QTensor."""
- try:
28
+ if name in self.fields:
29
return super().__getattr__(name)
- except AttributeError:
- return getattr(self.tensor, name)
30
+ return getattr(self.tensor, name)
31
32
33
class QRMSNorm(nn.Module):
0 commit comments