-
Notifications
You must be signed in to change notification settings - Fork 99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to run inference of the model with a single image and no proprioception data? #5
Comments
Thank you for your question. Our RDT model is pretrained and fine-tuned using proprioception data along with single or multiple images. If you intend to perform inference using only a single image without proprioception data, we recommend fine-tuning the RDT model with data in this format. For further details, please refer to the fine-tuning section of our documentation. However, it is important to note that we have not yet conducted experiments fine-tuning the RDT model without proprioception data, so we cannot guarantee the quality of the resulting performance. |
Thank you for your engagement. We are currently working to improve the clarity of the inference as you requested. We will inform you as soon as the updates are complete. Feel free to let me know if you’d like any further adjustments! |
Thank you for the prompt response! That would really help our use case. And yes, I should have been more clear, I do want to finetune, but also before I do that I wanted to just do a sanity check to make sure I could at least run inference on the environment. I do not expect the model to perform well without finetuning, since as you said, it is quite a different input space without the proprioception data. Again, please dont hesitate to ask if you'd like me to clarify anything more on my end. Thank you! |
Apologies for the late response. We have prepared a beta version of the minimal implementation for inference as requested. Please note that due to urgent circumstances, it has not been tested yet. We hope this version clarifies the inference process. Thank you for your understanding. # install dependencies as shown in the README here https://github.com/alik-git/RoboticsDiffusionTransformer?tab=readme-ov-file#installation
import yaml
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from configs.state_vec import STATE_VEC_IDX_MAPPING
from models.multimodal_encoder.siglip_encoder import SiglipVisionTower
from models.rdt_runner import RDTRunner
# other imports
config_path = "configs/base.yaml" # default config
pretrained_model_name_or_path = "path/to/rdt-model"
device = torch.device('cuda:0')
dtype = torch.bfloat16 # recommanded
cfg_scale = 2.0
# suppose you control in 7DOF joint position
STATE_INDICES = [
STATE_VEC_IDX_MAPPING['arm_joint_0_pos'],
STATE_VEC_IDX_MAPPING['arm_joint_1_pos'],
STATE_VEC_IDX_MAPPING['arm_joint_2_pos'],
STATE_VEC_IDX_MAPPING['arm_joint_3_pos'],
STATE_VEC_IDX_MAPPING['arm_joint_4_pos'],
STATE_VEC_IDX_MAPPING['arm_joint_5_pos'],
STATE_VEC_IDX_MAPPING['arm_joint_6_pos'],
STATE_VEC_IDX_MAPPING['gripper_open']
]
with open(config_path, "r") as fp:
config = yaml.safe_load(fp)
# Load vision encoder
vision_encoder = SiglipVisionTower(
vision_tower="/path/to/siglip-so400m-patch14-384",
args=None
)
vision_encoder.to(device, dtype=dtype)
vision_encoder.eval()
image_processor = vision_encoder.image_processor
# Load pretrained model (in HF style)
rdt = RDTRunner.from_pretrained(pretrained_model_name_or_path)
rdt.to(device, dtype=dtype)
rdt.eval()
previous_image_path = None
# previous_image = None # if t = 0
previous_image = Image.open(previous_image_path).convert("RGB") # if t > 0
current_image_path = None
current_image = Image.open(current_image_path).convert("RGB")
# here I suppose you only have an image from exterior (e.g., 3rd person view) and you don't have any state information
# the images shoud arrange in sequence [exterior_image, right_wrist_image, left_wrist_image] * image_history_size (e.g., 2)
rgbs_lst = [
[previous_image, None, None],
[current_image, None, None]
]
# if your have an right_wrist_image, then it should be
# rgbs_lst = [
# [previous_image, previous_right_wrist_image, None],
# [current_image, current_right_wrist_image, None]
# ]
# image pre-processing
# The background image used for padding
background_color = np.array([
int(x*255) for x in image_processor.image_mean
], dtype=np.uint8).reshape(1, 1, 3)
background_image = np.ones((
image_processor.size["height"],
image_processor.size["width"], 3), dtype=np.uint8
) * background_color
image_tensor_list = []
for step in range(config["common"]["img_history_size"]):
rgbs = rgbs_lst[step % len(rgbs_lst)]
for rgb in rgbs:
if rgb is None:
# Replace it with the background image
image = Image.fromarray(background_image)
else:
image = Image.fromarray((rgb * 255).astype(np.uint8))
if config["dataset"].get("auto_adjust_image_brightness", False):
pixel_values = list(image.getdata())
average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
if average_brightness <= 0.15:
image = transforms.ColorJitter(brightness=(1.75,1.75))(image)
if config["dataset"].get("image_aspect_ratio", "pad") == 'pad':
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
image_tensor_list.append(image)
image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype)
# encode images
image_embeds = vision_encoder(image_tensor).detach()
image_embeds = image_embeds.reshape(-1, vision_encoder.hidden_size).unsqueeze(0)
# Load language embeddings # this part I belive can be done via the `encode_lang.py` script, which I have managed to do
lang_embeddings = torch.load("path/to/outs/handover_pan.pt" , map_location=device)
# suppose you do not have proprio
# it's kind of tricky, I strongly suggest adding proprio as input and futher fine-tuning
B, N = 1, 1 # batch size and state history size
states = torch.zeros(
(B, N, config["model"]["state_token_dim"]),
device=device, dtype=dtype
)
# if you have proprio, you can do like this
# format like this: [arm_joint_0_pos, arm_joint_1_pos, arm_joint_2_pos, arm_joint_3_pos, arm_joint_4_pos, arm_joint_5_pos, arm_joint_6_pos, gripper_open]
# proprio = torch.tensor([0, 1, 2, 3, 4, 5, 6, 0.5]).reshape((1, 1, -1))
# states[:, :, STATE_INDICES] = proprio
state_elem_mask = torch.zeros(
(B, config["model"]["state_token_dim"]),
device=device, dtype=torch.bool
)
state_elem_mask[:, STATE_INDICES] = True
states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(device, dtype=dtype)
states = states[:, -1:, :] # only use the last state
actions = rdt.predict_action(
lang_tokens=lang_embeddings.to(device, dtype=dtype),
lang_attn_mask=torch.ones(
lang_embeddings.shape[:2], dtype=torch.bool,
device=device
),
img_tokens=image_embeds,
state_tokens=states, # how can I get this?
action_mask=state_elem_mask.unsqueeze(1), # how can I get this?
ctrl_freqs=torch.tensor([25.0], device=device), # would this default work?
) # (1, chunk_size, 128)
# select the meaning action via STATE_INDICES
action = actions[:, :, STATE_INDICES] # (1, chunk_size, len(STATE_INDICES)) = (1, chunk_size, 7+ 1) |
Apologies for the delayed response—I just returned from my weekend trip. You can use a padding value when fine-tuning and running RDT. We're also in the process of fine-tuning and running RDT on ManiSkill2/SimplerEnv. Once that work is completed, we'll be releasing everything. Stay tuned! |
Great work on the project! I’m particularly interested in replicating it in a simulation environment. Do you have an estimated timeline for when the ManiSkill2/SimplerEnv RDT implementation will be available on GitHub? Thank you for sharing such innovative work! |
This is a little bit challenging since the controller in the simulator is not aligned with the real-robot data. But I believe we will reach a conclusion within one week. |
Hi there!
Thank you for your great research and open-source contributions.
I just have a few questions about running your model.
What I am trying to do
I am trying to run RDT on the SimplerEnv to see how it compares to other baselines, but I am struggling to understand how to run inference of the pre-trained checkpoint given only a single image and text prompt as the input, which is the format for SimplerEnv. Is this something that is possible with RDT? I understand that RDT requires more inputs than just an image and text prompt (it requires proprioception data, control frequency, noisy action chunk and diffusion timestep as well), but I am wondering if there is a way to somehow pass blank (or default) values for those inputs in situations where they are not available.
What I have already tried
First I tried to run the
agilex_inference.py
as documented in the deployment section of your README here. But it seems that script expects to process inputs via ROS, which is not my situation.What would help if possible
If you have anything similar to code snippet in the OpenVLA readme in the their getting started section, something like this below. Note that I got most of this code from your
sample.py
script. But I am struggling to get it to run as I'm not sure how to pass only a single image and text prompt to the model.I hope it is clear what I'm trying to do, but please feel free to ask me to clarify any of the above if needed. Any suggestions on how to achieve this would be very appreciated. Thank you!
The text was updated successfully, but these errors were encountered: