-
Notifications
You must be signed in to change notification settings - Fork 135
Labeled tensors #1411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Labeled tensors #1411
Conversation
d6a3ddf
to
177a4c2
Compare
49fac6a
to
e32d865
Compare
d8fe0d1
to
29b954a
Compare
2966e9d
to
692c53c
Compare
Codecov ReportAttention: Patch coverage is
❌ Your patch check has failed because the patch coverage (78.54%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1411 +/- ##
==========================================
- Coverage 82.12% 82.03% -0.09%
==========================================
Files 211 227 +16
Lines 49757 51128 +1371
Branches 8819 9020 +201
==========================================
+ Hits 40862 41944 +1082
- Misses 6715 6926 +211
- Partials 2180 2258 +78
🚀 New features to boost your workflow:
|
If anybody wants to fix mypy that's very welcome :) |
@@ -0,0 +1,219 @@ | |||
# HERE LIE DRAGONS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new files need a license.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lucianopaz can we bring your pre-commit hook over?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be a separate PR, wouldn't be surprised if have files missing it in main
0150b56
to
7da9935
Compare
7da9935
to
7b8877b
Compare
Strategy
We implement xarray-like dummy Ops that respect / propagate dims semantics, and lower them to regular PyTensor graphs with rewrites.
Note in the example above the dummy TensorFromXtensor and XTensorFromTensor remain in the final graph. If we had created a function with Tensor inputs and outputs that are only then converted (symbolically) to and from xtensor, respectively, the final graph would have no signs of dimension operations, other than how it was constructed.
I suggest registering those rewrites in an
xtensor_lowering
database.Coordinates
For now I'm playing with how far we can get without coordinates. This means the graphs produced by an xarray-like syntax are much more amenable to the numpy-like backend of PyTensor. Otherwise it involves a lot of Pandas-like stuff (e.g., Multiindex) that we don't really have. It may be feasible, specially if nothing is symbolic, but... I fear a rabbit hole of edge cases)
Gradients
These ops are currently not differentiable, but one can lower the graph and then call the gradient. I do want to try the lazy grad approach from #788
Help implementing more Ops so we have MVP to try out with PyMC next. We need some Ops
Open a PR on top of this branch, I'll try to merge quickly! Try to make it clean (one commit per Op, unless it's like a factory of related Ops)
Implementing means:
3.1 The rewrites "box" the lower tensor operations between
TensorFromXTensor
andXTensorFromTensor
calls, so that the replacements are valid in terms of types. There are rewrites to remove chains of useless TensorFromXTensor/XTensorFromTensor that should clean up everything in the middle of the graph.Interplay between XTensorTypes and TensorTypes / weakly typed inputs
__add__
and the like so you can do x + x)Meta Ops
math.switch
probably can't support drop=True)time
dim to the outputs, and perhaps use that to also align thesequences
)Math stuff
Shape stuff
Array creation stuff
self.x * 0
,self.x * 0 + 1
)? PyTensor will do the right thing when it gets lowered)Indexing stuff
__getitem__
+ isel Implement indexing operations for XTensorVariables #1429__getitem__
+ isel for boolean indices (should work fine, just need to test and lift raise error)It probably makes sense to convert the non-XTensor indices to XTensor indices if they can be rendered equivalent, to reduce logic needed.
RandomVariables
This is quite important, as we'll need those for PyMC models! They are a mix of blockwise + size argument (which can or not be redundant)
Graph transformations
📚 Documentation preview 📚: https://pytensor--1411.org.readthedocs.build/en/1411/