8
8
9
9
import torch
10
10
11
- from lerobot .common .datasets .lerobot_dataset import LeRobotDataset , LeRobotDatasetMetadata
11
+ from lerobot .common .datasets .lerobot_dataset import (
12
+ LeRobotDataset ,
13
+ LeRobotDatasetMetadata ,
14
+ )
12
15
from lerobot .common .datasets .utils import dataset_to_policy_features
13
16
from lerobot .common .policies .diffusion .configuration_diffusion import DiffusionConfig
14
17
from lerobot .common .policies .diffusion .modeling_diffusion import DiffusionPolicy
@@ -34,12 +37,18 @@ def main():
34
37
# - dataset stats: for normalization and denormalization of input/outputs
35
38
dataset_metadata = LeRobotDatasetMetadata ("lerobot/pusht" )
36
39
features = dataset_to_policy_features (dataset_metadata .features )
37
- output_features = {key : ft for key , ft in features .items () if ft .type is FeatureType .ACTION }
38
- input_features = {key : ft for key , ft in features .items () if key not in output_features }
40
+ output_features = {
41
+ key : ft for key , ft in features .items () if ft .type is FeatureType .ACTION
42
+ }
43
+ input_features = {
44
+ key : ft for key , ft in features .items () if key not in output_features
45
+ }
39
46
40
47
# Policies are initialized with a configuration class, in this case `DiffusionConfig`. For this example,
41
48
# we'll just use the defaults and so no arguments other than input/output features need to be passed.
42
- cfg = DiffusionConfig (input_features = input_features , output_features = output_features )
49
+ cfg = DiffusionConfig (
50
+ input_features = input_features , output_features = output_features
51
+ )
43
52
44
53
# We can now instantiate our policy with this config and the dataset stats.
45
54
policy = DiffusionPolicy (cfg , dataset_stats = dataset_metadata .stats )
@@ -49,8 +58,12 @@ def main():
49
58
# Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames
50
59
# which can differ for inputs, outputs and rewards (if there are some).
51
60
delta_timestamps = {
52
- "observation.image" : [i / dataset_metadata .fps for i in cfg .observation_delta_indices ],
53
- "observation.state" : [i / dataset_metadata .fps for i in cfg .observation_delta_indices ],
61
+ "observation.image" : [
62
+ i / dataset_metadata .fps for i in cfg .observation_delta_indices
63
+ ],
64
+ "observation.state" : [
65
+ i / dataset_metadata .fps for i in cfg .observation_delta_indices
66
+ ],
54
67
"action" : [i / dataset_metadata .fps for i in cfg .action_delta_indices ],
55
68
}
56
69
@@ -63,7 +76,24 @@ def main():
63
76
# Load the previous action (-0.1), the next action to be executed (0.0),
64
77
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
65
78
# used to supervise the policy.
66
- "action" : [- 0.1 , 0.0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.9 , 1.0 , 1.1 , 1.2 , 1.3 , 1.4 ],
79
+ "action" : [
80
+ - 0.1 ,
81
+ 0.0 ,
82
+ 0.1 ,
83
+ 0.2 ,
84
+ 0.3 ,
85
+ 0.4 ,
86
+ 0.5 ,
87
+ 0.6 ,
88
+ 0.7 ,
89
+ 0.8 ,
90
+ 0.9 ,
91
+ 1.0 ,
92
+ 1.1 ,
93
+ 1.2 ,
94
+ 1.3 ,
95
+ 1.4 ,
96
+ ],
67
97
}
68
98
69
99
# We can then instantiate the dataset with these delta_timestamps configuration.
0 commit comments