Skip to content

Commit e0e4f1a

Browse files
Revert "[functorch] linearize (pytorch#94173)"
This reverts commit b6b9e1e. Reverted pytorch#94173 on behalf of https://github.com/kshitij12345 due to Broke lint runner
1 parent b6b9e1e commit e0e4f1a

File tree

5 files changed

+3
-220
lines changed

5 files changed

+3
-220
lines changed

Diff for: docs/source/func.api.rst

-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ Function Transforms
1616
grad_and_value
1717
vjp
1818
jvp
19-
linearize
2019
jacrev
2120
jacfwd
2221
hessian

Diff for: test/functorch/test_eager_transforms.py

+2-104
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@
2020
import unittest
2121
import warnings
2222
import math
23-
from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCPU, dtypes, onlyCUDA
23+
from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCPU
2424
from torch.testing._internal.common_dtype import get_all_fp_dtypes
25-
from torch.testing import make_tensor
2625
from torch._subclasses.fake_tensor import FakeTensorMode
2726
from functools import partial
2827
from functorch.experimental import replace_all_batch_norm_modules_
@@ -41,7 +40,7 @@
4140
from torch._ops import PyOperator
4241
from torch._functorch.utils import enable_single_level_autograd_function
4342
import torch.autograd.forward_ad as fwAD
44-
from torch.func import functional_call, stack_module_state, linearize
43+
from torch.func import functional_call, stack_module_state
4544

4645
# NB: numpy is a testing dependency!
4746
import numpy as np
@@ -2501,102 +2500,6 @@ def push_jvp(dummy, x):
25012500
vmap(vmap(push_jvp, (0, None)))(dummy, x)
25022501

25032502

2504-
class TestLinearize(TestCase):
2505-
@dtypes(torch.float)
2506-
def test_linearize_basic(self, device, dtype):
2507-
x_p = make_tensor((3, 1), device=device, dtype=dtype)
2508-
x_t = make_tensor((3, 1), device=device, dtype=dtype)
2509-
2510-
def fn(x):
2511-
return x.cos()
2512-
2513-
actual_output, jvp_fn = linearize(fn, x_p)
2514-
actual_jvp = jvp_fn(x_t)
2515-
expected_output, expected_jvp = jvp(fn, (x_p,), (x_t,))
2516-
self.assertEqual(actual_output, expected_output)
2517-
self.assertEqual(actual_jvp, expected_jvp)
2518-
2519-
@dtypes(torch.float)
2520-
def test_linearize_return(self, device, dtype):
2521-
x_p = make_tensor((3, 1), device=device, dtype=dtype)
2522-
x_t = make_tensor((3, 1), device=device, dtype=dtype)
2523-
2524-
def fn(x):
2525-
return (x.cos(), x.sum())
2526-
2527-
actual_output, jvp_fn = linearize(fn, x_p)
2528-
actual_jvp = jvp_fn(x_t)
2529-
expected_output, expected_jvp = jvp(fn, (x_p,), (x_t,))
2530-
self.assertEqual(actual_output, expected_output)
2531-
self.assertEqual(actual_jvp, expected_jvp)
2532-
2533-
@dtypes(torch.float)
2534-
def test_linearize_composition(self, device, dtype):
2535-
x_p = make_tensor((3, 1), device=device, dtype=dtype)
2536-
x_t = make_tensor((3, 3, 1), device=device, dtype=dtype)
2537-
2538-
def fn(x):
2539-
return (x.cos(), x.sum())
2540-
2541-
_, jvp_fn = linearize(fn, x_p)
2542-
actual_batched_jvp = vmap(jvp_fn)(x_t)
2543-
2544-
def jvp_fn(x_t):
2545-
return jvp(fn, (x_p,), (x_t,))[1]
2546-
expected_batched_jvp = vmap(jvp_fn)(x_t)
2547-
2548-
self.assertEqual(actual_batched_jvp, expected_batched_jvp)
2549-
2550-
@dtypes(torch.float)
2551-
def test_linearize_nested_input_nested_output(self, device, dtype):
2552-
x_p = make_tensor((3, 1), device=device, dtype=dtype)
2553-
x_t = make_tensor((3, 1), device=device, dtype=dtype)
2554-
y_p = make_tensor((3, 1), device=device, dtype=dtype)
2555-
y_t = make_tensor((3, 1), device=device, dtype=dtype)
2556-
z_p = make_tensor((3, 1), device=device, dtype=dtype)
2557-
z_t = make_tensor((3, 1), device=device, dtype=dtype)
2558-
2559-
def fn(arg):
2560-
x = arg['x']
2561-
y = arg['yz'][0]
2562-
z = arg['yz'][1]
2563-
2564-
return {'a': x.sum(), 'b': {'c': y + z, 'd': (x * z, y.exp())}}
2565-
2566-
inp_p = {'x': x_p, 'yz': (y_p, z_p)}
2567-
inp_t = {'x': x_t, 'yz': (y_t, z_t)}
2568-
actual_output, jvp_fn = linearize(fn, inp_p)
2569-
actual_jvp = jvp_fn(inp_t)
2570-
2571-
expected_output, expected_jvp = jvp(fn, (inp_p,), (inp_t,))
2572-
2573-
self.assertEqual(actual_output, expected_output)
2574-
self.assertEqual(actual_jvp, expected_jvp)
2575-
2576-
@onlyCUDA
2577-
def test_linearize_errors(self):
2578-
dtype = torch.float
2579-
device = torch.device('cpu')
2580-
x_p = make_tensor((3, 1), device=device, dtype=dtype)
2581-
x_t = make_tensor((3, 1), device=device, dtype=dtype)
2582-
2583-
def fn(x):
2584-
return x.sin()
2585-
2586-
_, jvp_fn = linearize(fn, x_p)
2587-
2588-
with self.assertRaisesRegex(RuntimeError, "to have the same argspec as the primals"):
2589-
jvp_fn((x_t, x_t))
2590-
2591-
with self.assertRaisesRegex(RuntimeError, "in flattened pytree doesn't match the shape"):
2592-
jvp_fn(x_t.unsqueeze(0))
2593-
2594-
with self.assertRaisesRegex(RuntimeError, "in flattened pytree doesn't match the dtype"):
2595-
jvp_fn(x_t.to(torch.double))
2596-
2597-
with self.assertRaisesRegex(RuntimeError, "in flattened pytree doesn't match the device"):
2598-
jvp_fn(x_t.to(torch.device('cuda')))
2599-
26002503
# The tests here follow the cases in [Forward Grad View/inplace]
26012504
# https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/autograd_meta.cpp#L18-L43
26022505
class TestVmapJvpInplaceView(TestCase):
@@ -4549,11 +4452,6 @@ def test_functional_call_multiple_dicts(self):
45494452
globals(),
45504453
only_for=only_for,
45514454
)
4552-
instantiate_device_type_tests(
4553-
TestLinearize,
4554-
globals(),
4555-
only_for=only_for,
4556-
)
45574455
instantiate_device_type_tests(
45584456
TestVmapJvpInplaceView,
45594457
globals(),

Diff for: torch/_functorch/eager_transforms.py

+1-112
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88
import torch
99
from functools import partial, wraps
1010
import contextlib
11-
from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map, tree_map_only
12-
from torch.fx.experimental import const_fold
13-
from torch.fx.experimental.proxy_tensor import make_fx
11+
from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
1412
from .pytree_hacks import tree_map_, treespec_pprint
1513
import torch.autograd.forward_ad as fwAD
1614

@@ -1602,112 +1600,3 @@ def wrapped(*args, **kwargs):
16021600
finally:
16031601
_func_decrement_nesting()
16041602
return wrapped
1605-
1606-
@exposed_in("torch.func")
1607-
def linearize(func: Callable, *primals) -> Tuple[Any, Callable]:
1608-
'''
1609-
Returns the value of ``func`` at ``primals`` and linear approximation
1610-
at ``primals``.
1611-
1612-
Args:
1613-
func (Callable): A Python function that takes one or more arguments.
1614-
primals (Tensors): Positional arguments to ``func`` that must all be
1615-
Tensors. These are the values at which the function is linearly approximated.
1616-
1617-
Returns:
1618-
Returns a ``(output, jvp_fn)`` tuple containing the output of ``func``
1619-
applied to ``primals`` and a function that computes the jvp of
1620-
``func`` evaluated at ``primals``.
1621-
1622-
linearize is useful if jvp is to be computed multiple times at ``primals``. However,
1623-
to achieve this, linearize saves intermediate computation and has higher memory requrements
1624-
than directly applying `jvp`. So, if all the ``tangents`` are known, it maybe more efficient
1625-
to compute vmap(jvp) instead of using linearize.
1626-
1627-
.. note::
1628-
linearize evaluates ``func`` twice. Please file an issue for an implementation
1629-
with a single evaluation.
1630-
1631-
Example::
1632-
>>> import torch
1633-
>>> from torch.func import linearize
1634-
>>> def fn(x):
1635-
... return x.sin()
1636-
...
1637-
>>> output, jvp_fn = linearize(fn, torch.zeros(3, 3))
1638-
>>> jvp_fn(torch.ones(3, 3))
1639-
tensor([[1., 1., 1.],
1640-
[1., 1., 1.],
1641-
[1., 1., 1.]])
1642-
>>>
1643-
1644-
'''
1645-
# Note: We evaluate `fn` twice.
1646-
# Once for returning the output and other while
1647-
# tracing the graph.
1648-
# If this becomes a bottle-neck, we should update
1649-
# make_fx such that it also returns the output.
1650-
1651-
output = func(*primals)
1652-
_, output_spec = tree_flatten(output)
1653-
1654-
flat_primals, primals_argspec = tree_flatten(primals)
1655-
1656-
# tangents for tracing
1657-
flat_tangents = tuple(p.new_empty(()).expand_as(p) for p in flat_primals)
1658-
1659-
# function to trace
1660-
def trace_fn(flat_tangents):
1661-
with fwAD.dual_level():
1662-
flat_duals = tuple(fwAD.make_dual(p, t) for p, t in zip(flat_primals, flat_tangents))
1663-
duals = tree_unflatten(flat_duals, primals_argspec)
1664-
output = func(*duals)
1665-
tangents = tree_map_only(torch.Tensor, lambda t: fwAD.unpack_dual(t)[1], output)
1666-
1667-
return tangents
1668-
1669-
jvp_graph = make_fx(trace_fn)(flat_tangents)
1670-
const_folded_jvp_graph = const_fold.split_const_subgraphs(jvp_graph)
1671-
1672-
# Hold only the meta-data regarding the primals.
1673-
flat_primals_shape = tuple(p.shape for p in flat_primals)
1674-
flat_primals_device = tuple(p.device for p in flat_primals)
1675-
flat_primals_dtype = tuple(p.dtype for p in flat_primals)
1676-
1677-
def forward_ad_checks(flat_tangents):
1678-
for idx, t in enumerate(flat_tangents):
1679-
if t.shape != flat_primals_shape[idx]:
1680-
msg = (f"tangent:{idx} with shape {t.shape} in flattened "
1681-
f"pytree doesn't match the shape {flat_primals_shape[idx]} "
1682-
"of the corresponding primal.")
1683-
raise RuntimeError(msg)
1684-
1685-
if t.device != flat_primals_device[idx]:
1686-
msg = (f"tangent:{idx} with device {t.device} in flattened "
1687-
f"pytree doesn't match the device {flat_primals_device[idx]} "
1688-
"of the corresponding primal.")
1689-
raise RuntimeError(msg)
1690-
1691-
if t.dtype != flat_primals_dtype[idx]:
1692-
msg = (f"tangent:{idx} with dtype {t.dtype} in flattened "
1693-
f"pytree doesn't match the dtype {flat_primals_dtype[idx]} "
1694-
"of the corresponding primal.")
1695-
raise RuntimeError(msg)
1696-
1697-
# jvp_fn : callable to return
1698-
# It takes care of checking the argspec of tangents,
1699-
# calling the folded fx graph and unflattening fx graph output
1700-
def jvp_fn(*tangents):
1701-
flat_tangents, tangent_argspec = tree_flatten(tangents)
1702-
if tangent_argspec != primals_argspec:
1703-
raise RuntimeError(f"Expected the tangents {tangent_argspec} to have "
1704-
f"the same argspec as the primals {primals_argspec}")
1705-
1706-
forward_ad_checks(flat_tangents)
1707-
1708-
flat_output = const_folded_jvp_graph(*flat_tangents)
1709-
# const folded graph can return flat output,
1710-
# so transform output.
1711-
return tree_unflatten(flat_output, output_spec)
1712-
1713-
return output, jvp_fn

Diff for: torch/func/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
jacfwd,
88
hessian,
99
functionalize,
10-
linearize
1110
)
1211
from torch._functorch.functional_call import functional_call, stack_module_state
1312
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_

Diff for: torch/fx/experimental/const_fold.py

-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from torch.fx.passes.split_module import split_module
77

88

9-
__all__ = ['FoldedGraphModule', 'get_unique_attr_name_in_module', 'split_const_subgraphs']
10-
119
class FoldedGraphModule(torch.fx.GraphModule):
1210
"""
1311
FoldedGraphModule is a GraphModule which also contains another

0 commit comments

Comments
 (0)