Skip to content

Commit 07349c2

Browse files
zyinghuaYinghua
andauthored
Fix deprecation warning for torch.utils._pytree._register_pytree_node in PyTorch 2.2 (#7008)
Fixed deprecation warning for torch.utils._pytree._register_pytree_node in PyTorch 2.2 Co-authored-by: Yinghua <[email protected]>
1 parent 8974c50 commit 07349c2

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

src/diffusers/utils/outputs.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import numpy as np
2323

24-
from .import_utils import is_torch_available
24+
from .import_utils import is_torch_available, is_torch_version
2525

2626

2727
def is_tensor(x) -> bool:
@@ -60,11 +60,18 @@ def __init_subclass__(cls) -> None:
6060
if is_torch_available():
6161
import torch.utils._pytree
6262

63-
torch.utils._pytree._register_pytree_node(
64-
cls,
65-
torch.utils._pytree._dict_flatten,
66-
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
67-
)
63+
if is_torch_version("<", "2.2"):
64+
torch.utils._pytree._register_pytree_node(
65+
cls,
66+
torch.utils._pytree._dict_flatten,
67+
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
68+
)
69+
else:
70+
torch.utils._pytree.register_pytree_node(
71+
cls,
72+
torch.utils._pytree._dict_flatten,
73+
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
74+
)
6875

6976
def __post_init__(self) -> None:
7077
class_fields = fields(self)

0 commit comments

Comments
 (0)