Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit a79b70e

Browse files
committed
fix cuda initialize issue
Signed-off-by: YunLiu <[email protected]>
1 parent 5672a90 commit a79b70e

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

generative/networks/layers/vector_quantizer.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def __init__(
8383
range(1, self.spatial_dims + 1)
8484
)
8585

86-
@torch.cuda.amp.autocast(enabled=False)
8786
def quantize(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
8887
"""
8988
Given an input it projects it to the quantized space and returns additional tensors needed for EMA loss.
@@ -100,28 +99,28 @@ def quantize(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, to
10099
encoding_indices_view = list(inputs.shape)
101100
del encoding_indices_view[1]
102101

103-
inputs = inputs.float()
102+
with torch.cuda.amp.autocast(enabled=False):
103+
inputs = inputs.float()
104104

105-
# Converting to channel last format
106-
flat_input = inputs.permute(self.flatten_permutation).contiguous().view(-1, self.embedding_dim)
105+
# Converting to channel last format
106+
flat_input = inputs.permute(self.flatten_permutation).contiguous().view(-1, self.embedding_dim)
107107

108-
# Calculate Euclidean distances
109-
distances = (
110-
(flat_input**2).sum(dim=1, keepdim=True)
111-
+ (self.embedding.weight.t() ** 2).sum(dim=0, keepdim=True)
112-
- 2 * torch.mm(flat_input, self.embedding.weight.t())
113-
)
108+
# Calculate Euclidean distances
109+
distances = (
110+
(flat_input**2).sum(dim=1, keepdim=True)
111+
+ (self.embedding.weight.t() ** 2).sum(dim=0, keepdim=True)
112+
- 2 * torch.mm(flat_input, self.embedding.weight.t())
113+
)
114114

115-
# Mapping distances to indexes
116-
encoding_indices = torch.max(-distances, dim=1)[1]
117-
encodings = torch.nn.functional.one_hot(encoding_indices, self.num_embeddings).float()
115+
# Mapping distances to indexes
116+
encoding_indices = torch.max(-distances, dim=1)[1]
117+
encodings = torch.nn.functional.one_hot(encoding_indices, self.num_embeddings).float()
118118

119-
# Quantize and reshape
120-
encoding_indices = encoding_indices.view(encoding_indices_view)
119+
# Quantize and reshape
120+
encoding_indices = encoding_indices.view(encoding_indices_view)
121121

122122
return flat_input, encodings, encoding_indices
123123

124-
@torch.cuda.amp.autocast(enabled=False)
125124
def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor:
126125
"""
127126
Given encoding indices of shape [B,D,H,W,1] embeds them in the quantized space
@@ -135,7 +134,8 @@ def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor:
135134
Returns:
136135
torch.Tensor: Quantize space representation of encoding_indices in channel first format.
137136
"""
138-
return self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous()
137+
with torch.cuda.amp.autocast(enabled=False):
138+
return self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous()
139139

140140
@torch.jit.unused
141141
def distributed_synchronization(self, encodings_sum: torch.Tensor, dw: torch.Tensor) -> None:

0 commit comments

Comments
 (0)