|
1 | 1 | import logging
|
2 |
| -from typing import Any, Callable, Dict, Optional |
| 2 | +from typing import Any, Callable, Dict, List, Optional |
3 | 3 |
|
4 | 4 | import torch
|
5 | 5 | from torch._decomp import register_decomposition
|
@@ -83,11 +83,6 @@ def inplace_op(*args, **kwargs): # type: ignore
|
83 | 83 | replace_inplace_op(aten.scatter_reduce_, aten.scatter_reduce)
|
84 | 84 |
|
85 | 85 |
|
86 |
| -@register_torch_trt_decomposition(aten.std, registry=TORCH_TRT_DECOMPOSITIONS) |
87 |
| -def std_replacement(*args, **kwargs) -> torch.Tensor: # type: ignore |
88 |
| - return torch.sqrt(torch.var(*args, **kwargs)) |
89 |
| - |
90 |
| - |
91 | 86 | @register_torch_trt_decomposition(aten.rsqrt, registry=TORCH_TRT_DECOMPOSITIONS)
|
92 | 87 | def rsqrt_replacement(*args, **kwargs) -> torch.Tensor: # type: ignore
|
93 | 88 | return torch.reciprocal(torch.sqrt(*args, **kwargs))
|
@@ -135,6 +130,54 @@ def reciprocal_replacement(
|
135 | 130 | return torch.div(1, input_)
|
136 | 131 |
|
137 | 132 |
|
| 133 | +@register_torch_trt_decomposition( |
| 134 | + torch.ops.prims.var.default, registry=TORCH_TRT_DECOMPOSITIONS |
| 135 | +) |
| 136 | +def var_decomposition( |
| 137 | + input_tensor: torch.Tensor, |
| 138 | + dims: Optional[List[int]], |
| 139 | + correction: int, |
| 140 | + output_dtype: Optional[torch.dtype] = None, |
| 141 | +) -> torch.Tensor: |
| 142 | + if dims is None: |
| 143 | + dims = [] |
| 144 | + |
| 145 | + # If the dimensions are empty, variance is taken over all dimensions |
| 146 | + if isinstance(dims, (tuple, list)) and len(dims) == 0: |
| 147 | + N = input_tensor.numel() |
| 148 | + # Otherwise, the number of samples is the product of the dimensions reduced over |
| 149 | + else: |
| 150 | + N = 1 |
| 151 | + for dim_i in dims: |
| 152 | + N *= input_tensor.shape[dim_i] |
| 153 | + |
| 154 | + # Compute the mean, difference, and correction term as per the formula: |
| 155 | + # https://pytorch.org/docs/stable/generated/torch.var.html |
| 156 | + |
| 157 | + # Additionally, prims does not support keepdim, and so we only keep dimensions |
| 158 | + # on the first reduction, then remove it for the second |
| 159 | + sample_mean = torch.mean(input_tensor, dims, keepdim=True) |
| 160 | + diff = input_tensor - sample_mean |
| 161 | + squared_diff = diff * diff |
| 162 | + variance_unnormalized = torch.sum(squared_diff, dims, keepdim=False) |
| 163 | + |
| 164 | + if correction is None: |
| 165 | + correction_term = float(N - 1) |
| 166 | + elif isinstance(correction, int): |
| 167 | + correction_term = float(N - correction) |
| 168 | + elif isinstance(correction, float): |
| 169 | + correction_term = float(N) - correction |
| 170 | + else: |
| 171 | + raise RuntimeError("correction must be int or float") |
| 172 | + |
| 173 | + if correction_term <= 0: |
| 174 | + raise RuntimeError(f"correction term was non-positive, got: {correction_term}") |
| 175 | + |
| 176 | + variance = variance_unnormalized / correction_term |
| 177 | + |
| 178 | + return variance |
| 179 | + |
| 180 | + |
138 | 181 | def get_decompositions(
|
139 | 182 | enable_experimental_decompositions: bool = False,
|
140 | 183 | ) -> Dict[OpOverload, Callable[[Any], Any]]:
|
|
0 commit comments