diff --git a/README.md b/README.md index 221efa5..29c5778 100644 --- a/README.md +++ b/README.md @@ -169,6 +169,7 @@ If you want to embed a policy server call in your own robot runtime, we have a m We provide more examples for how to fine-tune and run inference with our models on the ALOHA platform in the following READMEs: - [ALOHA Simulator](examples/aloha_sim) - [ALOHA Real](examples/aloha_real) +- [UR5](examples/ur5) diff --git a/examples/ur5/README.md b/examples/ur5/README.md new file mode 100644 index 0000000..88c6d7a --- /dev/null +++ b/examples/ur5/README.md @@ -0,0 +1,151 @@ +# UR5 Example + +Below we provide an outline of how to implement the key components mentioned in the "Finetune on your data" section of the [README](../README.md) for finetuning on UR5 datasets. + +First, we will define the `UR5Inputs` and `UR5Outputs` classes, which map the UR5 environment to the model and vice versa. Check the corresponding files in `src/openpi/policies/libero_policy.py` for comments explaining each line. + +```python + +@dataclasses.dataclass(frozen=True) +class UR5Inputs(transforms.DataTransformFn): + + action_dim: int + model_type: _model.ModelType = _model.ModelType.PI0 + + def __call__(self, data: dict) -> dict: + mask_padding = self.model_type == _model.ModelType.PI0 + + # First, concatenate the joints and gripper into the state vector. + # Pad to the expected input dimensionality of the model (same as action_dim). + state = np.concatenate([data["joints"], data["gripper"]]) + state = transforms.pad_to_dim(state, self.action_dim) + + # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically + # stores as float32 (C,H,W), gets skipped for policy inference. + base_image = _parse_image(data["base_rgb"]) + wrist_image = _parse_image(data["wrist_rgb"]) + + # Create inputs dict. + inputs = { + "state": state, + "image": { + "base_0_rgb": base_image, + "left_wrist_0_rgb": wrist_image, + # Since there is no right wrist, replace with zeros + "right_wrist_0_rgb": np.zeros_like(base_image), + }, + "image_mask": { + "base_0_rgb": np.True_, + "left_wrist_0_rgb": np.True_, + # Since the "slot" for the right wrist is not used, this mask is set + # to False + "right_wrist_0_rgb": np.False_ if mask_padding else np.True_, + }, + } + + # Pad actions to the model action dimension. + if "actions" in data: + # The robot produces 7D actions (6 DoF + 1 gripper), and we pad these. + actions = transforms.pad_to_dim(data["actions"], self.action_dim) + inputs["actions"] = actions + + # Pass the prompt (aka language instruction) to the model. + if "prompt" in data: + inputs["prompt"] = data["prompt"] + + return inputs + + +@dataclasses.dataclass(frozen=True) +class UR5Outputs(transforms.DataTransformFn): + + def __call__(self, data: dict) -> dict: + # Since the robot has 7 action dimensions (6 DoF + gripper), return the first 7 dims + return {"actions": np.asarray(data["actions"][:, :7])} + +``` + +Next, we will define the `UR5DataConfig` class, which defines how to process raw UR5 data from LeRobot dataset for training. For a full example, see the `LeRobotLiberoDataConfig` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py). + +```python + +@dataclasses.dataclass(frozen=True) +class LeRobotUR5DataConfig(DataConfigFactory): + + @override + def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: + # Boilerplate for remapping keys from the LeRobot dataset. We assume no renaming needed here. + repack_transform = _transforms.Group( + inputs=[ + _transforms.RepackTransform( + { + "base_rgb": "image", + "wrist_rgb": "wrist_image", + "joints": "joints", + "gripper": "gripper", + "prompt": "prompt", + } + ) + ] + ) + + # These transforms are the ones we wrote earlier. + data_transforms = _transforms.Group( + inputs=[UR5Inputs(action_dim=model_config.action_dim, model_type=model_config.model_type)], + outputs=[UR5Outputs()], + ) + + # Convert absolute actions to delta actions. + # By convention, we do not convert the gripper action (7th dimension). + delta_action_mask = _transforms.make_bool_mask(6, -1) + data_transforms = data_transforms.push( + inputs=[_transforms.DeltaActions(delta_action_mask)], + outputs=[_transforms.AbsoluteActions(delta_action_mask)], + ) + + # Model transforms include things like tokenizing the prompt and action targets + # You do not need to change anything here for your own dataset. + model_transforms = ModelTransformFactory()(model_config) + + # We return all data transforms for training and inference. No need to change anything here. + return dataclasses.replace( + self.create_base_config(assets_dirs), + repack_transforms=repack_transform, + data_transforms=data_transforms, + model_transforms=model_transforms, + ) + +``` + +Finally, we define the TrainConfig for our UR5 dataset. Here, we define a config for fine-tuning pi0 on our UR5 dataset. See the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py) for more examples, e.g. for pi0-FAST or for LoRA fine-tuning. + +```python +TrainConfig( + name="pi0_ur5", + model=pi0.Pi0Config(), + data=LeRobotUR5DataConfig( + repo_id="your_username/ur5_dataset", + # This config lets us reload the UR5 normalization stats from the base model checkpoint. + # Reloading normalization stats can help transfer pre-trained models to new environments. + # See the [norm_stats.md](../docs/norm_stats.md) file for more details. + assets=AssetsConfig( + assets_dir="s3://openpi-assets/checkpoints/pi0_base/assets", + asset_id="ur5e", + ), + base_config=DataConfig( + local_files_only=True, # True, if dataset is saved locally. + # This flag determines whether we load the prompt (i.e. the task instruction) from the + # ``task`` field in the LeRobot dataset. The recommended setting is True. + prompt_from_task=True, + ), + ), + # Load the pi0 base model checkpoint. + weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_base/params"), + num_train_steps=30_000, +) +``` + + + + +