Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The inference checkpoint (using the pretrained MTLSD checkpoint) is not compatible with PyTorch. #26

Open
yumang1cv opened this issue Feb 10, 2025 · 0 comments

Comments

@yumang1cv
Copy link

Recently, while running the inference step in this Colab-based tutorial (using the pre-trained MTLSD checkpoint), I encountered an error with the following code:

checkpoint = 'model_checkpoint_50000'  
raw_file = 'testing_data.zarr'  
raw_dataset = 'raw/0'  

raw, pred_lsds, pred_affs = predict(checkpoint, raw_file, raw_dataset)  

The error I received was:

Traceback (most recent call last):  
  File "/usr/local/lib/python3.11/dist-packages/gunpowder/nodes/batch_provider.py", line 193, in request_batch  
    batch = self.provide(upstream_request)  
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^  
  File "/usr/local/lib/python3.11/dist-packages/gunpowder/nodes/batch_filter.py", line 148, in provide  
    dependencies = self.prepare(request)  
                   ^^^^^^^^^^^^^^^^^^^^^  
  File "/usr/local/lib/python3.11/dist-packages/gunpowder/nodes/generic_predict.py", line 116, in prepare  
    self.start()  
  File "/usr/local/lib/python3.11/dist-packages/gunpowder/torch/nodes/predict.py", line 105, in start  
    checkpoint = torch.load(self.checkpoint, map_location=self.device)  
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^  
  File "/usr/local/lib/python3.11/dist-packages/torch/serialization.py", line 1384, in load  
    return _legacy_load(  
           ^^^^^^^^^^^^^  
  File "/usr/local/lib/python3.11/dist-packages/torch/serialization.py", line 1628, in _legacy_load  
    magic_number = pickle_module.load(f, **pickle_load_args)  
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^  
_pickle.UnpicklingError: invalid load key, '<'.  

Exception in pipeline:  
ZarrSource[testing_data.zarr] -> Normalize -> Unsqueeze -> Stack -> Predict -> Scan -> Squeeze -> Squeeze  
while trying to process request  

    RAW: ROI: [0:5000, 0:5000] (5000, 5000), voxel size: None, interpolatable: None, non-spatial: False, dtype: None, placeholder: False  
    PRED_LSDS: ROI: [40:4960, 40:4960] (4920, 4920), voxel size: None, interpolatable: None, non-spatial: False, dtype: None, placeholder: False  
    PRED_AFFS: ROI: [40:4960, 40:4960] (4920, 4920), voxel size: None, interpolatable: None, non-spatial: False, dtype: None, placeholder: False  

I was able to download the checkpoint successfully, as shown below:

--2025-02-10 09:22:56--  https://www.dropbox.com/s/r1u8pvji5lbanyq/model_checkpoint_50000  
Resolving www.dropbox.com (www.dropbox.com)... 162.125.65.18, 2620:100:6021:18::a27d:4112  
Connecting to www.dropbox.com (www.dropbox.com)|162.125.65.18|:443... connected.  
HTTP request sent, awaiting response... 200 OK  
Length: unspecified [text/html]  
Saving to: ‘model_checkpoint_50000’  

model_checkpoint_50     [  <=>               ]  72.45K   209KB/s    in 0.3s    

2025-02-10 09:22:58 (209 KB/s) - ‘model_checkpoint_50000’ saved [74187]  

After inspecting the issue, I realized that this checkpoint is likely based on TensorFlow and is therefore incompatible with PyTorch. Could you please provide a PyTorch-compatible version of the checkpoint to resolve this issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant