Skip to content

Commit

Permalink
get more demos
Browse files Browse the repository at this point in the history
  • Loading branch information
tribhi committed Feb 2, 2024
1 parent 1d2ed93 commit 144ad4b
Show file tree
Hide file tree
Showing 6 changed files with 8,003 additions and 27 deletions.
18 changes: 12 additions & 6 deletions loader/data_loader_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def __getitem__(self, index):
full_traj = np.load(f)

len, robot_traj = get_traj_length_unique(full_traj)
robot_traj = full_traj
len = robot_traj.shape[0]
if len == 1:
robot_traj = np.vstack((robot_traj, np.array([robot_traj[0,0]-1, robot_traj[0,1]])))
print(robot_traj)
Expand All @@ -162,25 +164,27 @@ def __getitem__(self, index):
len, human_past_traj = get_traj_length_unique(full_traj)

human_past_traj = transpose_traj(human_past_traj)

past_traj_len = len
with open(self.image_fol+"/human_traj_fixed.npy", 'rb') as f:
full_traj = np.load(f)

len, human_future_traj = get_traj_length_unique(full_traj)
human_future_traj = transpose_traj(human_future_traj)
# if len(human_past_traj) == 0:
# human_past_traj = np.array([human_future_traj[0]])
np.append(human_past_traj, np.array([human_future_traj[0]]))
# human_past_traj = np.append(human_past_traj, np.array([human_future_traj[0]]))
# if not is_valid_traj(future_traj) or not is_valid_traj(future_other_traj):
# print("Bad future traj in ", self.image_fol)
# embed()

if past_traj_len == 0:
human_past_traj = np.array([[human_future_traj[0,0], human_future_traj[0,1]]])
# visualize rgb
# with open(self.image_fol+"/sdf.npy", 'rb') as f:
# sdf_feat = np.load(f)
feat = np.concatenate((goal_sink_feat, semantic_img_feat), axis = 0)
### Add the traj features
human_traj_feature = np.zeros(goal_sink_feat.shape)
print("debug", human_past_traj)
human_traj_feature[0,list(np.array(human_past_traj[:,0], np.int)), list(np.array(human_past_traj[:,1], np.int))] = 100
# other_traj_feature = np.zeros(goal_sink_feat.shape)

Expand All @@ -202,7 +206,7 @@ def __getitem__(self, index):
# print("for i mean is std is ", i, np.mean(feat[i]), np.std(feat[i]))
# print("After normalize min max", np.min(feat[i]), np.max(feat[i]))


print("what is the error ",robot_traj)
robot_traj = self.auto_pad_future(robot_traj[:, :2])
human_past_traj = self.auto_pad_past(human_past_traj[:, :2])
# future_other_traj = self.auto_pad_future(future_other_traj[:, :2])
Expand Down Expand Up @@ -239,10 +243,12 @@ def auto_pad_future(self, traj):
traj = traj[:self.grid_size, :]
#raise ValueError('traj length {} must be less than grid_size {}'.format(traj.shape[0], self.grid_size))
pad_len = fixed_len - traj.shape[0]

pad_list = []
for i in range(int(np.ceil(pad_len/2))):
pad_list.append([traj[-1,0]-1, [-1,1]])
pad_list.append(traj[-1])
pad_list.append([(traj[-1,0]-1), traj[-1,1]])
pad_list.append([traj[-1,0], traj[-1,1]])
print(pad_list)
pad_array = np.array(pad_list[:pad_len])
output = np.vstack((traj, pad_array))
return output
Loading

0 comments on commit 144ad4b

Please sign in to comment.