Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for both static and non-static fields #26

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

ricor07
Copy link

@ricor07 ricor07 commented Dec 30, 2024

Edited functions in struct so that some fields can be pytree nodes and some not.

Copy link

google-cla bot commented Dec 30, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@shoyer
Copy link
Member

shoyer commented Jan 3, 2025

Thanks @ricor07!

I would prefer to implement this with one of the existing mechanisms for static fields from jax.tree_util.register_dataclasses rather than inventing a new approach here with the class variable: https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.register_dataclass.html

This suggests two options:

  1. putting static_fields (or a similar argument) as an argument on the struct function, or
  2. checking for a default value indicated with dataclasses.field

This change also needs a unit test to verify that it works.

@ricor07
Copy link
Author

ricor07 commented Jan 4, 2025

I think I may need help. I modified the struct function this way:

` @Property
def fields(self):
return dataclasses.fields(self)

def asdict(self):
return {field.name: getattr(self, field.name) for field in self.fields}

def astuple(self):
return tuple(getattr(self, field.name) for field in self.fields if field.name != "static")

def tree_flatten(self):
children = [getattr(self, field.name) for field in self.fields if field.name != "static"]
return children, None

@classmethod
def tree_unflatten(cls, _, children):
return cls(*children)

cls_as_struct = type(cls.name,
(VectorMixin, dataclasses.dataclass(cls)),
{'fields': fields,
'asdict': asdict,
'astuple': astuple,
'replace': dataclasses.replace,
'tree_flatten': tree_flatten,
'tree_unflatten': tree_unflatten,
'module': cls.module})
return jax.tree_util.register_pytree_node_class(cls_as_struct)`

And I modified TestStruct this way, according to the jax documentation:

class TestStruct: a: ArrayLike b: ArrayLike op: str = "static"

However, I receive TypeErrors since str is not a valid argument. I don't know what to do. I already tried different implementations

@ricor07 ricor07 marked this pull request as draft January 4, 2025 14:37
@ricor07
Copy link
Author

ricor07 commented Jan 17, 2025

Hello, could you give me a feedback? Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants