@@ -47,8 +47,7 @@ def test_same_manual_seed(self):
4747 y = torch .randn ((3 , 3 ))
4848 self .assertIsInstance (y , tensor .Tensor )
4949
50- self .assertTrue (
51- torch .equal (torchax .tensor .j2t (x ._elem ), torchax .tensor .j2t (y ._elem )))
50+ self .assertTrue (torch .allclose (x , y ))
5251
5352 def test_different_manual_seed (self ):
5453 with xla_env :
@@ -60,36 +59,30 @@ def test_different_manual_seed(self):
6059 y = torch .randn ((3 , 3 ))
6160 self .assertIsInstance (y , tensor .Tensor )
6261
63- self .assertFalse (
64- torch .equal (torchax .tensor .j2t (x ._elem ), torchax .tensor .j2t (y ._elem )))
62+ self .assertFalse (torch .allclose (x , y ))
6563
6664 def test_jit_with_rng (self ):
6765
68- @xla_env
69- def random_op ():
70- x = torch .randn (3 , 3 )
71- y = torch .randn (3 , 3 )
72- return x @ y
66+ with xla_env :
67+
68+ def random_op ():
69+ x = torch .randn (3 , 3 )
70+ y = torch .randn (3 , 3 )
71+ return x @ y
7372
74- random_jit = torchax .interop .jax_jit (random_op )
75- self .assertIsInstance (random_jit (), tensor .Tensor )
73+ random_jit = torchax .interop .jax_jit (random_op )
74+ self .assertIsInstance (random_jit (), tensor .Tensor )
7675
77- # Result always expected to be the same for a jitted function because seeds
78- # are baked in
79- torch .testing .assert_close (
80- torchax .tensor .j2t (random_jit ()._elem ),
81- torchax .tensor .j2t (random_jit ()._elem ),
82- atol = 0 ,
83- rtol = 0 )
76+ # Result always expected to be the same for a jitted function because seeds
77+ # are baked in
78+ torch .testing .assert_close (random_jit (), random_jit (), atol = 0 , rtol = 0 )
8479
8580 def test_generator_seed (self ):
8681 with xla_env :
8782 x = torch .randn (2 , 3 , generator = torch .Generator ().manual_seed (0 ))
8883 y = torch .randn (2 , 3 , generator = torch .Generator ().manual_seed (0 ))
8984
90- # Values will be different, but still check device, layout, dtype, etc
91- torch .testing .assert_close (
92- torchax .tensor .j2t (x ._elem ), torchax .tensor .j2t (y ._elem ))
85+ torch .testing .assert_close (x , y )
9386
9487 def test_buffer (self ):
9588
0 commit comments