Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

What is the Expected Format for Training Dataset? #81

Open
b4r4d41z opened this issue Feb 15, 2025 · 1 comment
Open

What is the Expected Format for Training Dataset? #81

b4r4d41z opened this issue Feb 15, 2025 · 1 comment

Comments

@b4r4d41z
Copy link

I want to train the model to perform tasks using a dual-arm robot. I am working with my Kuavo robot from Leju. According to the guide, I need to collect my own dataset. Currently, I can successfully gather data in the .bag format, but I am unsure about the required data format for proper training (what format does the script expect?).

Could you please advise if there are any datasets used by the developers for pre-training that I can download to understand the required structure? If they are publicly available, please share the link.

Alternatively, point me to this information in the README.md if I have missed it.

Thanks ❤️

@b4r4d41z
Copy link
Author

b4r4d41z commented Mar 3, 2025

I continue my attempts to get to the truth and still need help. If you've faced the same questions and managed to find answers, please share your experience in this issue.

finetune.sh → main.py → train/train.py → VLAConsumerDataset → HDF5VLADataset.parse_hdf5_file

I think I've figured out how the fine-tuning process works at the file interaction level:

inetune.sh: This script launches the finetuning process for the model. It sets up the necessary environment variables and calls main.py with the appropriate arguments.

deepspeed --hostfile=hostfile.txt main.py \
    --deepspeed="./configs/zero2.json" \
    --pretrained_model_name_or_path="robotics-diffusion-transformer/rdt-1b" \
    --pretrained_text_encoder_name_or_path=$TEXT_ENCODER_NAME \
    --pretrained_vision_encoder_name_or_path=$VISION_ENCODER_NAME \
    --output_dir=$OUTPUT_DIR \
    --train_batch_size=32 \
    --sample_batch_size=64 \
    --max_train_steps=200000 \
    --checkpointing_period=1000 \
    --sample_period=500 \
    --checkpoints_total_limit=40 \
    --lr_scheduler="constant" \
    --learning_rate=1e-4 \
    --mixed_precision="bf16" \
    --dataloader_num_workers=8 \
    --image_aug \
    --dataset_type="finetune" \
    --state_noise_snr=40 \
    --load_from_hdf5 \
    --report_to=wandb

main.py: The main entry script, which processes the provided arguments and initiates the training process by calling the train() function from the train/train.py module.

def train(args, logger):

train/train.py: Contains the train() function, which is responsible for setting up the model, loading the data, and executing the training loop. During this process, it creates an instance of the VLAConsumerDataset class for data preparation.

    # Dataset and DataLoaders creation:                                                           
    train_dataset = VLAConsumerDataset(
        config=config["dataset"],
        tokenizer=tokenizer,
        image_processor=image_processor,
        num_cameras=config["common"]["num_cameras"],
        img_history_size=config["common"]["img_history_size"],
        dataset_type=args.dataset_type,
        image_aug=args.image_aug,
        cond_mask_prob=args.cond_mask_prob,
        cam_ext_mask_prob=args.cam_ext_mask_prob,
        state_noise_snr=args.state_noise_snr,
        use_hdf5=args.load_from_hdf5,
        use_precomp_lang_embed=args.precomp_lang_embed,
    )
    sample_dataset = VLAConsumerDataset(
        config=config["dataset"],
        tokenizer=tokenizer,
        image_processor=image_processor,
        num_cameras=config["common"]["num_cameras"],
        img_history_size=config["common"]["img_history_size"],
        dataset_type=args.dataset_type,
        image_aug=False,
        cond_mask_prob=0,
        cam_ext_mask_prob=-1,
        state_noise_snr=None,
        use_hdf5=args.load_from_hdf5,
        use_precomp_lang_embed=args.precomp_lang_embed,
    )                              
    
    data_collator = DataCollatorForVLAConsumerDataset(tokenizer)         

train/dataset.py: This file defines the VLAConsumerDataset class, which manages data loading and preprocessing for the model. If the configuration specifies loading data from HDF5 files, it creates an instance of the HDF5VLADataset class.

from data.hdf5_vla_dataset import HDF5VLADataset
...
if use_hdf5:
            self.hdf5_dataset = HDF5VLADataset()`

data/hdf5_vla_dataset.py: Contains the HDF5VLADataset class, which is responsible for reading data from HDF5 files. The parse_hdf5_file method of this class processes individual episodes, extracting necessary data such as states, actions, images, and instructions, and returns them as a dictionary for further use in training.

def parse_hdf5_file(self, file_path):
        """[Modify] Parse a hdf5 file to generate a training sample at
            a random timestep.

        Args:
            file_path (str): the path to the hdf5 file
        
        Returns:
            valid (bool): whether the episode is valid, which is useful for filtering.
                If False, this episode will be dropped.
            dict: a dictionary containing the training sample,
                {
                    "meta": {
                        "dataset_name": str,    # the name of your dataset.
                        "#steps": int,          # the number of steps in the episode,
                                                # also the total timesteps.
                        "instruction": str      # the language instruction for this episode.
                    },                           
                    "step_id": int,             # the index of the sampled step,
                                                # also the timestep t.
                    "state": ndarray,           # state[t], (1, STATE_DIM).
                    "state_std": ndarray,       # std(state[:]), (STATE_DIM,).
                    "state_mean": ndarray,      # mean(state[:]), (STATE_DIM,).
                    "state_norm": ndarray,      # norm(state[:]), (STATE_DIM,).
                    "actions": ndarray,         # action[t:t+CHUNK_SIZE], (CHUNK_SIZE, STATE_DIM).
                    "state_indicator", ndarray, # indicates the validness of each dim, (STATE_DIM,).
                    "cam_high": ndarray,        # external camera image, (IMG_HISORY_SIZE, H, W, 3)
                                                # or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
                    "cam_high_mask": ndarray,   # indicates the validness of each timestep, (IMG_HISORY_SIZE,) boolean array.
                                                # For the first IMAGE_HISTORY_SIZE-1 timesteps, the mask should be False.
                    "cam_left_wrist": ndarray,  # left wrist camera image, (IMG_HISORY_SIZE, H, W, 3).
                                                # or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
                    "cam_left_wrist_mask": ndarray,
                    "cam_right_wrist": ndarray, # right wrist camera image, (IMG_HISORY_SIZE, H, W, 3).
                                                # or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
                                                # If only one wrist, make it right wrist, plz.
                    "cam_right_wrist_mask": ndarray
                } or None if the episode is invalid.
        """

Based on the last point, I currently assume that the dataset should preferably be in the exact format expected for further processing (as shown in the last code segment). At the moment, I am working on a script to transform .bag.hdf5. If you have a working algorithm or have seen a similar project, I would greatly appreciate it if you could share it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant