Skip to content

Commit 002cf2e

Browse files
author
Hansel
committed
initial commit
1 parent 08f0f38 commit 002cf2e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+15277
-0
lines changed

.gitignore

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,10 @@ cython_debug/
158158
# and can be added to the global gitignore or merged into this file. For a more nuclear
159159
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160160
#.idea/
161+
162+
# costum
163+
wandb/*
164+
TrainingRuns/*
165+
Data/*
166+
*.sh
167+

Classifier/Models.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
2+
from copy import deepcopy
3+
import torchvision.models as models
4+
from torchvision.models import ViT_B_16_Weights
5+
import torch.nn.functional as F
6+
import torch
7+
import numpy as np
8+
from Variables import *
9+
10+
11+
class Classifier_ResNet50(torch.nn.Module):
12+
def __init__(self):
13+
super().__init__()
14+
model = models.resnet50(pretrained=True)
15+
num_ftrs = model.fc.in_features
16+
model.fc = torch.nn.Linear(num_ftrs, OUTPUT_NEURONS)
17+
model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
18+
# for resnet50 use pretrained weights from Conrad, Ryan, and Kedar Narayan. "CEM500K, a large-scale heterogeneous unlabeled cellular electron microscopy image dataset for deep learning." Elife 10 (2021): e65894.
19+
state = torch.load(EM_PRETRAINED_WEIGHTS, map_location='cpu')
20+
state_dict = state['state_dict']
21+
#format the parameter names to match torchvision resnet50
22+
resnet50_state_dict = deepcopy(state_dict)
23+
for k in list(resnet50_state_dict.keys()):
24+
#only keep query encoder parameters; discard the fc projection head
25+
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
26+
resnet50_state_dict[k[len("module.encoder_q."):]] = resnet50_state_dict[k]
27+
#delete renamed or unused k
28+
del resnet50_state_dict[k]
29+
# load model weights
30+
model.load_state_dict(resnet50_state_dict, strict=False)
31+
self.model = model
32+
33+
def forward(self, x: torch.Tensor) -> torch.Tensor:
34+
return self.model(x)
35+
36+
class Classifier_Oracle(torch.nn.Module):
37+
def __init__(self):
38+
super().__init__()
39+
40+
def forward(self, transformed_mask: torch.Tensor, gt_mask: torch.Tensor, capside_radius: torch.Tensor) -> torch.Tensor:
41+
area = (np.pi*capside_radius**2).to(DEVICE)
42+
overlap = transformed_mask.squeeze()*gt_mask.to(DEVICE).squeeze()
43+
return torch.sum(overlap)/area
44+
45+
class Classifier_ViT(torch.nn.Module):
46+
def __init__(self):
47+
super().__init__()
48+
model = models.vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
49+
num_ftrs = model.heads.head.in_features
50+
model.heads.head = torch.nn.Linear(num_ftrs, OUTPUT_NEURONS)
51+
self.model = model
52+
53+
def forward(self, x: torch.Tensor) -> torch.Tensor:
54+
return self.model(x)
55+
56+
class Classifier_ResNet101(torch.nn.Module):
57+
def __init__(self):
58+
super().__init__()
59+
model = models.resnet101(pretrained=True)
60+
num_ftrs = model.fc.in_features
61+
model.fc = torch.nn.Linear(num_ftrs, OUTPUT_NEURONS)
62+
self.model = model
63+
64+
def forward(self, x: torch.Tensor) -> torch.Tensor:
65+
return self.model(x)
66+
67+
68+

0 commit comments

Comments
 (0)