Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions src/openpi/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import openpi.shared.download as _download
import openpi.shared.normalize as _normalize
import openpi.training.droid_rlds_dataset as droid_rlds_dataset
import openpi.training.misc.polaris_config as polaris_config
import openpi.training.misc.roboarena_config as roboarena_config
import openpi.training.optimizer as _optimizer
import openpi.training.weight_loaders as weight_loaders
Expand Down Expand Up @@ -93,8 +94,8 @@ class DataConfig:
rlds_data_dir: str | None = None
# Action space for DROID dataset.
action_space: droid_rlds_dataset.DroidActionSpace | None = None
# Path to the data filter file for DROID dataset
filter_dict_path: str | None = None
# List of datasets to sample from: name, version, weight, and optionally filter_dict_path
datasets: Sequence[droid_rlds_dataset.RLDSDataset] = ()


class GroupFactory(Protocol):
Expand Down Expand Up @@ -366,8 +367,16 @@ class RLDSDroidDataConfig(DataConfigFactory):
# Filtering options. Can pass a path to a dictionary that maps episodes to timestep ranges
# to tuples denoting ranges of time steps to keep (start, end). Episodes are uniquely identified with
# f"{recording_folderpath}--{file_path}", both of which are present in the RLDS episode metadata.
# Path to the filter dictionary file.
filter_dict_path: str | None = "gs://openpi-assets/droid/droid_sample_ranges_v1_0_1.json"

# List of datasets to sample from: name, version, weight, and optionally filter_dict_path
datasets: Sequence[droid_rlds_dataset.RLDSDataset] = (
droid_rlds_dataset.RLDSDataset(
name="droid",
version="1.0.1",
weight=1.0,
filter_dict_path="gs://openpi-assets/droid/droid_sample_ranges_v1_0_1.json",
),
)

@override
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
Expand Down Expand Up @@ -410,7 +419,7 @@ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig
model_transforms=model_transforms,
rlds_data_dir=self.rlds_data_dir,
action_space=self.action_space,
filter_dict_path=self.filter_dict_path,
datasets=self.datasets,
)


Expand Down Expand Up @@ -956,10 +965,9 @@ def __post_init__(self) -> None:
exp_name="debug_pi05",
wandb_enabled=False,
),
#
# RoboArena configs.
#
# RoboArena & PolaRiS configs.
*roboarena_config.get_roboarena_configs(),
*polaris_config.get_polaris_configs(),
]

if len({config.name for config in _CONFIGS}) != len(_CONFIGS):
Expand Down
2 changes: 1 addition & 1 deletion src/openpi/training/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def create_rlds_dataset(
shuffle=shuffle,
action_chunk_size=action_horizon,
action_space=data_config.action_space,
filter_dict_path=data_config.filter_dict_path,
datasets=data_config.datasets,
)


Expand Down
Loading