Skip to content

Commit e4db5dc

Browse files
Revert "[BE] remove unnecessary _dispatch_sqrt by using ** 0.5 (pytorch#131358)"
This reverts commit 4c7f22d. Reverted pytorch#131358 on behalf of https://github.com/janeyx99 due to Internal uses this private API and landing that has been a pain so we're reverting this first ([comment](pytorch#131358 (comment)))
1 parent 2576dbb commit e4db5dc

File tree

5 files changed

+22
-8
lines changed

5 files changed

+22
-8
lines changed

torch/optim/adam.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
_default_to_fused_or_foreach,
1111
_differentiable_doc,
1212
_disable_dynamo_if_unsupported,
13+
_dispatch_sqrt,
1314
_foreach_doc,
1415
_fused_doc,
1516
_get_capturable_supported_devices,
@@ -424,7 +425,7 @@ def _single_tensor_adam(
424425

425426
step_size = lr / bias_correction1
426427

427-
bias_correction2_sqrt = bias_correction2**0.5
428+
bias_correction2_sqrt = _dispatch_sqrt(bias_correction2)
428429

429430
if amsgrad:
430431
# Maintains the maximum of all 2nd moment running avg. till now
@@ -595,7 +596,7 @@ def _multi_tensor_adam(
595596

596597
step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1])
597598

598-
bias_correction2_sqrt = [bc**0.5 for bc in bias_correction2] # type: ignore[arg-type]
599+
bias_correction2_sqrt = [_dispatch_sqrt(bc) for bc in bias_correction2] # type: ignore[arg-type]
599600

600601
if amsgrad:
601602
# Maintains the maximum of all 2nd moment running avg. till now

torch/optim/adamw.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
_default_to_fused_or_foreach,
1111
_differentiable_doc,
1212
_disable_dynamo_if_unsupported,
13+
_dispatch_sqrt,
1314
_foreach_doc,
1415
_fused_doc,
1516
_get_capturable_supported_devices,
@@ -425,7 +426,7 @@ def _single_tensor_adamw(
425426

426427
step_size = lr / bias_correction1
427428

428-
bias_correction2_sqrt = bias_correction2**0.5
429+
bias_correction2_sqrt = _dispatch_sqrt(bias_correction2)
429430

430431
if amsgrad:
431432
# Maintains the maximum of all 2nd moment running avg. till now
@@ -592,7 +593,7 @@ def _multi_tensor_adamw(
592593
step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1])
593594

594595
bias_correction2_sqrt = [
595-
bc**0.5 for bc in bias_correction2 # type: ignore[arg-type]
596+
_dispatch_sqrt(bc) for bc in bias_correction2 # type: ignore[arg-type]
596597
]
597598

598599
if amsgrad:

torch/optim/nadam.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
_default_to_fused_or_foreach,
1111
_differentiable_doc,
1212
_disable_dynamo_if_unsupported,
13+
_dispatch_sqrt,
1314
_foreach_doc,
1415
_get_capturable_supported_devices,
1516
_get_scalar_dtype,
@@ -488,7 +489,8 @@ def _multi_tensor_nadam(
488489
torch._foreach_sqrt_(bias_correction_sqrt)
489490
else:
490491
bias_correction_sqrt = [
491-
(1 - beta2 ** _get_value(step)) ** 0.5 for step in grouped_state_steps
492+
_dispatch_sqrt(1 - beta2 ** _get_value(step))
493+
for step in grouped_state_steps
492494
]
493495
mus = [
494496
beta1 * (1.0 - 0.5 * (0.96 ** (_get_value(step) * momentum_decay)))

torch/optim/optimizer.py

+10
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# mypy: allow-untyped-defs
33
"""Base optimizer."""
44
import functools
5+
import math
56
import warnings
67
from collections import defaultdict, OrderedDict
78
from copy import deepcopy
@@ -112,6 +113,15 @@ def _stack_if_compiling(x):
112113
return x
113114

114115

116+
def _dispatch_sqrt(
117+
x: float,
118+
): # float annotation is needed because of torchscript type inference
119+
if not torch.jit.is_scripting() and isinstance(x, torch.Tensor):
120+
return x.sqrt()
121+
else:
122+
return math.sqrt(x)
123+
124+
115125
def _disable_dynamo_if_unsupported(single_tensor_fn=None):
116126
# workaround for torchscript BC
117127
# it requires all called functions to be in the

torch/optim/radam.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
_default_to_fused_or_foreach,
1212
_differentiable_doc,
1313
_disable_dynamo_if_unsupported,
14+
_dispatch_sqrt,
1415
_foreach_doc,
1516
_get_capturable_supported_devices,
1617
_get_scalar_dtype,
@@ -504,13 +505,12 @@ def _multi_tensor_radam(
504505
del bias_correction1
505506
else:
506507
rect = [
507-
(
508+
_dispatch_sqrt(
508509
(rho_t - 4) # type: ignore[arg-type]
509510
* (rho_t - 2)
510511
* rho_inf
511512
/ ((rho_inf - 4) * (rho_inf - 2) * rho_t)
512513
)
513-
** 0.5
514514
if rho_t > 5
515515
else 0
516516
for rho_t in rho_t_list
@@ -524,7 +524,7 @@ def _multi_tensor_radam(
524524
(lr * rect / bc) * -1 for rect, bc in zip(unrectified, bias_correction1)
525525
]
526526
bias_correction2 = [
527-
((1 - beta2 ** _get_value(step)) ** 0.5) * (lr * rect / bc) * -1
527+
_dispatch_sqrt(1 - beta2 ** _get_value(step)) * (lr * rect / bc) * -1
528528
for step, rect, bc in zip(grouped_state_steps, rect, bias_correction1)
529529
]
530530

0 commit comments

Comments
 (0)