Skip to content

Commit

Permalink
"initial"
Browse files Browse the repository at this point in the history
  • Loading branch information
xuanbinh-nguyen96 committed Jun 10, 2021
1 parent 5713fdb commit 45aa122
Show file tree
Hide file tree
Showing 26 changed files with 3,602 additions and 0 deletions.
13 changes: 13 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Ignore data folder for light-weight
data_RAD
data_RAD/
data
data/
saved_models
saved_models/
results
results/
.idea
.idea/
*.pyc
*.swp
128 changes: 128 additions & 0 deletions attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""
This code is extended from Jin-Hwa Kim, Jaehyun Jun, Byoung-Tak Zhang's repository.
https://github.com/jnhwkim/ban-vqa
This code is modified from ZCYang's repository.
https://github.com/zcyang/imageqa-san
"""
import torch
import torch.nn as nn
from torch.nn.utils.weight_norm import weight_norm
from bc import BCNet

# Bilinear Attention
class BiAttention(nn.Module):
def __init__(self, x_dim, y_dim, z_dim, glimpse, dropout=[.2,.5]):
super(BiAttention, self).__init__()

self.glimpse = glimpse
self.logits = weight_norm(BCNet(x_dim, y_dim, z_dim, glimpse, dropout=dropout, k=3), \
name='h_mat', dim=None)

def forward(self, v, q, v_mask=True):
"""
v: [batch, k, vdim]
q: [batch, qdim]
"""
p, logits = self.forward_all(v, q, v_mask)
return p, logits

def forward_all(self, v, q, v_mask=True):
v_num = v.size(1)
q_num = q.size(1)
logits = self.logits(v, q) # b x g x v x q

if v_mask:
mask = (0 == v.abs().sum(2)).unsqueeze(1).unsqueeze(3).expand(logits.size())
logits.data.masked_fill_(mask.data, -float('inf'))

p = nn.functional.softmax(logits.view(-1, self.glimpse, v_num * q_num), 2)
return p.view(-1, self.glimpse, v_num, q_num), logits

# Stacked Attention
class StackedAttention(nn.Module):
def __init__(self, num_stacks, img_feat_size, ques_feat_size, att_size, output_size, drop_ratio):
super(StackedAttention, self).__init__()

self.img_feat_size = img_feat_size
self.ques_feat_size = ques_feat_size
self.att_size = att_size
self.output_size = output_size
self.drop_ratio = drop_ratio
self.num_stacks = num_stacks
self.layers = nn.ModuleList()

self.dropout = nn.Dropout(drop_ratio)
self.tanh = nn.Tanh()
self.softmax = nn.Softmax(dim=1)

self.fc11 = nn.Linear(ques_feat_size, att_size, bias=True)
self.fc12 = nn.Linear(img_feat_size, att_size, bias=False)
self.fc13 = nn.Linear(att_size, 1, bias=True)

for stack in range(num_stacks - 1):
self.layers.append(nn.Linear(att_size, att_size, bias=True))
self.layers.append(nn.Linear(img_feat_size, att_size, bias=False))
self.layers.append(nn.Linear(att_size, 1, bias=True))

def forward(self, img_feat, ques_feat, v_mask=True):

# Batch size
B = ques_feat.size(0)

# Stack 1
ques_emb_1 = self.fc11(ques_feat)
img_emb_1 = self.fc12(img_feat)

# Compute attention distribution
h1 = self.tanh(ques_emb_1.view(B, 1, self.att_size) + img_emb_1)
h1_emb = self.fc13(self.dropout(h1))
# Mask actual bounding box sizes before calculating softmax
if v_mask:
mask = (0 == img_emb_1.abs().sum(2)).unsqueeze(2).expand(h1_emb.size())
h1_emb.data.masked_fill_(mask.data, -float('inf'))

p1 = self.softmax(h1_emb)

# Compute weighted sum
img_att_1 = img_emb_1*p1
weight_sum_1 = torch.sum(img_att_1, dim=1)

# Combine with question vector
u1 = ques_emb_1 + weight_sum_1

# Other stacks
us = []
ques_embs = []
img_embs = []
hs = []
h_embs =[]
ps = []
img_atts = []
weight_sums = []

us.append(u1)
for stack in range(self.num_stacks - 1):
ques_embs.append(self.layers[3 * stack + 0](us[-1]))
img_embs.append(self.layers[3 * stack + 1](img_feat))

# Compute attention distribution
hs.append(self.tanh(ques_embs[-1].view(B, -1, self.att_size) + img_embs[-1]))
h_embs.append(self.layers[3*stack + 2](self.dropout(hs[-1])))
# Mask actual bounding box sizes before calculating softmax
if v_mask:
mask = (0 == img_embs[-1].abs().sum(2)).unsqueeze(2).expand(h_embs[-1].size())
h_embs[-1].data.masked_fill_(mask.data, -float('inf'))
ps.append(self.softmax(h_embs[-1]))

# Compute weighted sum
img_atts.append(img_embs[-1] * ps[-1])
weight_sums.append(torch.sum(img_atts[-1], dim=1))

# Combine with previous stack
ux = us[-1] + weight_sums[-1]

# Combine with previous stack by multiple
us.append(ux)

return us[-1]
60 changes: 60 additions & 0 deletions auto_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""
Auto-encoder module for MEVF model
This code is written by Binh X. Nguyen and Binh D. Nguyen
<link paper>
"""
import torch.nn as nn
from torch.distributions.normal import Normal
import functools
import operator
import torch.nn.functional as F
import torch

def add_noise(images, mean=0, std=0.1):
normal_dst = Normal(mean, std)
noise = normal_dst.sample(images.shape)
noisy_image = noise + images
return noisy_image

def print_model(model):
print(model)
nParams = 0
for w in model.parameters():
nParams += functools.reduce(operator.mul, w.size(), 1)
print(nParams)

class Auto_Encoder_Model(nn.Module):
def __init__(self):
super(Auto_Encoder_Model, self).__init__()
# Encoder
self.conv1 = nn.Conv2d(1, 64, padding=1, kernel_size=3)
self.max_pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(64, 32, padding=1, kernel_size=3)
self.max_pool2 = nn.MaxPool2d(2)
self.conv3 = nn.Conv2d(32, 16, padding=1, kernel_size=3)

# Decoder
self.tran_conv1 = nn.ConvTranspose2d(16, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
self.conv4 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
self.tran_conv2 = nn.ConvTranspose2d(32, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
self.conv5 = nn.Conv2d(64, 1, kernel_size=3, padding=1)

def forward_pass(self, x):
output = F.relu(self.conv1(x))
output = self.max_pool1(output)
output = F.relu(self.conv2(output))
output = self.max_pool2(output)
output = F.relu(self.conv3(output))
return output

def reconstruct_pass(self, x):
output = F.relu(self.tran_conv1(x))
output = F.relu(self.conv4(output))
output = F.relu(self.tran_conv2(output))
output = torch.sigmoid(self.conv5(output))
return output

def forward(self, x):
output = self.forward_pass(x)
output = self.reconstruct_pass(output)
return output
Loading

0 comments on commit 45aa122

Please sign in to comment.