File tree Expand file tree Collapse file tree 1 file changed +13
-6
lines changed Expand file tree Collapse file tree 1 file changed +13
-6
lines changed Original file line number Diff line number Diff line change 21
21
22
22
import numpy as np
23
23
24
- from .import_utils import is_torch_available
24
+ from .import_utils import is_torch_available , is_torch_version
25
25
26
26
27
27
def is_tensor (x ) -> bool :
@@ -60,11 +60,18 @@ def __init_subclass__(cls) -> None:
60
60
if is_torch_available ():
61
61
import torch .utils ._pytree
62
62
63
- torch .utils ._pytree ._register_pytree_node (
64
- cls ,
65
- torch .utils ._pytree ._dict_flatten ,
66
- lambda values , context : cls (** torch .utils ._pytree ._dict_unflatten (values , context )),
67
- )
63
+ if is_torch_version ("<" , "2.2" ):
64
+ torch .utils ._pytree ._register_pytree_node (
65
+ cls ,
66
+ torch .utils ._pytree ._dict_flatten ,
67
+ lambda values , context : cls (** torch .utils ._pytree ._dict_unflatten (values , context )),
68
+ )
69
+ else :
70
+ torch .utils ._pytree .register_pytree_node (
71
+ cls ,
72
+ torch .utils ._pytree ._dict_flatten ,
73
+ lambda values , context : cls (** torch .utils ._pytree ._dict_unflatten (values , context )),
74
+ )
68
75
69
76
def __post_init__ (self ) -> None :
70
77
class_fields = fields (self )
You can’t perform that action at this time.
0 commit comments