Skip to content

Commit d1ddeed

Browse files
authored
Add cache to value_and_grad_partitioned (#9163)
1 parent 402612d commit d1ddeed

File tree

2 files changed

+182
-19
lines changed

2 files changed

+182
-19
lines changed

test/scan/test_scan.py

Lines changed: 140 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010

1111
import torch_xla
1212
import torch_xla.debug.metrics as met
13+
import torch_xla.experimental.scan as scan_module
1314
from torch_xla.experimental.scan import scan, value_and_grad_partitioned, tree_flatten_none
1415

1516
parent_folder = os.path.dirname(os.path.dirname(__file__))
1617
sys.path.append(parent_folder)
1718
from test_utils import XlaTestCase # type:ignore
19+
from absl.testing import parameterized
1820

1921

2022
def _loopy_scan(fn, init, xs):
@@ -44,6 +46,8 @@ class TestBase(XlaTestCase):
4446
def setUp(self):
4547
super().setUp()
4648
self.device = torch_xla.device()
49+
# Clear the scan computation cache before each test to avoid cross-test contamination.
50+
scan_module._SCAN_COMPUTATION_CACHE.clear()
4751

4852
def compare_pytree(self, expected_pytree, actual_pytree):
4953
flat_expected_pytree, expected_spec = tree_flatten(expected_pytree)
@@ -59,13 +63,14 @@ def compare_pytree(self, expected_pytree, actual_pytree):
5963
super().compareResults(flat_expected_pytree, flat_actual_pytree)
6064

6165

62-
class ScanTest(TestBase):
66+
class ScanTest(TestBase, parameterized.TestCase):
6367

6468
def run_test(self,
6569
fn,
6670
init: PyTree,
6771
xs: PyTree,
68-
partition_fn=default_partition):
72+
partition_fn=default_partition,
73+
is_fn_pure: bool = False):
6974
"""Compares the result of scanning with `fn` with our optimized HLO implementation
7075
against a for loop implementation. Checks both output values and gradients.
7176
"""
@@ -78,7 +83,12 @@ def run_test(self,
7883
# Actual output
7984
init_scan = tree_map(dupe, init)
8085
xs_scan = tree_map(dupe, xs)
81-
final_carry, ys = scan(fn, init_scan, xs_scan, partition_fn=partition_fn)
86+
final_carry, ys = scan(
87+
fn,
88+
init_scan,
89+
xs_scan,
90+
partition_fn=partition_fn,
91+
is_fn_pure=is_fn_pure)
8292
# Add up all leaves and `backward()` once.
8393
(squish(final_carry) + squish(ys)).backward()
8494
torch_xla.sync()
@@ -105,7 +115,8 @@ def run_test(self,
105115

106116
return final_carry, ys
107117

108-
def test_scan_simple(self):
118+
@parameterized.parameters(True, False)
119+
def test_scan_simple(self, is_fn_pure: bool):
109120
"""This test uses `scan` to implement `torch.cumsum`."""
110121

111122
def step_fn(carry, x):
@@ -117,7 +128,7 @@ def step_fn(carry, x):
117128
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
118129
requires_grad=True,
119130
device=self.device)
120-
final_carry, ys = self.run_test(step_fn, init, xs)
131+
final_carry, ys = self.run_test(step_fn, init, xs, is_fn_pure=is_fn_pure)
121132

122133
# Also ensure that our loop-based scan is correct, with manual checks
123134
# that replicate the step_fn.
@@ -140,7 +151,8 @@ def test_scan_incompatible_length(self):
140151
with self.assertRaises(ValueError):
141152
scan(lambda a, b: (a, b), init, (xs_1, xs_2))
142153

143-
def test_scan_tuples(self):
154+
@parameterized.parameters(True, False)
155+
def test_scan_tuples(self, is_fn_pure: bool):
144156
"""Test scanning over the leading axis of a tuple of tensors simultaneously,
145157
which is a simple PyTree."""
146158

@@ -163,9 +175,10 @@ def fn(carry, x):
163175
requires_grad=True,
164176
device=self.device))
165177

166-
self.run_test(fn, init, xs)
178+
self.run_test(fn, init, xs, is_fn_pure=is_fn_pure)
167179

168-
def test_scan_create_tensors(self):
180+
@parameterized.parameters(True, False)
181+
def test_scan_create_tensors(self, is_fn_pure: bool):
169182
"""Test scanning over a function that internally creates tensors."""
170183

171184
def fn(carry, x):
@@ -177,7 +190,7 @@ def fn(carry, x):
177190
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
178191
requires_grad=True,
179192
device=self.device)
180-
self.run_test(fn, init, xs)
193+
self.run_test(fn, init, xs, is_fn_pure=is_fn_pure)
181194

182195
def test_scan_create_tensors_no_transfers_from_device(self):
183196
"""Test that scanning over a function that internally creates tensors
@@ -220,7 +233,8 @@ def fn(carry, x):
220233
device=self.device)
221234
self.run_test(fn, init, xs)
222235

223-
def test_scan_input_output_aliases_carry(self):
236+
@parameterized.parameters(True, False)
237+
def test_scan_input_output_aliases_carry(self, is_fn_pure: bool):
224238
"""
225239
Test scan still works when a fn output aliases its carry input.
226240
"""
@@ -232,9 +246,10 @@ def fn(carry, x):
232246
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
233247
requires_grad=True,
234248
device=self.device)
235-
self.run_test(fn, init, xs)
249+
self.run_test(fn, init, xs, is_fn_pure=is_fn_pure)
236250

237-
def test_scan_input_output_aliases_x(self):
251+
@parameterized.parameters(True, False)
252+
def test_scan_input_output_aliases_x(self, is_fn_pure: bool):
238253
"""
239254
Test scan still works when a fn output aliases its x input.
240255
"""
@@ -246,7 +261,7 @@ def fn(carry, x):
246261
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
247262
requires_grad=True,
248263
device=self.device)
249-
self.run_test(fn, init, xs)
264+
self.run_test(fn, init, xs, is_fn_pure=is_fn_pure)
250265

251266
def test_scan_input_in_place_mutation(self):
252267
"""
@@ -288,7 +303,8 @@ def step_fn(carry, x):
288303
with self.assertRaisesRegex(AssertionError, "FakeTensor"):
289304
scan(step_fn, init, xs)
290305

291-
def test_scan_gradness(self):
306+
@parameterized.parameters(True, False)
307+
def test_scan_gradness(self, is_fn_pure: bool):
292308
"""
293309
Test the gradient output of `scan` when various inputs require or doesn't
294310
require gradients.
@@ -307,7 +323,7 @@ def fn(carry, x):
307323
xs = torch.tensor([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0]],
308324
requires_grad=xs_requires_grad,
309325
device=self.device)
310-
self.run_test(fn, init, xs)
326+
self.run_test(fn, init, xs, is_fn_pure=is_fn_pure)
311327

312328
test_case(True, True)
313329
test_case(True, False)
@@ -445,7 +461,8 @@ def fn(carry, x):
445461
self.assertEqual(bf16_ys.dtype, torch.bfloat16)
446462
self.assertEqual(f32_ys.dtype, torch.float32)
447463

448-
def test_scan_activation_aliases_input(self):
464+
@parameterized.parameters(True, False)
465+
def test_scan_activation_aliases_input(self, is_fn_pure: bool):
449466
"""Test that if an intermediate activation of fn aliases an input,
450467
we directly save the input tensor into the context object, instead of
451468
indexing into the leading dimension during the while loop and copying
@@ -470,7 +487,7 @@ def unpack(x):
470487

471488
# Intercept the tensors stored in the context object.
472489
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
473-
final_carry, ys = scan(fn, carry, xs)
490+
final_carry, ys = scan(fn, carry, xs, is_fn_pure=is_fn_pure)
474491
ys.sum().backward()
475492
torch_xla.sync()
476493

@@ -487,6 +504,112 @@ def unpack(x):
487504
# as opposed to just numerically identical but otherwise an extra copy.
488505
assert id(stored_xs) == id(xs)
489506

507+
def test_scan_computation_cache(self):
508+
"""
509+
Test that the computation cache is populated correctly.
510+
"""
511+
fn1_call_count = 0
512+
513+
def fn1(carry, x):
514+
nonlocal fn1_call_count
515+
fn1_call_count += 1
516+
return carry + x, x
517+
518+
init = torch.tensor([0.0, 0.0], device=self.device)
519+
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
520+
device=self.device,
521+
requires_grad=True)
522+
523+
for _ in range(10):
524+
scan(fn1, init, xs, is_fn_pure=True)
525+
526+
cache = scan_module._SCAN_COMPUTATION_CACHE
527+
528+
# Check if my_scan_fn is in the cache
529+
assert fn1 in cache, "fn1 should be in the cache"
530+
531+
# Inspect the second-level cache for my_scan_fn
532+
second_level_cache = cache[fn1]
533+
assert len(second_level_cache) > 0, "Second-level cache should not be empty"
534+
535+
# Check if the number of calls to fn1 is 1.
536+
assert fn1_call_count == 2, \
537+
"fn1 should be called only twice (one for constructing forward graph and one for constructing backward graph), but was called " + str(fn1_call_count)
538+
539+
# You can further inspect the contents of the second-level cache if needed
540+
for key, value in second_level_cache.items():
541+
forward, alias_input, backward = value
542+
# Add assertions or print statements to check the functions
543+
assert callable(forward)
544+
assert callable(alias_input)
545+
assert callable(backward)
546+
547+
def test_scan_computation_cache_by_fn_and_partition_fn(self):
548+
"""
549+
Test that the computation cache is populated by fn and partition_fn.
550+
"""
551+
552+
def fn1(carry, x):
553+
return carry + x, x
554+
555+
def fn2(carry, x):
556+
return carry * x, x
557+
558+
init = torch.tensor([0.0, 0.0], device=self.device)
559+
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
560+
device=self.device,
561+
requires_grad=True)
562+
scan(fn1, init, xs, is_fn_pure=True)
563+
scan(fn2, init, xs, is_fn_pure=True)
564+
565+
cache = scan_module._SCAN_COMPUTATION_CACHE
566+
567+
# Check if fn is in the cache
568+
assert fn1 in cache, "fn1 should be in the cache"
569+
assert fn2 in cache, "fn2 should be in the cache"
570+
571+
# Inspect the second-level cache for fn
572+
second_level_cache = cache[fn1]
573+
assert len(
574+
second_level_cache) == 1, "Second-level cache should be exactly 1"
575+
576+
# Inspect the second-level cache for fn
577+
second_level_cache = cache[fn2]
578+
assert len(
579+
second_level_cache) == 1, "Second-level cache should be exactly 1"
580+
581+
# Check if the partition function created a new cache entry
582+
scan(
583+
fn1,
584+
init,
585+
xs,
586+
partition_fn=min_cut_rematerialization_partition,
587+
is_fn_pure=True)
588+
second_level_cache = cache[fn1]
589+
# Inspect the second-level cache for fn2
590+
assert len(second_level_cache
591+
) == 2, "Second-level cache should be exactly 2. Got: " + str(
592+
len(second_level_cache))
593+
594+
def test_scan_computation_cache_disabled_when_fn_is_not_pure(self):
595+
"""
596+
Test that the computation cache is not populated when the function is not pure.
597+
"""
598+
599+
def fn1(carry, x):
600+
return carry + x, x
601+
602+
init = torch.tensor([0.0, 0.0], device=self.device)
603+
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
604+
device=self.device,
605+
requires_grad=True)
606+
scan(fn1, init, xs, is_fn_pure=False)
607+
608+
cache = scan_module._SCAN_COMPUTATION_CACHE
609+
610+
# Check if my_scan_fn is in the cache
611+
assert fn1 not in cache, "fn1 should not be in the cache"
612+
490613

491614
class PyTreeTest(TestBase):
492615

torch_xla/experimental/scan.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,26 @@
5151
from torch_xla.distributed.spmd.xla_sharding import shard_as
5252
import torch_xla.debug.profiler as xp
5353
import torch_xla.runtime
54+
from weakref import WeakKeyDictionary
5455

5556
Carry = TypeVar('Carry')
5657
X = TypeVar('X')
5758
Y = TypeVar('Y')
5859

60+
# A cache of the forward, alias_input, backward for `scan`. It has a sturcture
61+
# of {fn_ref: {input_key: (forward, alias_input, backward)}}.
62+
# The `fn_ref` is the address of the given function and is weakly referenced.
63+
# The input_key is computed using the shapes, dtypes, and the pytree specs of
64+
# the carry and xs.
65+
_SCAN_COMPUTATION_CACHE = WeakKeyDictionary()
66+
5967

6068
def scan(
6169
fn: Callable[[Carry, X], tuple[Carry, Y]],
6270
init: Carry,
6371
xs: X,
6472
partition_fn=default_partition,
73+
is_fn_pure: bool = False,
6574
# TODO: consider exposing knobs to control the RNG seed used in each `fn` iteration.
6675
) -> tuple[Carry, Y]:
6776
"""Apply a function over leading dimension of tensors while carrying along state.
@@ -110,6 +119,11 @@ def scan(fn, init, xs):
110119
based activation checkpointing. You may also write your own partitioner to insert any
111120
custom logic such as host offloading of activations.
112121
122+
is_fn_pure: (Optional[bool]) If `fn` is pure, the tracing cache will be enabled. A pure
123+
function always produces the same output for the same input, and it doesn't have any
124+
side effects, meaning it doesn't modify any state outside of itself. Essentially, it's
125+
like a mathematical function that only depends on its input arguments.
126+
113127
Returns:
114128
(carry, ys): A tuple where `carry` is the last carry object returned by `fn`, and
115129
`ys` is a PyTree with the same structure as `xs`, but where the leaves are formed
@@ -160,7 +174,7 @@ def scan(fn, init, xs):
160174
raise ValueError(f"`xs` {xs} is an empty PyTree.")
161175

162176
forward, alias_input, backward = value_and_grad_partitioned(
163-
fn, init, xs, partition_fn=partition_fn)
177+
fn, init, xs, partition_fn=partition_fn, is_fn_pure=is_fn_pure)
164178
carry, ys = Scan.apply(forward, alias_input, backward, init,
165179
xs) # type: ignore
166180
return carry, ys
@@ -170,7 +184,8 @@ def value_and_grad_partitioned(
170184
fn: Callable[[Carry, X], tuple[Carry, Y]],
171185
init: Carry,
172186
xs: X,
173-
partition_fn=default_partition) -> tuple[Callable, Callable, Callable]:
187+
partition_fn=default_partition,
188+
is_fn_pure=True) -> tuple[Callable, Callable, Callable]:
174189
"""
175190
Given a user `fn` to be scanned over the leading dimension of the input `xs`
176191
PyTree and an initial carry object `init`, symbolically traces `fn` and
@@ -213,11 +228,23 @@ def value_and_grad_partitioned(
213228
214229
partition_fn: An optional partitioning function used to partition fn into
215230
forward and backward graphs.
231+
232+
is_fn_pure: (Optional[bool]) If `fn` is pure, the tracing cache will be enabled.
216233
217234
Returns:
218235
A tuple of `(forward, alias_input, backward)`, detailed in the docstring of this function.
219236
"""
220237

238+
# compute the second-level cache key for tracing and generating the forward and backward graphs.
239+
# The key is a tuple of partition_fn's id, the shapes, dtypes, and the pytree specs of the carry and xs.
240+
def compute_second_level_cache_key(carry_pytree, x_pytree):
241+
carry_flat, carry_flat_spec = tree_flatten(carry_pytree)
242+
x_flat, x_flat_spec = tree_flatten(x_pytree)
243+
carry_key = tuple(
244+
(tuple(tensor.shape), tensor.dtype) for tensor in carry_flat)
245+
x_key = tuple((tuple(tensor.shape), tensor.dtype) for tensor in x_flat)
246+
return (id(partition_fn), carry_key, x_key, carry_flat_spec, x_flat_spec)
247+
221248
# Make some fake tensors to trace the user function and obtain the
222249
# forward and backward graphs. Note that the init/carry fake tensor
223250
# always requires grad. That's because even if the user passed in some
@@ -233,6 +260,12 @@ def make_fake_tensor(v: torch.Tensor, requires_grad=True) -> torch.Tensor:
233260
fake_x_pytree = tree_map(
234261
lambda v: make_fake_tensor(v[0], requires_grad=v.requires_grad), xs)
235262

263+
second_level_cache_key = compute_second_level_cache_key(
264+
fake_carry_pytree, fake_x_pytree)
265+
if is_fn_pure and fn in _SCAN_COMPUTATION_CACHE:
266+
if second_level_cache_key in _SCAN_COMPUTATION_CACHE[fn]:
267+
return _SCAN_COMPUTATION_CACHE[fn][second_level_cache_key]
268+
236269
# If an output of `fn` aliases the input, `aot_function` will handle that
237270
# pair of variables with an epilogue inside its generated autograd.Function
238271
# that we can't access. In other words, the captured graph won't contain
@@ -328,6 +361,13 @@ def backward(carry, x):
328361
grad_carry, grad_x = unflatten_bwd_out(out)
329362
return grad_carry, grad_x
330363

364+
# Cache the forward and backward graphs for later use.
365+
if is_fn_pure:
366+
if fn not in _SCAN_COMPUTATION_CACHE:
367+
_SCAN_COMPUTATION_CACHE[fn] = {}
368+
_SCAN_COMPUTATION_CACHE[fn][second_level_cache_key] = (forward, alias_input,
369+
backward)
370+
331371
return forward, alias_input, backward
332372

333373

0 commit comments

Comments
 (0)