@@ -83,7 +83,6 @@ def __init__(
83
83
range (1 , self .spatial_dims + 1 )
84
84
)
85
85
86
- @torch .cuda .amp .autocast (enabled = False )
87
86
def quantize (self , inputs : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
88
87
"""
89
88
Given an input it projects it to the quantized space and returns additional tensors needed for EMA loss.
@@ -100,28 +99,28 @@ def quantize(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, to
100
99
encoding_indices_view = list (inputs .shape )
101
100
del encoding_indices_view [1 ]
102
101
103
- inputs = inputs .float ()
102
+ with torch .cuda .amp .autocast (enabled = False ):
103
+ inputs = inputs .float ()
104
104
105
- # Converting to channel last format
106
- flat_input = inputs .permute (self .flatten_permutation ).contiguous ().view (- 1 , self .embedding_dim )
105
+ # Converting to channel last format
106
+ flat_input = inputs .permute (self .flatten_permutation ).contiguous ().view (- 1 , self .embedding_dim )
107
107
108
- # Calculate Euclidean distances
109
- distances = (
110
- (flat_input ** 2 ).sum (dim = 1 , keepdim = True )
111
- + (self .embedding .weight .t () ** 2 ).sum (dim = 0 , keepdim = True )
112
- - 2 * torch .mm (flat_input , self .embedding .weight .t ())
113
- )
108
+ # Calculate Euclidean distances
109
+ distances = (
110
+ (flat_input ** 2 ).sum (dim = 1 , keepdim = True )
111
+ + (self .embedding .weight .t () ** 2 ).sum (dim = 0 , keepdim = True )
112
+ - 2 * torch .mm (flat_input , self .embedding .weight .t ())
113
+ )
114
114
115
- # Mapping distances to indexes
116
- encoding_indices = torch .max (- distances , dim = 1 )[1 ]
117
- encodings = torch .nn .functional .one_hot (encoding_indices , self .num_embeddings ).float ()
115
+ # Mapping distances to indexes
116
+ encoding_indices = torch .max (- distances , dim = 1 )[1 ]
117
+ encodings = torch .nn .functional .one_hot (encoding_indices , self .num_embeddings ).float ()
118
118
119
- # Quantize and reshape
120
- encoding_indices = encoding_indices .view (encoding_indices_view )
119
+ # Quantize and reshape
120
+ encoding_indices = encoding_indices .view (encoding_indices_view )
121
121
122
122
return flat_input , encodings , encoding_indices
123
123
124
- @torch .cuda .amp .autocast (enabled = False )
125
124
def embed (self , embedding_indices : torch .Tensor ) -> torch .Tensor :
126
125
"""
127
126
Given encoding indices of shape [B,D,H,W,1] embeds them in the quantized space
@@ -135,7 +134,8 @@ def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor:
135
134
Returns:
136
135
torch.Tensor: Quantize space representation of encoding_indices in channel first format.
137
136
"""
138
- return self .embedding (embedding_indices ).permute (self .quantization_permutation ).contiguous ()
137
+ with torch .cuda .amp .autocast (enabled = False ):
138
+ return self .embedding (embedding_indices ).permute (self .quantization_permutation ).contiguous ()
139
139
140
140
@torch .jit .unused
141
141
def distributed_synchronization (self , encodings_sum : torch .Tensor , dw : torch .Tensor ) -> None :
0 commit comments