Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions ptflops/pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
* this file. If not visit https://opensource.org/licenses/MIT
'''

from functools import partial

import numpy as np
import torch
import torch.nn as nn
Expand Down Expand Up @@ -55,12 +57,11 @@ def bn_flops_counter_hook(module, input, output):
module.__flops__ += int(batch_flops)


def conv_flops_counter_hook(conv_module, input, output, extra_per_position_flops=0):
def conv_flops_counter_hook(conv_module, input, output, extra_per_position_flops=0, transpose=False):
# Can have multiple inputs, getting the first one
input = input[0]

batch_size = input.shape[0]
output_dims = list(output.shape[2:])

kernel_dims = list(conv_module.kernel_size)
in_channels = conv_module.in_channels
Expand All @@ -71,7 +72,12 @@ def conv_flops_counter_hook(conv_module, input, output, extra_per_position_flops
conv_per_position_flops = int(np.prod(kernel_dims, dtype=np.int64)) * \
(in_channels * filters_per_channel + extra_per_position_flops)

active_elements_count = batch_size * int(np.prod(output_dims, dtype=np.int64))
if transpose:
input_dims = list(input.shape[2:])
active_elements_count = batch_size * int(np.prod(input_dims, dtype=np.int64))
else:
output_dims = list(output.shape[2:])
active_elements_count = batch_size * int(np.prod(output_dims, dtype=np.int64))

overall_conv_flops = conv_per_position_flops * active_elements_count

Expand Down Expand Up @@ -301,9 +307,9 @@ def timm_attention_counter_hook(attention_module, input, output):
# Upscale
nn.Upsample: upsample_flops_counter_hook,
# Deconvolution
nn.ConvTranspose1d: conv_flops_counter_hook,
nn.ConvTranspose2d: conv_flops_counter_hook,
nn.ConvTranspose3d: conv_flops_counter_hook,
nn.ConvTranspose1d: partial(conv_flops_counter_hook, transpose=True),
nn.ConvTranspose2d: partial(conv_flops_counter_hook, transpose=True),
nn.ConvTranspose3d: partial(conv_flops_counter_hook, transpose=True),
# RNN
nn.RNN: rnn_flops_counter_hook,
nn.GRU: rnn_flops_counter_hook,
Expand Down
Loading