-
Notifications
You must be signed in to change notification settings - Fork 1.6k
add cute.union #2788
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
add cute.union #2788
Conversation
| - The alignment is the maximum alignment of all objects | ||
| - The size is the maximum size of all objects | ||
| **Usage:** |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The usage looks really nice! It’s worth documenting this in an *.rst file. Maybe you could create a new file just for types and include the union there. We can fill in the rest of the regular types after your PT.
https://github.com/NVIDIA/cutlass/tree/main/media/docs/pythonDSL/cute_dsl_general/types.rst
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
made claude draft it, someone's gotta look over it
python/CuTeDSL/cutlass/cute/core.py
Outdated
| self._align_of = max_alignment | ||
| self._size_of = struct.align_offset(max_size, max_alignment) | ||
|
|
||
| def __call__(self, base: Any) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def __call__(self, base: Any) -> None: | |
| @dsl_user_op | |
| def __call__(self, base: Any, *, loc=None, ip=None) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dsl_user_op does multiple things like generating location information. Let's add it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added. note that struct call does not have it, is that right?
python/CuTeDSL/cutlass/cute/core.py
Outdated
| if isinstance(obj, struct._AlignMeta): | ||
| obj = obj.dtype | ||
| if struct._is_scalar_type(obj): | ||
| new_obj = recast_ptr(base + off, dtype=obj) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| new_obj = recast_ptr(base + off, dtype=obj) | |
| new_obj = recast_ptr(base + off, dtype=obj, loc=loc, ip=ip) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| @cute.union | ||
| class BasicUnion: | ||
| as_int: cutlass.Int32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe it's worth adding failing tests like
@cute.union
class BasicUnion:
as_int: cutlass.Int32
as_float # no annotation
|
The PR is in great shape — thanks for adding Let's also wait @anakinxc's review |
| ) | ||
| return cls | ||
|
|
||
| def size_in_bytes(self) -> int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you might also add
def __repr__(self) -> str:
return f"<union {self._cls.__name__} size={self._size_of} align={self._align_of}>"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, not struct doesn't have a repr either btw
| :return: The decorated union class. | ||
| """ | ||
|
|
||
| def __init__(self, cls): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably also implement __setitem__ for union to prevent user add new field after init?
Consider the following code
@cute.union
class value_union:
as_int : cutlass.Int32
as_float : cutlass.Float32
def foo():
vu = value_union()
vu.as_some_craziness = myStruct() # This should be rejected
vu.__sizeof__() # what should be the size now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added. struct doesn't have that either btw
| @@ -0,0 +1,190 @@ | |||
| ################################################################################################# | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@brandon-yujie-sun @grypp @anakinxc seems we miss directory for testing :) Maybe we should consider add them. For this PR, I think it's okay to put it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, we should add test folder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, it would be better to put it under test/python/CuTeDSL or something like that. cc @zekunf-nv whom also plans to add some tests to the repo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved
e9a3cca to
8c4338e
Compare
No description provided.