Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ jobs:
pip install .[dev]
- name: Testing with pytest
run: |
python -m pytest . -s
python -m pytest ./tests -s -v
- name: Linting with flake8
run: |
python -m flake8 .
python -m isort -rc --check-only --diff ./ptflops ./tests
python -m flake8 ./ptflops ./tests ./samples
python -m isort -rc --check-only --diff ./ptflops ./tests ./samples
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[isort]
line_length = 79
line_length = 89
multi_line_output = 0
known_standard_library = setuptools
known_first_party = ptflops
Expand Down
4 changes: 2 additions & 2 deletions ptflops/aten_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def f(*args):

def exit_module(self, name):
def f(*args):
assert(self.parents[-1] == name)
assert (self.parents[-1] == name)
self.parents.pop()
return f

Expand Down Expand Up @@ -138,7 +138,7 @@ def get_flops_aten(model, input_res,

except Exception as e:
print("Flops estimation was not finished successfully because of"
f" the following exception:\n{type(e)} : {e}")
f" the following exception: \n{type(e)}: {e}")
traceback.print_exc()

return None, None
Expand Down
6 changes: 3 additions & 3 deletions ptflops/pytorch_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import torch.nn as nn
import torch.nn.functional as F

from .pytorch_ops import (CUSTOM_MODULES_MAPPING, FUNCTIONAL_MAPPING,
MODULES_MAPPING, TENSOR_OPS_MAPPING)
from .pytorch_ops import (CUSTOM_MODULES_MAPPING, FUNCTIONAL_MAPPING, MODULES_MAPPING,
TENSOR_OPS_MAPPING)
from .utils import flops_to_string, params_to_string


Expand Down Expand Up @@ -72,7 +72,7 @@ def reset_environment():

except Exception as e:
print("Flops estimation was not finished successfully because of"
f" the following exception:\n{type(e)} : {e}")
f" the following exception: \n{type(e)}: {e}")
traceback.print_exc()
reset_environment()

Expand Down
6 changes: 3 additions & 3 deletions ptflops/pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def bn_flops_counter_hook(module, input, output):
module.__flops__ += int(batch_flops)


def conv_flops_counter_hook(conv_module, input, output, extra_per_position_flops=0, transpose=False):
def conv_flops_counter_hook(conv_module, input, output, extra_per_position_flops=0,
transpose=False):
# Can have multiple inputs, getting the first one
input = input[0]

Expand All @@ -84,8 +85,7 @@ def conv_flops_counter_hook(conv_module, input, output, extra_per_position_flops
bias_flops = 0

if conv_module.bias is not None:

bias_flops = out_channels * active_elements_count
bias_flops = batch_size * int(np.prod(list(output.shape[1:]), dtype=np.int64))

overall_flops = overall_conv_flops + bias_flops

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [

[project.optional-dependencies]
dev = [
"flake8==3.8.1",
"flake8==5.0.1",
"flake8-import-order==0.18.1",
"isort==4.3.21",
"torchvision>=0.5.0",
Expand Down
11 changes: 11 additions & 0 deletions tests/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@ def test_conv(self, default_input_image_size, backend: FLOPS_BACKEND):
assert params == 3 * 3 * 2 * 3 + 2
assert macs == 2759904

@pytest.mark.parametrize("backend", [FLOPS_BACKEND.PYTORCH, FLOPS_BACKEND.ATEN])
def test_conv_t(self, default_input_image_size, backend: FLOPS_BACKEND):
net = nn.ConvTranspose2d(3, 2, 3, stride=(2, 2), bias=True)
macs, params = get_model_complexity_info(net, default_input_image_size,
as_strings=False,
print_per_layer_stat=False,
backend=backend)

assert params == 3 * 3 * 2 * 3 + 2
assert macs == 3112706

@pytest.mark.parametrize("backend", [FLOPS_BACKEND.PYTORCH, FLOPS_BACKEND.ATEN])
def test_fc(self, backend: FLOPS_BACKEND):
net = nn.Sequential(nn.Linear(3, 2, bias=True))
Expand Down