Skip to content

Commit d919c20

Browse files
authored
Add files via upload
1 parent 67352d7 commit d919c20

34 files changed

+2434
-0
lines changed

Diff for: data/sample.png

361 KB
Loading

Diff for: dataset.py

+179
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# -*- coding: utf-8 -*-
2+
import numpy as np
3+
import os
4+
import random
5+
import cv2
6+
7+
8+
def load_mnist(traing_num=50000):
9+
dat = np.load("data/mnist.npz")
10+
X = dat['x_train'][:traing_num]
11+
Y = dat['y_train'][:traing_num]
12+
X_test = dat['x_test']
13+
Y_test = dat['y_test']
14+
Y = Y.reshape((Y.shape[0],))
15+
Y_test = Y_test.reshape((Y_test.shape[0],))
16+
return X, Y, X_test, Y_test
17+
18+
19+
def move_step(v0, p0, bounding_box):
20+
xmin, xmax, ymin, ymax = bounding_box
21+
assert (p0[0]>=xmin) and (p0[0]<=xmax) and (p0[1]>=ymin) and (p0[1]<=ymax)
22+
v = v0.copy()
23+
assert v[0] != 0.0 and v[1] != 0.0
24+
p = v0 + p0
25+
while (p[0]<xmin) or (p[0]>xmax) or (p[1]<ymin) or (p[1]>ymax):
26+
vx, vy = v
27+
x, y = p
28+
dist = np.zeros((4,))
29+
dist[0] = abs(x-xmin) if ymin <= (xmin-x)*vy/vx+y<=ymax else np.inf
30+
dist[1] = abs(x-xmax) if ymin <= (xmax-x)*vy/vx+y<=ymax else np.inf
31+
dist[2] = abs((y-ymin)*vx/vy) if xmin <= (ymin-y)*vx/vy+x<=xmax else np.inf
32+
dist[3] = abs((y-ymax)*vx/vy) if xmin <= (ymax-y)*vx/vy+x<=xmax else np.inf
33+
n = np.argmin(dist)
34+
if n == 0:
35+
v[0] = -v[0]
36+
p[0] = 2*xmin-p[0]
37+
elif n == 1:
38+
v[0] = -v[0]
39+
p[0] = 2*xmax-p[0]
40+
elif n == 2:
41+
v[1] = -v[1]
42+
p[1] = 2*ymin-p[1]
43+
elif n == 3:
44+
v[1] = -v[1]
45+
p[1] = 2*ymax-p[1]
46+
else:
47+
assert False
48+
return v, p
49+
50+
51+
52+
class MovingMNISTIterator(object):
53+
def __init__(self):
54+
self.mnist_train_img, self.mnist_train_label,self.mnist_test_img, self.mnist_test_label = load_mnist()
55+
56+
def sample(self, digitnum,
57+
width,
58+
height,
59+
seqlen,
60+
batch_size,
61+
index_range=(0, 50000)):
62+
""""""
63+
"""
64+
65+
:param digitnum: The num of the digits
66+
:param width: The width of the generated images
67+
:param height: The height of the generated images
68+
:param seqlen: The length of the image sequence
69+
:param index_range: by default
70+
:return:
71+
"""
72+
character_indices = np.random.randint(low=index_range[0], high=index_range[1],size=(batch_size, digitnum))
73+
angles = np.random.random((batch_size, digitnum)) * (2 * np.pi)
74+
magnitudes = np.random.random((batch_size, digitnum)) * (5 - 3) + 3
75+
velocities = np.zeros((batch_size, digitnum, 2), dtype='float32')
76+
velocities[..., 0] = magnitudes * np.cos(angles)
77+
velocities[..., 1] = magnitudes * np.sin(angles)
78+
xmin = 14.0
79+
xmax = float(width) - 14.0
80+
ymin = 14.0
81+
ymax = float(height) - 14.0
82+
positions = np.random.uniform(low=xmin, high=xmax,size=(batch_size, digitnum, 2))
83+
seq = np.zeros((seqlen, batch_size, 1, height, width), dtype='uint8')
84+
for i in range(batch_size):
85+
for j in range(digitnum):
86+
ind = character_indices[i, j]
87+
v = velocities[i, j, :]
88+
p = positions[i, j, :]
89+
img = self.mnist_train_img[ind].reshape((28, 28))
90+
for k in range(seqlen):
91+
topleft_y = int(p[0] - img.shape[0] / 2)
92+
topleft_x = int(p[1] - img.shape[1] / 2)
93+
seq[k, i, 0, topleft_y:topleft_y + 28, topleft_x:topleft_x + 28] = np.maximum(seq[k, i, 0, topleft_y:topleft_y + 28, topleft_x:topleft_x + 28],img)
94+
v, p = move_step(v, p, [xmin, xmax, ymin, ymax])
95+
return seq
96+
97+
98+
99+
class MovingMnist_Generation(object):
100+
def __init__(self,digtnum, width, height, seq_length):
101+
self.digtnum = digtnum
102+
self.width = width
103+
self.height = height
104+
self.seq_length = seq_length
105+
106+
def next_batch(self,batch_size,next_seqlen=1,return_one=True,norm=False):
107+
movingmnist = MovingMNISTIterator()
108+
109+
110+
sample = movingmnist.sample(digitnum=self.digtnum,
111+
width=self.width,
112+
height=self.height,
113+
seqlen=self.seq_length+next_seqlen,
114+
batch_size=batch_size)
115+
sample = np.transpose(sample,(1,0,2,3,4))
116+
117+
118+
x_batch = sample[:,0:self.seq_length,:,:,:]
119+
y_batch = sample[:,self.seq_length:(self.seq_length+next_seqlen),:,:,:]
120+
121+
if return_one is True and next_seqlen == 1:
122+
y_batch = np.reshape(y_batch,(batch_size,1,self.width,self.height))
123+
124+
# return the x_batch with shape(batchsize,seq_length,channels,width,height)
125+
# return the y_batch with shape(batchsize,seq_length,channels,width,height) or (batchsize,channels,width,height) when y_batch has only one timestep
126+
if norm:
127+
return x_batch/255.0 , y_batch/255.0
128+
else:
129+
return x_batch,y_batch
130+
131+
132+
133+
class SCMD_Generation(object):
134+
def __init__(self,seq_length=5,next_seq=1,isTrain=True,return_one=True,norm=False,baseline=False):
135+
self.seq_length = seq_length # the length of the squence for training
136+
self.next_seq = next_seq
137+
self.is_train = isTrain
138+
self.return_one = return_one
139+
self.norm = norm
140+
self.baseline = baseline
141+
self.train_root = "data/SCMD2016/TRAIN"
142+
self.test_root = "data/SCMD2016/TEST"
143+
self.data_length = 0
144+
145+
146+
def next_batch(self,batchsize):
147+
x_batch = np.ndarray(shape=(batchsize,self.seq_length,1,200,200),dtype=np.float32)
148+
y_batch = np.ndarray(shape=(batchsize,1,1,200,200),dtype=np.float32)
149+
150+
if self.is_train:
151+
datalist = os.listdir(self.train_root)
152+
else:
153+
datalist = os.listdir(self.test_root)
154+
self.data_length = len(datalist)
155+
156+
random_order = random.sample(range(1,self.data_length),batchsize)
157+
if self.is_train:
158+
root_path = self.train_root
159+
else:
160+
root_path = self.test_root
161+
162+
for i in range(batchsize):
163+
for k in range(self.seq_length):
164+
x_batch[i,k,0,:,:] = cv2.imread(root_path+"/SCMD_"+str(random_order[i])+"/"+str(k+1)+".png",cv2.IMREAD_GRAYSCALE)
165+
y_batch[i,0,0,:,:] = cv2.imread(root_path+"/SCMD_"+str(random_order[i])+"/"+str(self.seq_length+1)+".png",cv2.IMREAD_GRAYSCALE)
166+
167+
x_batch,y_batch = x_batch*10.0,y_batch*10.0
168+
169+
if self.baseline:
170+
x_batch = x_batch[:,:,:,100,100]
171+
y_batch = y_batch[:,:,:,100,100]
172+
173+
if self.norm:
174+
x_batch = x_batch/255.0
175+
y_batch = y_batch/255.0
176+
177+
178+
return x_batch,y_batch
179+

Diff for: demo.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# -*- coding: utf-8 -*-
2+
import torch
3+
import dataset
4+
import cv2
5+
6+
model_clstm_m = "checkpoint/clstm_m/model_best.pth"
7+
model_clstm_s = "checkpoint/clstm_s/model_best.pth"
8+
mdoel_forecast_clstm_m = "checkpoint/forecast_clstm_m/model_best.pth"
9+
mdoel_forecast_clstm_s = "checkpoint/forecast_clstm_s/model_best.pth"
10+
model_forecast_clstm_forecaster="checkpoint/forecast_clstm_forecaster/model_best.pth"
11+
12+
13+
14+
def demo_mnist(model_path):
15+
model = torch.load(model_path)
16+
mnist = dataset.MovingMnist_Generation(digtnum=2,
17+
width=64,
18+
height=64,
19+
seq_length=9)
20+
x_batch,y_batch = mnist.next_batch(batch_size=1,
21+
next_seqlen=1,
22+
return_one=False,
23+
norm=False)
24+
x_batch = torch.from_numpy(x_batch).float()
25+
26+
output = model.forward(x_batch)
27+
output = output.detach().cpu().numpy()
28+
cv2.imwrite("demo_mnist.png",output[0][0][0])
29+
30+
31+
32+
def demo_scmd(model_path):
33+
model = torch.load(model_path)
34+
scmd = dataset.SCDMD_Generation()
35+
x_batch,y_batch = scmd.next_batch(batchsize=1)
36+
x_batch = torch.from_numpy(x_batch).float()
37+
38+
output = model.forward(x_batch)
39+
output = output.detach().cpu().numpy()
40+
cv2.imwrite("demo_scmd.png",output[0][0][0])

0 commit comments

Comments
 (0)