-
-
Notifications
You must be signed in to change notification settings - Fork 654
Open
Labels
HacktoberfestPyDataGlobalPyData Global 2020 SprintPyData Global 2020 Sprintenhancementhelp wantedmodule: metricsMetrics moduleMetrics module
Description
Currently, the testing data is generated without taking into account distributed context:
ignite/tests/ignite/metrics/test_confusion_matrix.py
Lines 559 to 565 in e17acc7
def _test_distrib_multiclass_images(device): | |
def _test(metric_device): | |
num_classes = 3 | |
cm = ConfusionMatrix(num_classes=num_classes, device=metric_device) | |
y_true, y_pred = get_y_true_y_pred() | |
same for MultiLabelConfusionMatrix tests.
We would like to rewrite the tests in a similar way as here :
ignite/tests/ignite/metrics/test_recall.py
Lines 778 to 785 in e17acc7
def _test(average, n_epochs, metric_device): | |
n_iters = 60 | |
s = 16 | |
n_classes = 7 | |
offset = n_iters * s | |
y_true = torch.randint(0, 2, size=(offset * idist.get_world_size(), n_classes, 6, 8)).to(device) | |
y_preds = torch.randint(0, 2, size=(offset * idist.get_world_size(), n_classes, 6, 8)).to(device) |
See https://github.com/pytorch/ignite/issues#issuecomment-782750122
cc @touqir14
Metadata
Metadata
Assignees
Labels
HacktoberfestPyDataGlobalPyData Global 2020 SprintPyData Global 2020 Sprintenhancementhelp wantedmodule: metricsMetrics moduleMetrics module