Skip to content

Commit

Permalink
add train_diffusion.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark-tz committed Feb 25, 2025
1 parent 60ff84f commit 126094a
Show file tree
Hide file tree
Showing 10 changed files with 680 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ git_stats

3rdParty

__temp__
__temp__*

.vscode
# python
Expand Down
23 changes: 23 additions & 0 deletions ZBin/py_playground/rocos/conf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os
from omegaconf import OmegaConf
def get_config(name="default"):
path = []
if os.path.isabs(name):
path.append(os.path.dirname(name))
else:
path.append(os.path.dirname(__file__))
path.append(os.getcwd())
path.append(os.path.join(os.path.expanduser("~"), ".config", "rocos"))
for p in path:
try:
cfg = OmegaConf.load(os.path.join(p, name + ".yaml"))
break
except FileNotFoundError:
cfg = None
if cfg is None:
raise FileNotFoundError(f"Config file {name}.yaml not found in {path}")
return cfg

if __name__ == "__main__":
cfg = get_config("default")
print(cfg)
23 changes: 23 additions & 0 deletions ZBin/py_playground/rocos/conf/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
ssl:
addr:
local: 127.0.0.1
ref: 224.5.23.1
vision: 224.5.23.2
zss: 233.233.233.233
port:
# ssl
vision: 10005
ref: 10003
# grSim
sim_cmd: 20011
sim_status: [30011, 30012]
sim_control: 10300
sim_blue_control: 10301
sim_yellow_control: 10302
# rocos : Core <> Client
cmd: [50011, 50012]
status: [60001, 60002]
fusion_vision: [23333, 23334]
fusion_vision_py: [41001, 41002]
debug_msg: [20001, 20002]
debug_heapmap: [20003, 20004]
32 changes: 27 additions & 5 deletions ZBin/py_playground/rocos/log/data/loader.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,39 @@
import sys
from torch.utils.data import DataLoader

sys.path.append('../../..')
from rocos.log.data.tracker_vision import TrackerVisionDataset

def data_loader(args, path):
dset = TrackerVisionDataset(
path,
obs_len=args.obs_len,
pred_len=args.pred_len,
skip=args.skip,
delim=args.delim)
skip=args.skip)

loader = DataLoader(
dset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.loader_num_workers,
collate_fn=seq_collate)
return dset, loader
num_workers=args.loader_num_workers)
return dset, loader


if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='__log__')
parser.add_argument('--obs_len', type=int, default=8)
parser.add_argument('--pred_len', type=int, default=12)
parser.add_argument('--skip', type=int, default=5)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--loader_num_workers', type=int, default=4)
parser.add_argument('--delim', type=str, default=' ')
args = parser.parse_args()

dset, loader = data_loader(args, args.data_dir)
for data in iter(loader):
print(len(data))
for d in data:
print(d.shape)
break
47 changes: 27 additions & 20 deletions ZBin/py_playground/rocos/log/data/tracker_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ def __init__(self, data_dir, obs_len = 8, pred_len = 12, skip = 5):
self.skip = skip
self.seq_len = self.obs_len + self.pred_len

self.data = []
self.dtype = np.float32

self.obs_data = []
self.pred_data = []

all_log_files = []
# get log files
Expand All @@ -31,16 +34,18 @@ def __init__(self, data_dir, obs_len = 8, pred_len = 12, skip = 5):
content = get_content(log_file)
msgs, type = read_log(content)
assert type == TYPE.SSL_VISION_TRACKER_2020, "Unknown message type"
data_seqs = self.generate_seq(msgs, TrackerWrapperPacket)
self.data.extend(data_seqs)
obs_seqs, pred_seqs = self.generate_seq(msgs, TrackerWrapperPacket)
self.obs_data.extend(obs_seqs)
self.pred_data.extend(pred_seqs)

print(f"Data Dir : {data_dir}, Total search {len(all_log_files)} log files, found {len(self.data)} sequences")
print(f"Data Dir : {data_dir}, Total search {len(all_log_files)} log files, found {len(self.obs_data)} sequences")

def generate_seq(self, msgs, MsgType):
if len(msgs) < LOG_MIN_LEN:
return []
return [], []

dataset_seqs = []
obs_seqs = []
pred_seqs = []

data_seq = []
for data in msgs:
Expand All @@ -53,9 +58,12 @@ def generate_seq(self, msgs, MsgType):
data_seq.append(data)

for i in range(0, len(data_seq) - self.seq_len, self.skip):
dataset_seqs.extend(self.generate_single_seq(data_seq[i:i+self.seq_len]))
single_seq = self.generate_single_seq(data_seq[i:i+self.seq_len])
obs, pred = [d[:self.obs_len] for d in single_seq], [d[self.obs_len:] for d in single_seq]
obs_seqs.append(obs)
pred_seqs.append(pred)

return dataset_seqs
return obs_seqs, pred_seqs

def parse_single_msg(self, msg: TrackerWrapperPacket):
frame = msg.tracked_frame
Expand All @@ -67,13 +75,13 @@ def parse_single_msg(self, msg: TrackerWrapperPacket):

for r in frame.robots:
# robot_data
rd = np.array([r.pos.x, r.pos.y, r.orientation, r.vel.x, r.vel.y, r.vel_angular, r.visibility])
rd = np.array([r.pos.x, r.pos.y, r.orientation, r.vel.x, r.vel.y, r.vel_angular],dtype=self.dtype)
if r.robot_id.team_color == TeamColor.TEAM_COLOR_BLUE:
robot_blue[r.robot_id.id] = rd
elif r.robot_id.team_color == TeamColor.TEAM_COLOR_YELLOW:
robot_yellow[r.robot_id.id] = rd
return {
"ball": np.array([ball.pos.x, ball.pos.y, ball.pos.z, ball.vel.x, ball.vel.y, ball.vel.z, ball.visibility]),
"ball": np.array([ball.pos.x, ball.pos.y, ball.vel.x, ball.vel.y],dtype=self.dtype),
"blue": robot_blue,
"yellow": robot_yellow
}
Expand All @@ -90,9 +98,9 @@ def generate_single_seq(self, seqs):
checked_yellow_id = list(set(checked_yellow_id) & set(seq["yellow"].keys()))
assert len(checked_blue_id) == len(blue_id) and len(checked_yellow_id) == len(yellow_id), "robot number not correct"

_ball_seq = np.empty((0, seqs[0]["ball"].shape[0]))
_blue_seq = np.empty((0, len(blue_id), seqs[0]["blue"][blue_id[0]].shape[0]))
_yellow_seq = np.empty((0, len(yellow_id), seqs[0]["yellow"][yellow_id[0]].shape[0]))
_ball_seq = np.empty((0, seqs[0]["ball"].shape[0]), dtype=self.dtype)
_blue_seq = np.empty((0, len(blue_id), seqs[0]["blue"][blue_id[0]].shape[0]), dtype=self.dtype)
_yellow_seq = np.empty((0, len(yellow_id), seqs[0]["yellow"][yellow_id[0]].shape[0]), dtype=self.dtype)

for seq in seqs:
ball = seq["ball"]
Expand All @@ -102,17 +110,16 @@ def generate_single_seq(self, seqs):
_blue_seq = np.vstack((_blue_seq, blue[None, :, :]))
_yellow_seq = np.vstack((_yellow_seq, yellow[None, :, :]))

return [{
"ball": _ball_seq,
"blue": _blue_seq,
"yellow": _yellow_seq
}]
return [_ball_seq, _blue_seq, _yellow_seq]

def __len__(self):
return len(self.data)
return len(self.obs_data)

def __getitem__(self, idx):
return self.data[idx]
return [
self.obs_data[idx],
self.pred_data[idx]
]

if __name__ == "__main__":
data_dir = sys.argv[1]
Expand Down
1 change: 1 addition & 0 deletions ZBin/py_playground/rocos/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from rocos.nn.unet1d import ConditionalUnet1D
Empty file.
Loading

0 comments on commit 126094a

Please sign in to comment.