-
Notifications
You must be signed in to change notification settings - Fork 948
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixed eval.py on MPS #702
Fixed eval.py on MPS #702
Conversation
OMG. We knew there was a bug but we didnt know why. Thanks! |
Yes, it was tough to find. It silently and randomly replaces some numbers in the tensor with completely meaningless ones, and the training/validation can fail only a few steps later. I literally debugged it line by line to find the issue. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before | After
episode_1.mp4
I think this is the one-liner fix of the year.
Congrats @IliaLarchenko, this one was bugging us for a while! 🤗
Thanks @IliaLarchenko for this fix. I tested it on pusht with device = "mps". The policy now works with the fix but would not work before. That is also likely affecting:
I also tried to add this on top of the eval and train files
as per https://www.reddit.com/r/pytorch/comments/1c3kwwg/how_do_i_fix_the_mps_notimplemented_error_for_m1/ It seems to help but hard to tell. The pusht policy trained is not as good as the lerobot/pusht one. |
@aliberts snap here I started writing this post yesterday, have a look into
Maybe more. And unsure of the effect in the training stack as harder to anlayse. |
What this does
I encountered a very weird bug while training/evaluating the model on
mps
device: unexpectedlyobservation.state
values become completely random float numbers that mess up the whole evaluation.Turned out it happens after this line:
lerobot/lerobot/scripts/eval.py
Line 157 in 638d411
non_blocking=True
is not supported by MPS devices and can result in random numbers. There are multiple issues related to it in different libraries. While writing this I also found that there are 2 open issues with the same problem here:#475
#496
This fix should solve them both.
My fix is simple use
non_blocking=True
only on CUDA..to(device, non_blocking=True)
is also used intrain.py
, though I don't see the same issue in training. This is probably because the error happens only with non-contiguous tensors, which is not the case for training (but I didn't look into it; maybe train.py also requires this fix).How it was tested
It can be hard to reproduce as the error appears randomly, but just try to train or evaluate any model on an MPS device. (I used the Pusht dataset, but in other issues, it happened with Aloha.)
E.g. try to evaluate the diffusion_pusht model on mps: https://huggingface.co/lerobot/diffusion_pusht
It fails without the fix (either silently by getting a 0 success rate or by returning NaN at some step) but works well with fix.