Skip to content

Commit 9a68f20

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 2449b3c commit 9a68f20

File tree

84 files changed

+1238
-474
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

84 files changed

+1238
-474
lines changed

.devcontainer/devcontainer.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717
"ms-python.python",
1818
"ms-vscode.cpptools"
1919
]
20-
}
20+
}

docker/manyskill-lerobot-gpu/10_nvidia.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
"ICD" : {
44
"library_path" : "libEGL_nvidia.so.0"
55
}
6-
}
6+
}

docker/manyskill-lerobot-gpu/nvidia_icd.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
"library_path": "libGLX_nvidia.so.0",
55
"api_version" : "1.2.155"
66
}
7-
}
7+
}

docker/manyskill-lerobot-gpu/nvidia_layers.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@
1818
"DISABLE_LAYER_NV_OPTIMUS_1": ""
1919
}
2020
}
21-
}
21+
}

examples/3_train_policy.py

+37-7
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88

99
import torch
1010

11-
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
11+
from lerobot.common.datasets.lerobot_dataset import (
12+
LeRobotDataset,
13+
LeRobotDatasetMetadata,
14+
)
1215
from lerobot.common.datasets.utils import dataset_to_policy_features
1316
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
1417
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
@@ -34,12 +37,18 @@ def main():
3437
# - dataset stats: for normalization and denormalization of input/outputs
3538
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
3639
features = dataset_to_policy_features(dataset_metadata.features)
37-
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
38-
input_features = {key: ft for key, ft in features.items() if key not in output_features}
40+
output_features = {
41+
key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION
42+
}
43+
input_features = {
44+
key: ft for key, ft in features.items() if key not in output_features
45+
}
3946

4047
# Policies are initialized with a configuration class, in this case `DiffusionConfig`. For this example,
4148
# we'll just use the defaults and so no arguments other than input/output features need to be passed.
42-
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
49+
cfg = DiffusionConfig(
50+
input_features=input_features, output_features=output_features
51+
)
4352

4453
# We can now instantiate our policy with this config and the dataset stats.
4554
policy = DiffusionPolicy(cfg, dataset_stats=dataset_metadata.stats)
@@ -49,8 +58,12 @@ def main():
4958
# Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames
5059
# which can differ for inputs, outputs and rewards (if there are some).
5160
delta_timestamps = {
52-
"observation.image": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
53-
"observation.state": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
61+
"observation.image": [
62+
i / dataset_metadata.fps for i in cfg.observation_delta_indices
63+
],
64+
"observation.state": [
65+
i / dataset_metadata.fps for i in cfg.observation_delta_indices
66+
],
5467
"action": [i / dataset_metadata.fps for i in cfg.action_delta_indices],
5568
}
5669

@@ -63,7 +76,24 @@ def main():
6376
# Load the previous action (-0.1), the next action to be executed (0.0),
6477
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
6578
# used to supervise the policy.
66-
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
79+
"action": [
80+
-0.1,
81+
0.0,
82+
0.1,
83+
0.2,
84+
0.3,
85+
0.4,
86+
0.5,
87+
0.6,
88+
0.7,
89+
0.8,
90+
0.9,
91+
1.0,
92+
1.1,
93+
1.2,
94+
1.3,
95+
1.4,
96+
],
6797
}
6898

6999
# We can then instantiate the dataset with these delta_timestamps configuration.

examples/advanced/2_calculate_validation_loss.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,6 @@
1212

1313
import torch
1414

15-
from lerobot.common.datasets.lerobot_dataset import (
16-
LeRobotDataset,
17-
LeRobotDatasetMetadata,
18-
)
1915
from lerobot.common.datasets.lerobot_dataset import (
2016
LeRobotDataset,
2117
LeRobotDatasetMetadata,
@@ -44,7 +40,24 @@ def main():
4440
# Load the previous action (-0.1), the next action to be executed (0.0),
4541
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
4642
# used to calculate the loss.
47-
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
43+
"action": [
44+
-0.1,
45+
0.0,
46+
0.1,
47+
0.2,
48+
0.3,
49+
0.4,
50+
0.5,
51+
0.6,
52+
0.7,
53+
0.8,
54+
0.9,
55+
1.0,
56+
1.1,
57+
1.2,
58+
1.3,
59+
1.4,
60+
],
4861
}
4962

5063
# Load the last 10% of episodes of the dataset as a validation set.
@@ -63,7 +76,9 @@ def main():
6376
train_dataset = LeRobotDataset(
6477
"lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps
6578
)
66-
val_dataset = LeRobotDataset("lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps)
79+
val_dataset = LeRobotDataset(
80+
"lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps
81+
)
6782
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
6883
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")
6984

lerobot/common/datasets/compute_stats.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -45,20 +45,20 @@ def get_stats_einops_patterns(dataset, num_workers=0):
4545
if key in dataset.meta.camera_keys:
4646
# sanity check that images are channel first
4747
_, c, h, w = batch[key].shape
48-
assert (
49-
c < h and c < w
50-
), f"expect channel first images, but instead {batch[key].shape}"
48+
assert c < h and c < w, (
49+
f"expect channel first images, but instead {batch[key].shape}"
50+
)
5151

5252
# sanity check that images are float32 in range [0,1]
53-
assert (
54-
batch[key].dtype == torch.float32
55-
), f"expect torch.float32, but instead {batch[key].dtype=}"
56-
assert (
57-
batch[key].max() <= 1
58-
), f"expect pixels lower than 1, but instead {batch[key].max()=}"
59-
assert (
60-
batch[key].min() >= 0
61-
), f"expect pixels greater than 1, but instead {batch[key].min()=}"
53+
assert batch[key].dtype == torch.float32, (
54+
f"expect torch.float32, but instead {batch[key].dtype=}"
55+
)
56+
assert batch[key].max() <= 1, (
57+
f"expect pixels lower than 1, but instead {batch[key].max()=}"
58+
)
59+
assert batch[key].min() >= 0, (
60+
f"expect pixels greater than 1, but instead {batch[key].min()=}"
61+
)
6262

6363
stats_patterns[key] = "b c h w -> c 1 1"
6464
elif batch[key].ndim == 2:

lerobot/common/datasets/factory.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ def resolve_delta_timestamps(
5858
if key == "action" and cfg.action_delta_indices is not None:
5959
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
6060
if key.startswith("observation.") and cfg.observation_delta_indices is not None:
61-
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
61+
delta_timestamps[key] = [
62+
i / ds_meta.fps for i in cfg.observation_delta_indices
63+
]
6264

6365
if len(delta_timestamps) == 0:
6466
delta_timestamps = None
@@ -79,11 +81,15 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
7981
LeRobotDataset | MultiLeRobotDataset
8082
"""
8183
image_transforms = (
82-
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
84+
ImageTransforms(cfg.dataset.image_transforms)
85+
if cfg.dataset.image_transforms.enable
86+
else None
8387
)
8488

8589
if isinstance(cfg.dataset.repo_id, str):
86-
ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id, local_files_only=cfg.dataset.local_files_only)
90+
ds_meta = LeRobotDatasetMetadata(
91+
cfg.dataset.repo_id, local_files_only=cfg.dataset.local_files_only
92+
)
8793
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
8894
dataset = LeRobotDataset(
8995
cfg.dataset.repo_id,
@@ -110,6 +116,8 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
110116
if cfg.dataset.use_imagenet_stats:
111117
for key in dataset.meta.camera_keys:
112118
for stats_type, stats in IMAGENET_STATS.items():
113-
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
119+
dataset.meta.stats[key][stats_type] = torch.tensor(
120+
stats, dtype=torch.float32
121+
)
114122

115123
return dataset

lerobot/common/datasets/lerobot_dataset.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,9 @@ def create(
326326
# as this would break the dict flattening in the stats computation, which uses '/' as separator
327327
for key in features:
328328
if "/" in key:
329-
raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.")
329+
raise ValueError(
330+
f"Feature names should not contain '/'. Found '/' in feature '{key}'."
331+
)
330332

331333
features = {**features, **DEFAULT_FEATURES}
332334

lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ def check_format(raw_dir) -> bool:
7575
else:
7676
assert data[f"/observations/images/{camera}"].ndim == 4
7777
b, h, w, c = data[f"/observations/images/{camera}"].shape
78-
assert (
79-
c < h and c < w
80-
), f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
78+
assert c < h and c < w, (
79+
f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
80+
)
8181

8282

8383
def load_from_raw(

lerobot/common/datasets/transforms.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,16 @@ def _check_input(self, sharpness):
153153
return float(sharpness[0]), float(sharpness[1])
154154

155155
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
156-
sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
156+
sharpness_factor = (
157+
torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
158+
)
157159
return {"sharpness_factor": sharpness_factor}
158160

159161
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
160162
sharpness_factor = params["sharpness_factor"]
161-
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
163+
return self._call_kernel(
164+
F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor
165+
)
162166

163167

164168
@dataclass

lerobot/common/envs/configs.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,16 @@ class AlohaEnv(EnvConfig):
4747

4848
def __post_init__(self):
4949
if self.obs_type == "pixels":
50-
self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
50+
self.features["top"] = PolicyFeature(
51+
type=FeatureType.VISUAL, shape=(480, 640, 3)
52+
)
5153
elif self.obs_type == "pixels_agent_pos":
52-
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
53-
self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
54+
self.features["agent_pos"] = PolicyFeature(
55+
type=FeatureType.STATE, shape=(14,)
56+
)
57+
self.features["pixels/top"] = PolicyFeature(
58+
type=FeatureType.VISUAL, shape=(480, 640, 3)
59+
)
5460

5561
@property
5662
def gym_kwargs(self) -> dict:
@@ -88,9 +94,13 @@ class PushtEnv(EnvConfig):
8894

8995
def __post_init__(self):
9096
if self.obs_type == "pixels_agent_pos":
91-
self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3))
97+
self.features["pixels"] = PolicyFeature(
98+
type=FeatureType.VISUAL, shape=(384, 384, 3)
99+
)
92100
elif self.obs_type == "environment_state_agent_pos":
93-
self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
101+
self.features["environment_state"] = PolicyFeature(
102+
type=FeatureType.ENV, shape=(16,)
103+
)
94104

95105
@property
96106
def gym_kwargs(self) -> dict:
@@ -129,7 +139,9 @@ class XarmEnv(EnvConfig):
129139

130140
def __post_init__(self):
131141
if self.obs_type == "pixels_agent_pos":
132-
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
142+
self.features["agent_pos"] = PolicyFeature(
143+
type=FeatureType.STATE, shape=(4,)
144+
)
133145

134146
@property
135147
def gym_kwargs(self) -> dict:

lerobot/common/envs/factory.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# limitations under the License.
1616
import importlib
1717
from collections import deque
18-
from collections import deque
1918

2019
import gymnasium as gym
2120

@@ -33,7 +32,9 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
3332
raise ValueError(f"Policy type '{env_type}' is not available.")
3433

3534

36-
def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None:
35+
def make_env(
36+
cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False
37+
) -> gym.vector.VectorEnv | None:
3738
"""Makes a gym vector environment according to the config.
3839
3940
Args:
@@ -57,15 +58,20 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
5758
try:
5859
importlib.import_module(package_name)
5960
except ModuleNotFoundError as e:
60-
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
61+
print(
62+
f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`"
63+
)
6164
raise e
6265

6366
gym_handle = f"{package_name}/{cfg.task}"
6467

6568
# batched version of the env that returns an observation of shape (b, c)
6669
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
6770
env = env_cls(
68-
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
71+
[
72+
lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs)
73+
for _ in range(n_envs)
74+
]
6975
)
7076

7177
return env

lerobot/common/envs/utils.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
5252

5353
# sanity check that images are channel last
5454
_, h, w, c = img.shape
55-
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
55+
assert c < h and c < w, (
56+
f"expect channel last images, but instead got {img.shape=}"
57+
)
5658

5759
# sanity check that images are uint8
5860
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
@@ -94,7 +96,9 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
9496
for key, ft in env_cfg.features.items():
9597
if ft.type is FeatureType.VISUAL:
9698
if len(ft.shape) != 3:
97-
raise ValueError(f"Number of dimensions of {key} != 3 (shape={ft.shape})")
99+
raise ValueError(
100+
f"Number of dimensions of {key} != 3 (shape={ft.shape})"
101+
)
98102

99103
shape = get_channel_first_image_shape(ft.shape)
100104
feature = PolicyFeature(type=ft.type, shape=shape)

lerobot/common/logger.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,9 @@ def load_last_training_state(
250250
)
251251
# For the case where the optimizer is a dictionary of optimizers (e.g., sac)
252252
if type(training_state["optimizer"]) is dict:
253-
assert set(training_state["optimizer"].keys()) == set(
254-
optimizer.keys()
255-
), "Optimizer dictionaries do not have the same keys during resume!"
253+
assert set(training_state["optimizer"].keys()) == set(optimizer.keys()), (
254+
"Optimizer dictionaries do not have the same keys during resume!"
255+
)
256256
for k, v in training_state["optimizer"].items():
257257
optimizer[k].load_state_dict(v)
258258
else:

0 commit comments

Comments
 (0)