Skip to content

Commit

Permalink
Update TransMIL.py
Browse files Browse the repository at this point in the history
  • Loading branch information
wyhsleep authored May 23, 2024
1 parent ee41a58 commit 6f57883
Showing 1 changed file with 1 addition and 210 deletions.
211 changes: 1 addition & 210 deletions models/TransMIL.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,218 +335,9 @@ def relocate(self):
self.classifier = self.classifier.to(device)


class TransMIL_l(nn.Module):
def __init__(self, input_dim, n_classes, dropout, act, layer, survival=False):
super(TransMIL_l, self).__init__()
self._fc1 = [nn.Linear(input_dim, 512)]
if act.lower() == 'relu':
self._fc1 += [nn.ReLU()]
elif act.lower() == 'gelu':
self._fc1 += [nn.GELU()]
if dropout:
self._fc1 += [nn.Dropout(dropout)]
print("dropout: ", dropout)
self._fc1 = nn.Sequential(*self._fc1)
self.cls_token = nn.Parameter(torch.randn(1, 1, 512))
nn.init.normal_(self.cls_token, std=1e-6)
self.n_classes = n_classes
self.layers = nn.ModuleList()
for _ in range(layer):
self.layers.append(TransLayer(dim=512))
if _ != layer-1:
self.layers.append(PPEG(dim=512))

self.norm = nn.LayerNorm(512)
self.classifier = nn.Linear(512, self.n_classes)

self.apply(initialize_weights)
self.survival = survival
def forward(self, x):
if len(x.shape) == 2:
x = x.expand(1, -1, -1)

h = x.float() # [B, n, 1024]

h = self._fc1(h) # [B, n, 256]

# ---->pad
H = h.shape[1]
_H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H)))
add_length = _H * _W - H
h = torch.cat([h, h[:, :add_length, :]], dim=1) # [B, N, 256]

# ---->cls_token
cls_tokens = self.cls_token.expand(1, -1, -1).cuda()
h = torch.cat((cls_tokens, h), dim=1)


for layer in self.layers:
if isinstance(layer, TransLayer):
h = layer(h)
else:
h = layer(h, _H, _W)
h = self.norm(h)[:,0]

logits = self.classifier(h) # [B, n_classes]
Y_prob = F.softmax(logits, dim=1)
Y_hat = torch.topk(logits, 1, dim=1)[1]
A_raw = None
results_dict = None
if self.survival:
Y_hat = torch.topk(logits, 1, dim = 1)[1]
hazards = torch.sigmoid(logits)
S = torch.cumprod(1 - hazards, dim=1)
return hazards, S, Y_hat, None, None
return logits, Y_prob, Y_hat, A_raw, results_dict

def relocate(self):
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._fc1 = self._fc1.to(device)
self.layers = self.layers .to(device)
self.norm = self.norm.to(device)
self.classifier = self.classifier.to(device)


class TransMIL_l_v2(nn.Module):
def __init__(self, input_dim, n_classes, dropout, act, layer, survival=False):
super(TransMIL_l_v2, self).__init__()
self._fc1 = [nn.Linear(input_dim, 512)]
if act.lower() == 'relu':
self._fc1 += [nn.ReLU()]
elif act.lower() == 'gelu':
self._fc1 += [nn.GELU()]
if dropout:
self._fc1 += [nn.Dropout(dropout)]
print("dropout: ", dropout)
self._fc1 = nn.Sequential(*self._fc1)
self.cls_token = nn.Parameter(torch.randn(1, 1, 512))
nn.init.normal_(self.cls_token, std=1e-6)
self.n_classes = n_classes
self.layers = nn.ModuleList()
for _ in range(layer):
self.layers.append(TransLayer(dim=512))
if _ == layer/2-1:
self.layers.append(PPEG(dim=512))

self.norm = nn.LayerNorm(512)
self.classifier = nn.Linear(512, self.n_classes)

self.apply(initialize_weights)
self.survival = survival
def forward(self, x):
if len(x.shape) == 2:
x = x.expand(1, -1, -1)

h = x.float() # [B, n, 1024]

h = self._fc1(h) # [B, n, 256]

# ---->pad
H = h.shape[1]
_H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H)))
add_length = _H * _W - H
h = torch.cat([h, h[:, :add_length, :]], dim=1) # [B, N, 256]

# ---->cls_token
cls_tokens = self.cls_token.expand(1, -1, -1).cuda()
h = torch.cat((cls_tokens, h), dim=1)


for layer in self.layers:
if isinstance(layer, TransLayer):
h = layer(h)
else:
h = layer(h, _H, _W)
h = self.norm(h)[:,0]

logits = self.classifier(h) # [B, n_classes]
Y_prob = F.softmax(logits, dim=1)
Y_hat = torch.topk(logits, 1, dim=1)[1]
A_raw = None
results_dict = None
if self.survival:
Y_hat = torch.topk(logits, 1, dim = 1)[1]
hazards = torch.sigmoid(logits)
S = torch.cumprod(1 - hazards, dim=1)
return hazards, S, Y_hat, None, None
return logits, Y_prob, Y_hat, A_raw, results_dict

def relocate(self):
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._fc1 = self._fc1.to(device)
self.layers = self.layers .to(device)
self.norm = self.norm.to(device)
self.classifier = self.classifier.to(device)


class TransMIL_l_v3(nn.Module):
def __init__(self, input_dim, n_classes, dropout, act, layer, survival=False):
super(TransMIL_l_v3, self).__init__()
self._fc1 = [nn.Linear(input_dim, 512)]
if act.lower() == 'relu':
self._fc1 += [nn.ReLU()]
elif act.lower() == 'gelu':
self._fc1 += [nn.GELU()]
if dropout:
self._fc1 += [nn.Dropout(dropout)]
print("dropout: ", dropout)
self._fc1 = nn.Sequential(*self._fc1)
self.cls_token = nn.Parameter(torch.randn(1, 1, 512))
nn.init.normal_(self.cls_token, std=1e-6)
self.n_classes = n_classes
self.layers = nn.ModuleList()
for _ in range(layer):
self.layers.append(TransLayer(dim=512))

self.norm = nn.LayerNorm(512)
self.classifier = nn.Linear(512, self.n_classes)

self.apply(initialize_weights)
self.survival = survival
def forward(self, x):
if len(x.shape) == 2:
x = x.expand(1, -1, -1)

h = x.float() # [B, n, 1024]

h = self._fc1(h) # [B, n, 256]

# ---->pad
H = h.shape[1]
_H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H)))
add_length = _H * _W - H
h = torch.cat([h, h[:, :add_length, :]], dim=1) # [B, N, 256]

# ---->cls_token
cls_tokens = self.cls_token.expand(1, -1, -1).cuda()
h = torch.cat((cls_tokens, h), dim=1)


for layer in self.layers:
if isinstance(layer, TransLayer):
h = layer(h)
else:
h = layer(h, _H, _W)
h = self.norm(h)[:,0]

logits = self.classifier(h) # [B, n_classes]
Y_prob = F.softmax(logits, dim=1)
Y_hat = torch.topk(logits, 1, dim=1)[1]
A_raw = None
results_dict = None
if self.survival:
Y_hat = torch.topk(logits, 1, dim = 1)[1]
hazards = torch.sigmoid(logits)
S = torch.cumprod(1 - hazards, dim=1)
return hazards, S, Y_hat, None, None
return logits, Y_prob, Y_hat, A_raw, results_dict

def relocate(self):
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._fc1 = self._fc1.to(device)
self.layers = self.layers .to(device)
self.norm = self.norm.to(device)
self.classifier = self.classifier.to(device)



Expand All @@ -560,4 +351,4 @@ def relocate(self):
# results_dict = model1(data)
print(model2)
results_dict = model2(data)
print(results_dict)
print(results_dict)

0 comments on commit 6f57883

Please sign in to comment.