Skip to content

Adding support for both static and non-static fields #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 70 additions & 58 deletions tree_math/_src/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,61 +19,73 @@


def struct(cls):
"""Class decorator that enables JAX function transforms as well as tree math.

Decorating a class with `@struct` makes it a dataclass that is compatible
with arithmetic infix operators like `+`, `-`, `*` and `/`. The decorated
class is also a valid pytree, making it compatible with JAX function
transformations such as `jit` and `grad`.

Example usage:
```
import jax
import tree_math

@tree_math.struct
class Point:
x: float
y: float

a = Point(0.0, 1.0)
b = Point(2.0, 3.0)

a + 3 * b # Point(6.0, 10.0)
jax.grad(lambda x, y: x @ y)(a, b) # Point(2.0, 3.0)
```

Args:
cls: a class, written with the same syntax as a `dataclass`.

Returns:
A wrapped version of `cls` that implements dataclass, pytree and tree_math
functionality.
"""
@property
def fields(self):
return dataclasses.fields(self)

def asdict(self):
return {field.name: getattr(self, field.name) for field in self.fields}

def astuple(self):
return tuple(getattr(self, field.name) for field in self.fields)

def tree_flatten(self):
return self.astuple(), None

@classmethod
def tree_unflatten(cls, _, children):
return cls(*children)

cls_as_struct = type(cls.__name__,
(VectorMixin, dataclasses.dataclass(cls)),
{'fields': fields,
'asdict': asdict,
'astuple': astuple,
'replace': dataclasses.replace,
'tree_flatten': tree_flatten,
'tree_unflatten': tree_unflatten,
'__module__': cls.__module__})
return jax.tree_util.register_pytree_node_class(cls_as_struct)
"""Class decorator that enables JAX function transforms as well as tree math.

Decorating a class with `@struct` makes it a dataclass that is compatible
with arithmetic infix operators like `+`, `-`, `*` and `/`. The decorated
class is also a valid pytree, making it compatible with JAX function
transformations such as `jit` and `grad`.

Example usage:
```
import jax
import tree_math

@tree_math.struct
class Point:
x: float
y: float
static_field: int = 0 # base case

a = Point(0.0, 1.0)
b = Point(2.0, 3.0)

a + 3 * b # Point(6.0, 10.0)
jax.grad(lambda x, y: x @ y)(a, b) # Point(2.0, 3.0)
```

Args:
cls: a class, written with the same syntax as a `dataclass`.

Returns:
A wrapped version of `cls` that implements dataclass, pytree and tree_math
functionality.
"""

# Get static fields from the class if defined
static_fields = getattr(cls, 'static_fields', [])

@property
def fields(self):
return dataclasses.fields(self)

def asdict(self):
return {field.name: getattr(self, field.name) for field in self.fields}

def astuple(self):
return tuple(getattr(self, field.name) for field in self.fields if field.name not in static_fields)

def tree_flatten(self):
# Flatten only the non-static fields
children = [getattr(self, field.name) for field in self.fields if field.name not in static_fields]
return children, None

@classmethod
def tree_unflatten(cls, _, children):
# Create an instance with the provided children and static fields
instance = cls(*children)
for field in cls.static_fields:
setattr(instance, field, getattr(cls, field)) # Set static fields
return instance

cls_as_struct = type(cls.__name__,
(VectorMixin, dataclasses.dataclass(cls)),
{'fields': fields,
'asdict': asdict,
'astuple': astuple,
'replace': dataclasses.replace,
'tree_flatten': tree_flatten,
'tree_unflatten': tree_unflatten,
'__module__': cls.__module__})

return jax.tree_util.register_pytree_node_class(cls_as_struct)
10 changes: 8 additions & 2 deletions tree_math/_src/structs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,14 @@

@tree_math.struct
class TestStruct:
a: ArrayLike
b: ArrayLike
a: ArrayLike
b: ArrayLike
static_field: int = 0 # This will be a static field

# Define static fields as a class variable
static_fields = ['static_field'] # Specify which fields are static




class StructsTest(test_util.TestCase):
Expand Down