10
10
11
11
import torch_xla
12
12
import torch_xla .debug .metrics as met
13
+ import torch_xla .experimental .scan as scan_module
13
14
from torch_xla .experimental .scan import scan , value_and_grad_partitioned , tree_flatten_none
14
15
15
16
parent_folder = os .path .dirname (os .path .dirname (__file__ ))
16
17
sys .path .append (parent_folder )
17
18
from test_utils import XlaTestCase # type:ignore
19
+ from absl .testing import parameterized
18
20
19
21
20
22
def _loopy_scan (fn , init , xs ):
@@ -44,6 +46,8 @@ class TestBase(XlaTestCase):
44
46
def setUp (self ):
45
47
super ().setUp ()
46
48
self .device = torch_xla .device ()
49
+ # Clear the scan computation cache before each test to avoid cross-test contamination.
50
+ scan_module ._SCAN_COMPUTATION_CACHE .clear ()
47
51
48
52
def compare_pytree (self , expected_pytree , actual_pytree ):
49
53
flat_expected_pytree , expected_spec = tree_flatten (expected_pytree )
@@ -59,13 +63,14 @@ def compare_pytree(self, expected_pytree, actual_pytree):
59
63
super ().compareResults (flat_expected_pytree , flat_actual_pytree )
60
64
61
65
62
- class ScanTest (TestBase ):
66
+ class ScanTest (TestBase , parameterized . TestCase ):
63
67
64
68
def run_test (self ,
65
69
fn ,
66
70
init : PyTree ,
67
71
xs : PyTree ,
68
- partition_fn = default_partition ):
72
+ partition_fn = default_partition ,
73
+ is_fn_pure : bool = False ):
69
74
"""Compares the result of scanning with `fn` with our optimized HLO implementation
70
75
against a for loop implementation. Checks both output values and gradients.
71
76
"""
@@ -78,7 +83,12 @@ def run_test(self,
78
83
# Actual output
79
84
init_scan = tree_map (dupe , init )
80
85
xs_scan = tree_map (dupe , xs )
81
- final_carry , ys = scan (fn , init_scan , xs_scan , partition_fn = partition_fn )
86
+ final_carry , ys = scan (
87
+ fn ,
88
+ init_scan ,
89
+ xs_scan ,
90
+ partition_fn = partition_fn ,
91
+ is_fn_pure = is_fn_pure )
82
92
# Add up all leaves and `backward()` once.
83
93
(squish (final_carry ) + squish (ys )).backward ()
84
94
torch_xla .sync ()
@@ -105,7 +115,8 @@ def run_test(self,
105
115
106
116
return final_carry , ys
107
117
108
- def test_scan_simple (self ):
118
+ @parameterized .parameters (True , False )
119
+ def test_scan_simple (self , is_fn_pure : bool ):
109
120
"""This test uses `scan` to implement `torch.cumsum`."""
110
121
111
122
def step_fn (carry , x ):
@@ -117,7 +128,7 @@ def step_fn(carry, x):
117
128
xs = torch .tensor ([[1.0 , 2.0 ], [3.0 , 4.0 ], [5.0 , 6.0 ]],
118
129
requires_grad = True ,
119
130
device = self .device )
120
- final_carry , ys = self .run_test (step_fn , init , xs )
131
+ final_carry , ys = self .run_test (step_fn , init , xs , is_fn_pure = is_fn_pure )
121
132
122
133
# Also ensure that our loop-based scan is correct, with manual checks
123
134
# that replicate the step_fn.
@@ -140,7 +151,8 @@ def test_scan_incompatible_length(self):
140
151
with self .assertRaises (ValueError ):
141
152
scan (lambda a , b : (a , b ), init , (xs_1 , xs_2 ))
142
153
143
- def test_scan_tuples (self ):
154
+ @parameterized .parameters (True , False )
155
+ def test_scan_tuples (self , is_fn_pure : bool ):
144
156
"""Test scanning over the leading axis of a tuple of tensors simultaneously,
145
157
which is a simple PyTree."""
146
158
@@ -163,9 +175,10 @@ def fn(carry, x):
163
175
requires_grad = True ,
164
176
device = self .device ))
165
177
166
- self .run_test (fn , init , xs )
178
+ self .run_test (fn , init , xs , is_fn_pure = is_fn_pure )
167
179
168
- def test_scan_create_tensors (self ):
180
+ @parameterized .parameters (True , False )
181
+ def test_scan_create_tensors (self , is_fn_pure : bool ):
169
182
"""Test scanning over a function that internally creates tensors."""
170
183
171
184
def fn (carry , x ):
@@ -177,7 +190,7 @@ def fn(carry, x):
177
190
xs = torch .tensor ([[1.0 , 2.0 ], [3.0 , 4.0 ], [5.0 , 6.0 ]],
178
191
requires_grad = True ,
179
192
device = self .device )
180
- self .run_test (fn , init , xs )
193
+ self .run_test (fn , init , xs , is_fn_pure = is_fn_pure )
181
194
182
195
def test_scan_create_tensors_no_transfers_from_device (self ):
183
196
"""Test that scanning over a function that internally creates tensors
@@ -220,7 +233,8 @@ def fn(carry, x):
220
233
device = self .device )
221
234
self .run_test (fn , init , xs )
222
235
223
- def test_scan_input_output_aliases_carry (self ):
236
+ @parameterized .parameters (True , False )
237
+ def test_scan_input_output_aliases_carry (self , is_fn_pure : bool ):
224
238
"""
225
239
Test scan still works when a fn output aliases its carry input.
226
240
"""
@@ -232,9 +246,10 @@ def fn(carry, x):
232
246
xs = torch .tensor ([[1.0 , 2.0 ], [3.0 , 4.0 ], [5.0 , 6.0 ]],
233
247
requires_grad = True ,
234
248
device = self .device )
235
- self .run_test (fn , init , xs )
249
+ self .run_test (fn , init , xs , is_fn_pure = is_fn_pure )
236
250
237
- def test_scan_input_output_aliases_x (self ):
251
+ @parameterized .parameters (True , False )
252
+ def test_scan_input_output_aliases_x (self , is_fn_pure : bool ):
238
253
"""
239
254
Test scan still works when a fn output aliases its x input.
240
255
"""
@@ -246,7 +261,7 @@ def fn(carry, x):
246
261
xs = torch .tensor ([[1.0 , 2.0 ], [3.0 , 4.0 ], [5.0 , 6.0 ]],
247
262
requires_grad = True ,
248
263
device = self .device )
249
- self .run_test (fn , init , xs )
264
+ self .run_test (fn , init , xs , is_fn_pure = is_fn_pure )
250
265
251
266
def test_scan_input_in_place_mutation (self ):
252
267
"""
@@ -288,7 +303,8 @@ def step_fn(carry, x):
288
303
with self .assertRaisesRegex (AssertionError , "FakeTensor" ):
289
304
scan (step_fn , init , xs )
290
305
291
- def test_scan_gradness (self ):
306
+ @parameterized .parameters (True , False )
307
+ def test_scan_gradness (self , is_fn_pure : bool ):
292
308
"""
293
309
Test the gradient output of `scan` when various inputs require or doesn't
294
310
require gradients.
@@ -307,7 +323,7 @@ def fn(carry, x):
307
323
xs = torch .tensor ([[2.0 , 3.0 ], [4.0 , 5.0 ], [6.0 , 7.0 ]],
308
324
requires_grad = xs_requires_grad ,
309
325
device = self .device )
310
- self .run_test (fn , init , xs )
326
+ self .run_test (fn , init , xs , is_fn_pure = is_fn_pure )
311
327
312
328
test_case (True , True )
313
329
test_case (True , False )
@@ -445,7 +461,8 @@ def fn(carry, x):
445
461
self .assertEqual (bf16_ys .dtype , torch .bfloat16 )
446
462
self .assertEqual (f32_ys .dtype , torch .float32 )
447
463
448
- def test_scan_activation_aliases_input (self ):
464
+ @parameterized .parameters (True , False )
465
+ def test_scan_activation_aliases_input (self , is_fn_pure : bool ):
449
466
"""Test that if an intermediate activation of fn aliases an input,
450
467
we directly save the input tensor into the context object, instead of
451
468
indexing into the leading dimension during the while loop and copying
@@ -470,7 +487,7 @@ def unpack(x):
470
487
471
488
# Intercept the tensors stored in the context object.
472
489
with torch .autograd .graph .saved_tensors_hooks (pack , unpack ):
473
- final_carry , ys = scan (fn , carry , xs )
490
+ final_carry , ys = scan (fn , carry , xs , is_fn_pure = is_fn_pure )
474
491
ys .sum ().backward ()
475
492
torch_xla .sync ()
476
493
@@ -487,6 +504,112 @@ def unpack(x):
487
504
# as opposed to just numerically identical but otherwise an extra copy.
488
505
assert id (stored_xs ) == id (xs )
489
506
507
+ def test_scan_computation_cache (self ):
508
+ """
509
+ Test that the computation cache is populated correctly.
510
+ """
511
+ fn1_call_count = 0
512
+
513
+ def fn1 (carry , x ):
514
+ nonlocal fn1_call_count
515
+ fn1_call_count += 1
516
+ return carry + x , x
517
+
518
+ init = torch .tensor ([0.0 , 0.0 ], device = self .device )
519
+ xs = torch .tensor ([[1.0 , 2.0 ], [3.0 , 4.0 ], [5.0 , 6.0 ]],
520
+ device = self .device ,
521
+ requires_grad = True )
522
+
523
+ for _ in range (10 ):
524
+ scan (fn1 , init , xs , is_fn_pure = True )
525
+
526
+ cache = scan_module ._SCAN_COMPUTATION_CACHE
527
+
528
+ # Check if my_scan_fn is in the cache
529
+ assert fn1 in cache , "fn1 should be in the cache"
530
+
531
+ # Inspect the second-level cache for my_scan_fn
532
+ second_level_cache = cache [fn1 ]
533
+ assert len (second_level_cache ) > 0 , "Second-level cache should not be empty"
534
+
535
+ # Check if the number of calls to fn1 is 1.
536
+ assert fn1_call_count == 2 , \
537
+ "fn1 should be called only twice (one for constructing forward graph and one for constructing backward graph), but was called " + str (fn1_call_count )
538
+
539
+ # You can further inspect the contents of the second-level cache if needed
540
+ for key , value in second_level_cache .items ():
541
+ forward , alias_input , backward = value
542
+ # Add assertions or print statements to check the functions
543
+ assert callable (forward )
544
+ assert callable (alias_input )
545
+ assert callable (backward )
546
+
547
+ def test_scan_computation_cache_by_fn_and_partition_fn (self ):
548
+ """
549
+ Test that the computation cache is populated by fn and partition_fn.
550
+ """
551
+
552
+ def fn1 (carry , x ):
553
+ return carry + x , x
554
+
555
+ def fn2 (carry , x ):
556
+ return carry * x , x
557
+
558
+ init = torch .tensor ([0.0 , 0.0 ], device = self .device )
559
+ xs = torch .tensor ([[1.0 , 2.0 ], [3.0 , 4.0 ], [5.0 , 6.0 ]],
560
+ device = self .device ,
561
+ requires_grad = True )
562
+ scan (fn1 , init , xs , is_fn_pure = True )
563
+ scan (fn2 , init , xs , is_fn_pure = True )
564
+
565
+ cache = scan_module ._SCAN_COMPUTATION_CACHE
566
+
567
+ # Check if fn is in the cache
568
+ assert fn1 in cache , "fn1 should be in the cache"
569
+ assert fn2 in cache , "fn2 should be in the cache"
570
+
571
+ # Inspect the second-level cache for fn
572
+ second_level_cache = cache [fn1 ]
573
+ assert len (
574
+ second_level_cache ) == 1 , "Second-level cache should be exactly 1"
575
+
576
+ # Inspect the second-level cache for fn
577
+ second_level_cache = cache [fn2 ]
578
+ assert len (
579
+ second_level_cache ) == 1 , "Second-level cache should be exactly 1"
580
+
581
+ # Check if the partition function created a new cache entry
582
+ scan (
583
+ fn1 ,
584
+ init ,
585
+ xs ,
586
+ partition_fn = min_cut_rematerialization_partition ,
587
+ is_fn_pure = True )
588
+ second_level_cache = cache [fn1 ]
589
+ # Inspect the second-level cache for fn2
590
+ assert len (second_level_cache
591
+ ) == 2 , "Second-level cache should be exactly 2. Got: " + str (
592
+ len (second_level_cache ))
593
+
594
+ def test_scan_computation_cache_disabled_when_fn_is_not_pure (self ):
595
+ """
596
+ Test that the computation cache is not populated when the function is not pure.
597
+ """
598
+
599
+ def fn1 (carry , x ):
600
+ return carry + x , x
601
+
602
+ init = torch .tensor ([0.0 , 0.0 ], device = self .device )
603
+ xs = torch .tensor ([[1.0 , 2.0 ], [3.0 , 4.0 ], [5.0 , 6.0 ]],
604
+ device = self .device ,
605
+ requires_grad = True )
606
+ scan (fn1 , init , xs , is_fn_pure = False )
607
+
608
+ cache = scan_module ._SCAN_COMPUTATION_CACHE
609
+
610
+ # Check if my_scan_fn is in the cache
611
+ assert fn1 not in cache , "fn1 should not be in the cache"
612
+
490
613
491
614
class PyTreeTest (TestBase ):
492
615
0 commit comments