Skip to content

Commit 4af4195

Browse files
committed
refactor: fix MyPy type issues, refactor data loader and clean up debugging code
1 parent 5c2466a commit 4af4195

File tree

6 files changed

+41
-123
lines changed

6 files changed

+41
-123
lines changed

examples/multi_gpu/gather.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

flair/distributed_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,15 @@
22
import os
33
import random
44
from multiprocessing.connection import Connection
5-
from typing import Any, Callable, Collection, Iterable, TypeVar
5+
from typing import Callable, Collection, Iterable, TypeVar
66

7-
import numpy as np
87
import torch
98
import torch.multiprocessing as mp
109
from torch.distributed import destroy_process_group, init_process_group
1110
from torch.utils.data import Dataset
1211

1312
import flair
1413
from flair.data import Corpus, _len_dataset
15-
from flair.training_utils import print_execution_time
1614

1715
log = logging.getLogger("flair")
1816

@@ -64,7 +62,6 @@ def is_main_process() -> bool:
6462
return True
6563

6664

67-
@print_execution_time
6865
def validate_corpus_same_each_process(corpus: Corpus) -> None:
6966
"""Catches most cases in which a corpus is not the same on each process. However, there is no guarantee for two
7067
reasons: 1) It uses a sample for speed 2) It compares strings to avoid requiring the datasets to be serializable

flair/models/pairwise_regression_model.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -299,18 +299,12 @@ def evaluate(
299299
if not isinstance(data_points, Dataset):
300300
data_points = FlairDatapointDataset(data_points)
301301

302-
if multi_gpu:
303-
distributed_sampler: DistributedSampler = DistributedSampler(
304-
data_points, shuffle=False
305-
)
306-
data_loader = DataLoader(
307-
data_points,
308-
batch_size=mini_batch_size,
309-
shuffle=False,
310-
sampler=distributed_sampler,
311-
)
312-
else:
313-
data_loader = DataLoader(data_points, batch_size=mini_batch_size)
302+
data_loader = DataLoader(
303+
data_points,
304+
batch_size=mini_batch_size,
305+
shuffle=False,
306+
sampler=DistributedSampler(data_points, shuffle=False) if multi_gpu else None,
307+
)
314308

315309
with torch.no_grad():
316310
eval_loss = torch.zeros(1, device=flair.device)
@@ -327,15 +321,15 @@ def evaluate(
327321
if isinstance(batch, Sentence):
328322
batch = [batch]
329323

330-
loss, num, scores = self._forward_loss_and_scores(batch, return_scores=True)
324+
loss, num, scores_forward = self._forward_loss_and_scores(batch, return_scores=True)
331325

332326
true_values = []
333327
for sentence in batch:
334328
total_count += 1
335329
for label in sentence.get_labels(gold_label_type):
336330
true_values.append(float(label.value))
337331

338-
results = scores.cpu().tolist()
332+
results = scores_forward.cpu().tolist()
339333

340334
eval_loss += loss
341335

@@ -389,7 +383,7 @@ def evaluate(
389383
)
390384

391385
else: # if it's not the main process, just set a dummy Result
392-
result = Result(0., "", {}, {'loss': 0.0})
386+
result = Result(0.0, "", {}, {"loss": 0.0})
393387

394388
if multi_gpu:
395389
result = broadcast_value(result, src=0)

flair/models/text_regression_model.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -152,18 +152,12 @@ def evaluate(
152152
if not isinstance(data_points, Dataset):
153153
data_points = FlairDatapointDataset(data_points)
154154

155-
if multi_gpu:
156-
distributed_sampler: DistributedSampler = DistributedSampler(
157-
data_points, shuffle=False
158-
)
159-
data_loader = DataLoader(
160-
data_points,
161-
batch_size=mini_batch_size,
162-
shuffle=False,
163-
sampler=distributed_sampler,
164-
)
165-
else:
166-
data_loader = DataLoader(data_points, batch_size=mini_batch_size)
155+
data_loader = DataLoader(
156+
data_points,
157+
batch_size=mini_batch_size,
158+
shuffle=False,
159+
sampler=DistributedSampler(data_points, shuffle=False) if multi_gpu else None,
160+
)
167161

168162
with torch.no_grad():
169163
eval_loss = torch.zeros(1, device=flair.device)
@@ -176,15 +170,15 @@ def evaluate(
176170
if isinstance(batch, Sentence):
177171
batch = [batch]
178172

179-
scores, loss = self.forward_labels_and_loss(batch)
173+
scores_forward, loss = self.forward_labels_and_loss(batch)
180174

181175
true_values = []
182176
for sentence in batch:
183177
total_count += 1
184178
for label in sentence.get_labels(gold_label_type):
185179
true_values.append(float(label.value))
186180

187-
results = scores[:, 0].cpu().tolist()
181+
results = scores_forward[:, 0].cpu().tolist()
188182

189183
eval_loss += loss
190184

@@ -239,7 +233,7 @@ def evaluate(
239233
)
240234

241235
else: # if it's not the main process, just set a dummy Result
242-
result = Result(0., "", {}, {'loss': 0.0})
236+
result = Result(0.0, "", {}, {"loss": 0.0})
243237

244238
if multi_gpu:
245239
result = broadcast_value(result, src=0)

flair/nn/model.py

Lines changed: 21 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from abc import ABC, abstractmethod
66
from collections import Counter
77
from pathlib import Path
8-
from time import time
98
from typing import Any, Optional, Union
109

1110
import torch.nn
@@ -19,8 +18,14 @@
1918
from flair.class_utils import get_non_abstract_subclasses
2019
from flair.data import DT, DT2, Corpus, Dictionary, Sentence, _iter_dataset
2120
from flair.datasets import DataLoader, FlairDatapointDataset
22-
from flair.distributed_utils import aggregate, aggregate_tensor_sum, broadcast_value, flatten_dicts, is_main_process, \
23-
merge_sets
21+
from flair.distributed_utils import (
22+
aggregate,
23+
aggregate_tensor_sum,
24+
broadcast_value,
25+
flatten_dicts,
26+
is_main_process,
27+
merge_sets,
28+
)
2429
from flair.embeddings import Embeddings
2530
from flair.embeddings.base import load_embeddings
2631
from flair.file_utils import Tqdm, load_torch_state
@@ -274,8 +279,6 @@ def evaluate(
274279
multi_gpu: bool = False,
275280
**kwargs,
276281
) -> Result:
277-
t0 = time()
278-
print('running custom evaluate..')
279282
exclude_labels = exclude_labels if exclude_labels is not None else []
280283

281284
import numpy as np
@@ -302,25 +305,15 @@ def evaluate(
302305
all_true_values = {}
303306
all_predicted_values = {}
304307

305-
if multi_gpu:
306-
distributed_sampler: DistributedSampler = DistributedSampler(
307-
data_points, shuffle=False
308-
)
309-
loader = DataLoader(
310-
data_points,
311-
batch_size=mini_batch_size,
312-
shuffle=False,
313-
sampler=distributed_sampler,
314-
)
315-
rank = torch.distributed.get_rank()
316-
print('rank =', rank)
317-
else:
318-
loader = DataLoader(data_points, batch_size=mini_batch_size)
319-
rank = 0
308+
loader = DataLoader(
309+
data_points,
310+
batch_size=mini_batch_size,
311+
shuffle=False,
312+
sampler=DistributedSampler(data_points, shuffle=False) if multi_gpu else None,
313+
)
314+
rank = torch.distributed.get_rank() if multi_gpu else 0
320315

321316
sentence_id = 0
322-
t1 = time()
323-
print('time1', t1 - t0)
324317
for batch in Tqdm.tqdm(loader, disable=not is_main_process()):
325318
# remove any previously predicted labels
326319
for datapoint in batch:
@@ -381,19 +374,14 @@ def evaluate(
381374
if out_path:
382375
lines.extend(self._print_predictions(batch, gold_label_type))
383376

384-
t2 = time()
385-
print('time2', t2 - t1)
386-
print('eval losssss', type(eval_loss), eval_loss)
387377
if multi_gpu:
388378
all_spans = aggregate(all_spans, merge_sets)
389379
all_true_values = aggregate(all_true_values, flatten_dicts)
390380
all_predicted_values = aggregate(all_predicted_values, flatten_dicts)
391381
average_over = aggregate(average_over, sum)
392382
eval_loss = aggregate(eval_loss, aggregate_tensor_sum)
393-
print('eval loss =', eval_loss)
394-
print('len all', len(all_spans), len(all_true_values), len(all_predicted_values), sep='\t')
395383

396-
result = Result(0., "", {}, {'loss': 0.0})
384+
result = Result(0.0, "", {}, {"loss": 0.0})
397385
if is_main_process():
398386

399387
# convert true and predicted values to two span-aligned lists
@@ -475,10 +463,8 @@ def evaluate(
475463
target_names.append(label_name)
476464
labels.append(evaluation_label_dictionary.get_idx_for_item(label_name))
477465

478-
#print(f"{len(data_points)}\t{len(y_true_save)}\n{len(y_true)}\t{len(y_pred)}\t{len(target_names)}\t{len(labels)}")
479-
480466
# there is at least one gold label or one prediction (default)
481-
if len(all_true_values) + len(all_predicted_values) > 1:
467+
if is_main_process() and len(all_true_values) + len(all_predicted_values) > 1:
482468
classification_report = sklearn.metrics.classification_report(
483469
y_true,
484470
y_pred,
@@ -512,9 +498,9 @@ def evaluate(
512498
if metric_key != "support":
513499
classification_report_dict["micro avg"][metric_key] = classification_report_dict["accuracy"]
514500
else:
515-
classification_report_dict["micro avg"][metric_key] = classification_report_dict["macro avg"][
516-
"support"
517-
]
501+
classification_report_dict["micro avg"][metric_key] = classification_report_dict[
502+
"macro avg"
503+
]["support"]
518504

519505
detailed_result = (
520506
"\nResults:"
@@ -536,14 +522,7 @@ def evaluate(
536522
if average_over > 0:
537523
eval_loss /= average_over
538524
scores["loss"] = eval_loss.item()
539-
print('scores', scores)
540-
541-
print('classification report')
542-
print(classification_report_dict['micro avg'])
543525

544-
t3 = time()
545-
print('time3', t3 - t2)
546-
print('total time', t3 - t0)
547526
result = Result(
548527
main_score=classification_report_dict[main_evaluation_metric[0]][main_evaluation_metric[1]],
549528
detailed_results=detailed_result,
@@ -559,7 +538,7 @@ def evaluate(
559538
f"- And no predictions were made!\n"
560539
"Double check your corpus (if the test split has labels), and how you initialize the ModelTrainer!"
561540
)
562-
541+
563542
result = Result(
564543
main_score=0.0,
565544
detailed_results=error_text,
@@ -572,9 +551,6 @@ def evaluate(
572551

573552
return result
574553

575-
# final_value
576-
# return final_value
577-
578554
@abstractmethod
579555
def predict(
580556
self,

flair/training_utils.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
1-
import functools
21
import logging
32
import pathlib
43
import random
5-
import time
64
from collections import defaultdict
75
from enum import Enum
86
from functools import reduce
97
from math import inf
108
from pathlib import Path
11-
from typing import Callable, Literal, NamedTuple, Optional, Union
9+
from typing import Literal, NamedTuple, Optional, Union
1210

1311
from numpy import ndarray
1412
from scipy.stats import pearsonr, spearmanr
@@ -516,23 +514,3 @@ def create_labeled_sentence_from_entity_offsets(
516514
token_entities = [entity for entity in token_entities if entity.end_token_idx <= token_limit]
517515

518516
return create_labeled_sentence_from_tokens(tokens, token_entities)
519-
520-
521-
def print_execution_time(func: Callable) -> Callable:
522-
"""
523-
Decorator that prints the execution time of the decorated function.
524-
525-
:param func: The function to be decorated.
526-
:return: The wrapped function with execution time printing.
527-
"""
528-
529-
@functools.wraps(func)
530-
def wrapper(*args, **kwargs):
531-
start_time = time.perf_counter() # Start the timer
532-
result = func(*args, **kwargs) # Execute the function
533-
end_time = time.perf_counter() # End the timer
534-
elapsed_time = end_time - start_time
535-
print(f"Function '{func.__name__}' executed in {elapsed_time:.4f} seconds.")
536-
return result
537-
538-
return wrapper

0 commit comments

Comments
 (0)