|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +import numpy as np |
| 3 | +import os |
| 4 | +import random |
| 5 | +import cv2 |
| 6 | + |
| 7 | + |
| 8 | +def load_mnist(traing_num=50000): |
| 9 | + dat = np.load("data/mnist.npz") |
| 10 | + X = dat['x_train'][:traing_num] |
| 11 | + Y = dat['y_train'][:traing_num] |
| 12 | + X_test = dat['x_test'] |
| 13 | + Y_test = dat['y_test'] |
| 14 | + Y = Y.reshape((Y.shape[0],)) |
| 15 | + Y_test = Y_test.reshape((Y_test.shape[0],)) |
| 16 | + return X, Y, X_test, Y_test |
| 17 | + |
| 18 | + |
| 19 | +def move_step(v0, p0, bounding_box): |
| 20 | + xmin, xmax, ymin, ymax = bounding_box |
| 21 | + assert (p0[0]>=xmin) and (p0[0]<=xmax) and (p0[1]>=ymin) and (p0[1]<=ymax) |
| 22 | + v = v0.copy() |
| 23 | + assert v[0] != 0.0 and v[1] != 0.0 |
| 24 | + p = v0 + p0 |
| 25 | + while (p[0]<xmin) or (p[0]>xmax) or (p[1]<ymin) or (p[1]>ymax): |
| 26 | + vx, vy = v |
| 27 | + x, y = p |
| 28 | + dist = np.zeros((4,)) |
| 29 | + dist[0] = abs(x-xmin) if ymin <= (xmin-x)*vy/vx+y<=ymax else np.inf |
| 30 | + dist[1] = abs(x-xmax) if ymin <= (xmax-x)*vy/vx+y<=ymax else np.inf |
| 31 | + dist[2] = abs((y-ymin)*vx/vy) if xmin <= (ymin-y)*vx/vy+x<=xmax else np.inf |
| 32 | + dist[3] = abs((y-ymax)*vx/vy) if xmin <= (ymax-y)*vx/vy+x<=xmax else np.inf |
| 33 | + n = np.argmin(dist) |
| 34 | + if n == 0: |
| 35 | + v[0] = -v[0] |
| 36 | + p[0] = 2*xmin-p[0] |
| 37 | + elif n == 1: |
| 38 | + v[0] = -v[0] |
| 39 | + p[0] = 2*xmax-p[0] |
| 40 | + elif n == 2: |
| 41 | + v[1] = -v[1] |
| 42 | + p[1] = 2*ymin-p[1] |
| 43 | + elif n == 3: |
| 44 | + v[1] = -v[1] |
| 45 | + p[1] = 2*ymax-p[1] |
| 46 | + else: |
| 47 | + assert False |
| 48 | + return v, p |
| 49 | + |
| 50 | + |
| 51 | + |
| 52 | +class MovingMNISTIterator(object): |
| 53 | + def __init__(self): |
| 54 | + self.mnist_train_img, self.mnist_train_label,self.mnist_test_img, self.mnist_test_label = load_mnist() |
| 55 | + |
| 56 | + def sample(self, digitnum, |
| 57 | + width, |
| 58 | + height, |
| 59 | + seqlen, |
| 60 | + batch_size, |
| 61 | + index_range=(0, 50000)): |
| 62 | + """""" |
| 63 | + """ |
| 64 | +
|
| 65 | + :param digitnum: The num of the digits |
| 66 | + :param width: The width of the generated images |
| 67 | + :param height: The height of the generated images |
| 68 | + :param seqlen: The length of the image sequence |
| 69 | + :param index_range: by default |
| 70 | + :return: |
| 71 | + """ |
| 72 | + character_indices = np.random.randint(low=index_range[0], high=index_range[1],size=(batch_size, digitnum)) |
| 73 | + angles = np.random.random((batch_size, digitnum)) * (2 * np.pi) |
| 74 | + magnitudes = np.random.random((batch_size, digitnum)) * (5 - 3) + 3 |
| 75 | + velocities = np.zeros((batch_size, digitnum, 2), dtype='float32') |
| 76 | + velocities[..., 0] = magnitudes * np.cos(angles) |
| 77 | + velocities[..., 1] = magnitudes * np.sin(angles) |
| 78 | + xmin = 14.0 |
| 79 | + xmax = float(width) - 14.0 |
| 80 | + ymin = 14.0 |
| 81 | + ymax = float(height) - 14.0 |
| 82 | + positions = np.random.uniform(low=xmin, high=xmax,size=(batch_size, digitnum, 2)) |
| 83 | + seq = np.zeros((seqlen, batch_size, 1, height, width), dtype='uint8') |
| 84 | + for i in range(batch_size): |
| 85 | + for j in range(digitnum): |
| 86 | + ind = character_indices[i, j] |
| 87 | + v = velocities[i, j, :] |
| 88 | + p = positions[i, j, :] |
| 89 | + img = self.mnist_train_img[ind].reshape((28, 28)) |
| 90 | + for k in range(seqlen): |
| 91 | + topleft_y = int(p[0] - img.shape[0] / 2) |
| 92 | + topleft_x = int(p[1] - img.shape[1] / 2) |
| 93 | + seq[k, i, 0, topleft_y:topleft_y + 28, topleft_x:topleft_x + 28] = np.maximum(seq[k, i, 0, topleft_y:topleft_y + 28, topleft_x:topleft_x + 28],img) |
| 94 | + v, p = move_step(v, p, [xmin, xmax, ymin, ymax]) |
| 95 | + return seq |
| 96 | + |
| 97 | + |
| 98 | + |
| 99 | +class MovingMnist_Generation(object): |
| 100 | + def __init__(self,digtnum, width, height, seq_length): |
| 101 | + self.digtnum = digtnum |
| 102 | + self.width = width |
| 103 | + self.height = height |
| 104 | + self.seq_length = seq_length |
| 105 | + |
| 106 | + def next_batch(self,batch_size,next_seqlen=1,return_one=True,norm=False): |
| 107 | + movingmnist = MovingMNISTIterator() |
| 108 | + |
| 109 | + |
| 110 | + sample = movingmnist.sample(digitnum=self.digtnum, |
| 111 | + width=self.width, |
| 112 | + height=self.height, |
| 113 | + seqlen=self.seq_length+next_seqlen, |
| 114 | + batch_size=batch_size) |
| 115 | + sample = np.transpose(sample,(1,0,2,3,4)) |
| 116 | + |
| 117 | + |
| 118 | + x_batch = sample[:,0:self.seq_length,:,:,:] |
| 119 | + y_batch = sample[:,self.seq_length:(self.seq_length+next_seqlen),:,:,:] |
| 120 | + |
| 121 | + if return_one is True and next_seqlen == 1: |
| 122 | + y_batch = np.reshape(y_batch,(batch_size,1,self.width,self.height)) |
| 123 | + |
| 124 | + # return the x_batch with shape(batchsize,seq_length,channels,width,height) |
| 125 | + # return the y_batch with shape(batchsize,seq_length,channels,width,height) or (batchsize,channels,width,height) when y_batch has only one timestep |
| 126 | + if norm: |
| 127 | + return x_batch/255.0 , y_batch/255.0 |
| 128 | + else: |
| 129 | + return x_batch,y_batch |
| 130 | + |
| 131 | + |
| 132 | + |
| 133 | +class SCMD_Generation(object): |
| 134 | + def __init__(self,seq_length=5,next_seq=1,isTrain=True,return_one=True,norm=False,baseline=False): |
| 135 | + self.seq_length = seq_length # the length of the squence for training |
| 136 | + self.next_seq = next_seq |
| 137 | + self.is_train = isTrain |
| 138 | + self.return_one = return_one |
| 139 | + self.norm = norm |
| 140 | + self.baseline = baseline |
| 141 | + self.train_root = "data/SCMD2016/TRAIN" |
| 142 | + self.test_root = "data/SCMD2016/TEST" |
| 143 | + self.data_length = 0 |
| 144 | + |
| 145 | + |
| 146 | + def next_batch(self,batchsize): |
| 147 | + x_batch = np.ndarray(shape=(batchsize,self.seq_length,1,200,200),dtype=np.float32) |
| 148 | + y_batch = np.ndarray(shape=(batchsize,1,1,200,200),dtype=np.float32) |
| 149 | + |
| 150 | + if self.is_train: |
| 151 | + datalist = os.listdir(self.train_root) |
| 152 | + else: |
| 153 | + datalist = os.listdir(self.test_root) |
| 154 | + self.data_length = len(datalist) |
| 155 | + |
| 156 | + random_order = random.sample(range(1,self.data_length),batchsize) |
| 157 | + if self.is_train: |
| 158 | + root_path = self.train_root |
| 159 | + else: |
| 160 | + root_path = self.test_root |
| 161 | + |
| 162 | + for i in range(batchsize): |
| 163 | + for k in range(self.seq_length): |
| 164 | + x_batch[i,k,0,:,:] = cv2.imread(root_path+"/SCMD_"+str(random_order[i])+"/"+str(k+1)+".png",cv2.IMREAD_GRAYSCALE) |
| 165 | + y_batch[i,0,0,:,:] = cv2.imread(root_path+"/SCMD_"+str(random_order[i])+"/"+str(self.seq_length+1)+".png",cv2.IMREAD_GRAYSCALE) |
| 166 | + |
| 167 | + x_batch,y_batch = x_batch*10.0,y_batch*10.0 |
| 168 | + |
| 169 | + if self.baseline: |
| 170 | + x_batch = x_batch[:,:,:,100,100] |
| 171 | + y_batch = y_batch[:,:,:,100,100] |
| 172 | + |
| 173 | + if self.norm: |
| 174 | + x_batch = x_batch/255.0 |
| 175 | + y_batch = y_batch/255.0 |
| 176 | + |
| 177 | + |
| 178 | + return x_batch,y_batch |
| 179 | + |
0 commit comments