Skip to content

Commit bcfe712

Browse files
committed
! fixes for pytorch 0.2.0 release
1 parent 0fa5847 commit bcfe712

File tree

6 files changed

+12
-8
lines changed

6 files changed

+12
-8
lines changed

core/heads/dynamic_write_head.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ def _allocation(self, usage_vb, epsilon=1e-6):
7272
# TODO: seems we have to wait for this PR: https://github.com/pytorch/pytorch/pull/1439
7373
prod_sorted_usage_vb = fake_cumprod(cat_sorted_usage_vb)
7474
# prod_sorted_usage_vb = torch.cumprod(cat_sorted_usage_vb, dim=1) # TODO: use this once the PR is ready
75-
alloc_weight_vb = (1 - sorted_usage_vb) * prod_sorted_usage_vb # equ. (1)
75+
# alloc_weight_vb = (1 - sorted_usage_vb) * prod_sorted_usage_vb # equ. (1) # 0.1.12
76+
alloc_weight_vb = (1 - sorted_usage_vb) * prod_sorted_usage_vb.squeeze() # equ. (1) # 0.2.0
7677
_, indices_vb = torch.topk(indices_vb, k=self.mem_hei, dim=1, largest=False)
7778
alloc_weight_vb = alloc_weight_vb.gather(1, indices_vb)
7879
return alloc_weight_vb
@@ -187,7 +188,8 @@ def _update_precedence_weights(self, prev_preced_vb):
187188
returns:
188189
preced_vb: [batch_size x num_write_heads x mem_hei]
189190
"""
190-
write_sum_vb = torch.sum(self.wl_curr_vb, 2)
191+
# write_sum_vb = torch.sum(self.wl_curr_vb, 2) # 0.1.12
192+
write_sum_vb = torch.sum(self.wl_curr_vb, 2, keepdim=True) # 0.2.0
191193
return (1 - write_sum_vb).expand_as(prev_preced_vb) * prev_preced_vb + self.wl_curr_vb
192194

193195
def _temporal_link(self, prev_link_vb, prev_preced_vb):

core/heads/static_head.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ def _location_focus(self):
8787
wg_vb = self.wc_vb * self.gate_vb + self.wl_prev_vb * (1. - self.gate_vb)
8888
ws_vb = self._shift(wg_vb, self.shift_vb)
8989
wp_vb = ws_vb.pow(self.gamma_vb.expand_as(ws_vb))
90-
self.wl_curr_vb = wp_vb / wp_vb.sum(2).expand_as(wp_vb)
90+
# self.wl_curr_vb = wp_vb / wp_vb.sum(2).expand_as(wp_vb) # 0.1.12
91+
self.wl_curr_vb = wp_vb / wp_vb.sum(2, keepdim=True).expand_as(wp_vb) # 0.2.0
9192

9293
def forward(self, hidden_vb, memory_vb):
9394
# outputs for computing addressing for heads

utils/fake_ops.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def fake_cumprod(vb):
2222
mul_mask_vb[i, :, :i+1] = 1
2323
add_mask_vb = 1 - mul_mask_vb
2424
vb = vb.expand_as(mul_mask_vb) * mul_mask_vb + add_mask_vb
25-
vb = torch.prod(vb, 2).transpose(0, 2)
25+
# vb = torch.prod(vb, 2).transpose(0, 2) # 0.1.12
26+
vb = torch.prod(vb, 2, keepdim=True).transpose(0, 2) # 0.2.0
2627
# print(real_cumprod - vb.data) # NOTE: checked, ==0
2728
return vb

utils/helpers.py

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import print_function
44
import logging
55
import numpy as np
6-
import cv2
76
from collections import namedtuple
87

98
def loggerConfig(log_file, verbose=2):

utils/options.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ def __init__(self):
2222

2323
# training signature
2424
self.machine = "daim" # "machine_id"
25-
self.timestamp = "17080200" # "yymmdd##"
25+
self.timestamp = "17080800" # "yymmdd##"
2626
# training configuration
2727
self.mode = 1 # 1(train) | 2(test model_file)
28-
self.config = 1
28+
self.config = 1
2929

3030
self.seed = 1
3131
self.render = False # whether render the window from the original envs or not

utils/similarities.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ def batch_cosine_sim(u, v, epsilon=1e-6):
1212
"""
1313
assert u.dim() == 3 and v.dim() == 3
1414
numerator = torch.bmm(u, v.transpose(1, 2))
15-
denominator = torch.sqrt(torch.bmm(u.norm(2, 2).pow(2) + epsilon, v.norm(2, 2).pow(2).transpose(1, 2) + epsilon))
15+
# denominator = torch.sqrt(torch.bmm(u.norm(2, 2).pow(2) + epsilon, v.norm(2, 2).pow(2).transpose(1, 2) + epsilon)) # 0.1.12
16+
denominator = torch.sqrt(torch.bmm(u.norm(2, 2, keepdim=True).pow(2) + epsilon, v.norm(2, 2, keepdim=True).pow(2).transpose(1, 2) + epsilon)) # 0.2.0
1617
k = numerator / (denominator + epsilon)
1718
return k
1819

0 commit comments

Comments
 (0)