-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
Copy pathtest_train_pipeline.py
91 lines (71 loc) · 2.56 KB
/
test_train_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from dataclasses import dataclass
from typing import Tuple
from tml.common.batch import DataclassBatch
from tml.common.testing_utils import mock_pg
from tml.core import train_pipeline
import torch
from torchrec.distributed import DistributedModelParallel
@dataclass
class MockDataclassBatch(DataclassBatch):
continuous_features: torch.Tensor
labels: torch.Tensor
class MockModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.model = torch.nn.Linear(10, 1)
self.loss_fn = torch.nn.BCEWithLogitsLoss()
def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, torch.Tensor]:
pred = self.model(batch.continuous_features)
loss = self.loss_fn(pred, batch.labels)
return (loss, pred)
def create_batch(bsz: int):
return MockDataclassBatch(
continuous_features=torch.rand(bsz, 10).float(),
labels=torch.bernoulli(torch.empty(bsz, 1).uniform_(0, 1)).float(),
)
def test_sparse_pipeline():
device = torch.device("cpu")
model = MockModule().to(device)
steps = 8
example = create_batch(1)
dataloader = iter(example for _ in range(steps + 2))
results = []
with mock_pg():
d_model = DistributedModelParallel(model)
pipeline = train_pipeline.TrainPipelineSparseDist(
model=d_model,
optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9),
device=device,
grad_accum=2,
)
for _ in range(steps):
results.append(pipeline.progress(dataloader))
results = [elem.detach().numpy() for elem in results]
# Check gradients are accumulated, i.e. results do not change for every 0th and 1th.
for first, second in zip(results[::2], results[1::2]):
assert first == second, results
# Check we do update gradients, i.e. results do change for every 1th and 2nd.
for first, second in zip(results[1::2], results[2::2]):
assert first != second, results
def test_amp():
device = torch.device("cpu")
model = MockModule().to(device)
steps = 8
example = create_batch(1)
dataloader = iter(example for _ in range(steps + 2))
results = []
with mock_pg():
d_model = DistributedModelParallel(model)
pipeline = train_pipeline.TrainPipelineSparseDist(
model=d_model,
optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9),
device=device,
enable_amp=True,
# Not supported on CPU.
enable_grad_scaling=False,
)
for _ in range(steps):
results.append(pipeline.progress(dataloader))
results = [elem.detach() for elem in results]
for value in results:
assert value.dtype == torch.bfloat16