Skip to content

kwargs reduceprod #72

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 22 additions & 15 deletions onnx2pytorch/convert/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,21 @@
import numpy as np
import onnx
import torch
from onnx import numpy_helper
from torch import nn
from torch.nn import functional as F
from onnx import numpy_helper
from torch.nn.modules.linear import Identity

from onnx2pytorch.convert.attribute import extract_attributes
from onnx2pytorch.convert.layer import (
convert_layer,
convert_linear_layer,
convert_batch_norm_layer,
convert_instance_norm_layer,
convert_lstm_layer,
)
from onnx2pytorch.convert.layer import (convert_batch_norm_layer,
convert_instance_norm_layer,
convert_layer, convert_linear_layer,
convert_lstm_layer)
from onnx2pytorch.operations import *
from onnx2pytorch.operations import Hardsigmoid, Resize, Upsample
from onnx2pytorch.operations.base import OperatorWrapper
from onnx2pytorch.operations import Resize, Upsample, Hardsigmoid
from onnx2pytorch.utils import (
get_inputs_names,
get_outputs_names,
value_wrapper,
)
from onnx2pytorch.utils import (get_inputs_names, get_outputs_names,
value_wrapper)


def get_buffer_name(param_name):
Expand Down Expand Up @@ -211,7 +205,20 @@ def convert_operations(onnx_graph, opset_version, batch_dim=0, enable_pruning=Tr
elif node.op_type == "ReduceProd":
kwargs = dict(keepdim=True)
kwargs.update(extract_attributes(node))
op = partial(torch.prod, **kwargs)
def reduceprod_wrapper(x, **kw):
# When no reduction axis is specified,
# we must simulate "keepdim=True" by outputting a tensor of the same rank.
if 'axes' not in kw and 'dim' not in kw:
original_dim = x.dim()
# Compute the product over all elements
out = torch.prod(x)
# Reshape to have as many dimensions as the original input (all ones)
out = out.view([1] * original_dim)
else:
out = torch.prod(x, **kw)
return out
# Use the kwargs when binding reduceprod_wrapper.
op = lambda x: reduceprod_wrapper(x, **kwargs)
elif node.op_type == "ReduceSum":
op = ReduceSum(opset_version=opset_version, **extract_attributes(node))
elif node.op_type == "ReduceL2":
Expand Down