Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GRU support #70

Open
ZM-J opened this issue Oct 22, 2024 · 0 comments
Open

GRU support #70

ZM-J opened this issue Oct 22, 2024 · 0 comments

Comments

@ZM-J
Copy link

ZM-J commented Oct 22, 2024

There's a GRU onnx model which can be directly used to infer using onnxruntime. However, when I'm trying to use the following code to convert this model to its pytorch counterpart:

import onnx
from onnx2pytorch import ConvertModel

onnx_model = onnx.load(onnx_model_path)
pytorch_model = ConvertModel(onnx_model)

import torch
X = torch.randn((1, 60, 78))
with torch.inference_mode():
    y = pytorch_model(X)

print(y.shape)

I got the following error:

Traceback (most recent call last):
  File "/convert.py", line 41, in <module>
    y = pytorch_model(X)
  File "/miniconda3/envs/convert_model/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/miniconda3/envs/convert_model/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/miniconda3/envs/convert_model/lib/python3.10/site-packages/onnx2pytorch/convert/model.py", line 224, in forward
    activations[out_op_id] = op(*in_activations)
TypeError: gru() received an invalid combination of arguments - got (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor), but expected one of:
 * (Tensor data, Tensor batch_sizes, Tensor hx, tuple of Tensors params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional)
 * (Tensor input, Tensor hx, tuple of Tensors params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first)

I don't know why it has happened internally though :(

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant