Skip to content

Commit 0727453

Browse files
shoyertree-math authors
authored and
tree-math authors
committed
[tree-math] add replace() method
This convenience method is copied from flax.struct. PiperOrigin-RevId: 621690005
1 parent 4f9cd0a commit 0727453

File tree

4 files changed

+9
-2
lines changed

4 files changed

+9
-2
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
setuptools.setup(
3030
name='tree-math',
3131
description='Mathematical operations for JAX pytrees',
32-
version='0.2.0 ',
32+
version='0.2.1',
3333
license='Apache 2.0',
3434
author='Google LLC',
3535
author_email='[email protected]',

tree_math/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@
2424
from tree_math._src.vector import Vector, VectorMixin
2525
import tree_math.numpy
2626

27-
__version__ = '0.2.0'
27+
__version__ = '0.2.1'

tree_math/_src/structs.py

+1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def tree_unflatten(cls, _, children):
7272
{'fields': fields,
7373
'asdict': asdict,
7474
'astuple': astuple,
75+
'replace': dataclasses.replace,
7576
'tree_flatten': tree_flatten,
7677
'tree_unflatten': tree_unflatten,
7778
'__module__': cls.__module__})

tree_math/_src/structs_test.py

+6
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,12 @@ def testPickle(self):
109109
restored = pickle.loads(pickle.dumps(struct))
110110
self.assertTreeEqual(struct, restored, check_dtypes=True)
111111

112+
def testReplace(self):
113+
struct = TestStruct(1, 2)
114+
replaced = struct.replace(b=3)
115+
expected = TestStruct(1, 3)
116+
self.assertTreeEqual(replaced, expected, check_dtypes=True)
117+
112118

113119
if __name__ == '__main__':
114120
absltest.main()

0 commit comments

Comments
 (0)