-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmethods.py
107 lines (86 loc) · 4.1 KB
/
methods.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import pytorch_lightning as pl
from torch import nn, optim
import torch
from torchvision import models, transforms
import torchmetrics as metrics
from torch.distributions.poisson import Poisson
CUDA_LAUNCH_BLOCKING = 1
class ObjectDistributionModel(pl.LightningModule):
def __init__(self, num_classes=91):
super().__init__()
self.save_hyperparameters()
self.aerial_model = models.resnet18(pretrained=True)
self.aerial_model.fc = nn.Linear(512, num_classes)
self.train_metrics = metrics.Accuracy()
self.test_metrics = metrics.Accuracy()
self.val_metrics = metrics.Accuracy()
def forward(self, data):
aerial_feats = self.aerial_model(data["aerial_img"])
aerial_feats = nn.functional.softplus(aerial_feats)
return aerial_feats
def get_loss(self, logits, labels, step="train", metrics=None):
m = Poisson(logits)
output = m.sample().int()
lb = -1 * m.log_prob(labels)
loss = lb.mean()
self.log(f"{step}_loss", loss, prog_bar=True)
self.log(f"{step}_accuracy", metrics(output, labels), prog_bar=True)
return loss
def training_step(self, batch, batch_idx=None):
return self._shared_step(batch, step="train", metrics=self.train_metrics)
def validation_step(self, batch, batch_idx=None):
return self._shared_step(batch, step="val", metrics=self.val_metrics)
def test_step(self, batch, batch_idx=None):
return self._shared_step(batch, step="test", metrics=self.test_metrics)
def _shared_step(self, batch, batch_idx=None, step="train", metrics=None):
logits = self(batch)
labels = batch["labels_counts"]
loss = self.get_loss(logits, labels, step=step, metrics=metrics)
return loss
def configure_optimizers(self):
return optim.AdamW(self.parameters(), lr=1e-3, betas=(0.9, 0.99), eps=1e-5)
class NlcdBaselineModel(pl.LightningModule):
def __init__(self, num_classes=8):
super().__init__()
self.save_hyperparameters()
self.aerial_model = models.resnet18(pretrained=True)
self.aerial_model.fc = nn.Linear(512, num_classes)
self.train_metrics = metrics.Accuracy()
self.test_metrics = metrics.Accuracy()
self.val_metrics = metrics.Accuracy()
self.criterion = nn.CrossEntropyLoss()
def forward(self, data):
aerial_feats = self.aerial_model(data["aerial_img"])
aerial_feats = nn.functional.normalize(aerial_feats) # shape = (Batch, Feature)
data["aerial_feats"] = aerial_feats
return data
def get_loss(self, feats, labels, step="train", metric=None):
loss = self.criterion(feats, labels)
return loss
def training_step(self, batch, batch_idx=None):
return self._shared_step(batch, step="train", metric=self.train_metrics)
def validation_step(self, batch, batch_idx=None):
return self._shared_step(batch, step="val", metric=self.val_metrics)
def test_step(self, batch, batch_idx=None):
return self._shared_step(batch, step="test", metric=self.test_metrics)
def _shared_step(self, batch, batch_idx=None, step="train", metric=None):
logits = self(batch)
labels = logits["nlcd_coarse_labels"]
feats = logits["aerial_feats"]
loss = self.get_loss(feats, labels, step=step, metric=metric)
preds = torch.argmax(feats, dim=1)
accuracy = metric(preds, labels).detach()
self.log(f"{step}/acc", accuracy)
self.log(f"{step}/loss", loss)
return loss
def configure_optimizers(self):
return optim.AdamW(self.parameters(), lr=1e-3, betas=(0.9, 0.99), eps=1e-5)
class NlcdPretrainedModel(NlcdBaselineModel):
def __init__(self, num_classes=8, checkpoint_path=None):
super().__init__()
self.save_hyperparameters()
checkpoint = torch.load(checkpoint_path)
obj_dist_model = ObjectDistributionModel(num_classes=num_classes)
obj_dist_model.load_state_dict(checkpoint["state_dict"])
self.aerial_model = obj_dist_model.aerial_model
self.aerial_model.fc = nn.Linear(512, num_classes)