Initial commit

This commit is contained in:
Ury Zhilinsky
2024-12-23 13:38:06 -08:00
commit 385780ecc3
121 changed files with 15572 additions and 0 deletions

0
scripts/__init__.py Normal file
View File

479
scripts/aloha_hd5.py Normal file
View File

@@ -0,0 +1,479 @@
# ruff: noqa
"""
Script courtesy of Raziel90 https://github.com/huggingface/lerobot/pull/586/files
Example usage
python scripts/aloha_hd5.py --raw-path ~/data/ --dataset-repo-id <hf-username>/<dataset-name> --robot-type <aloha-stationary|aloha-mobile> --fps 50 --video-encoding=false --push=false
The data will be saved locally the value of the LEROBOT_HOME environment variable. By default this is set to ~/.cache/huggingface/lerobot
If you wish to submit the dataset to the hub, you can do so by setting up the hf cli https://huggingface.co/docs/huggingface_hub/en/guides/cli and setting --push=true
"""
import argparse
import logging
import os
from pathlib import Path
import shutil
import traceback
import cv2
import h5py
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import torch
class AlohaHD5Extractor:
TAGS = ["aloha", "robotics", "hdf5"]
aloha_stationary = "aloha-stationary"
aloha_mobile = "aloha-mobile"
@staticmethod
def get_cameras(hdf5_data: h5py.File):
"""
Extracts the list of RGB camera keys from the given HDF5 data.
Parameters
----------
hdf5_data : h5py.File
The HDF5 file object containing the dataset.
Returns
-------
list of str
A list of keys corresponding to RGB cameras in the dataset.
"""
rgb_cameras = [key for key in hdf5_data["/observations/images"] if "depth" not in key]
return rgb_cameras
@staticmethod
def check_format(episode_list: list[str] | list[Path], image_compressed: bool = True):
"""
Check the format of the given list of HDF5 files.
Parameters
----------
episode_list : list of str or list of Path
List of paths to the HDF5 files to be checked.
image_compressed : bool, optional
Flag indicating whether the images are compressed (default is True).
Raises
------
ValueError
If the episode_list is empty.
If any HDF5 file is missing required keys '/action' or '/observations/qpos'.
If the '/action' or '/observations/qpos' keys do not have 2 dimensions.
If the number of frames in '/action' and '/observations/qpos' keys do not match.
If the number of frames in '/observations/images/{camera}' does not match the number of frames in '/action' and '/observations/qpos'.
If the dimensions of images do not match the expected dimensions based on the image_compressed flag.
If uncompressed images do not have the expected (h, w, c) format.
"""
if not episode_list:
raise ValueError("No hdf5 files found in the raw directory. Make sure they are named 'episode_*.hdf5'")
for episode_path in episode_list:
with h5py.File(episode_path, "r") as data:
if not all(key in data for key in ["/action", "/observations/qpos"]):
raise ValueError(
"Missing required keys in the hdf5 file. Make sure the keys '/action' and '/observations/qpos' are present."
)
if not data["/action"].ndim == data["/observations/qpos"].ndim == 2:
raise ValueError("The '/action' and '/observations/qpos' keys should have both 2 dimensions.")
if (num_frames := data["/action"].shape[0]) != data["/observations/qpos"].shape[0]:
raise ValueError(
"The '/action' and '/observations/qpos' keys should have the same number of frames."
)
for camera in AlohaHD5Extractor.get_cameras(data):
if num_frames != data[f"/observations/images/{camera}"].shape[0]:
raise ValueError(
f"The number of frames in '/observations/images/{camera}' should be the same as in '/action' and '/observations/qpos' keys."
)
expected_dims = 2 if image_compressed else 4
if data[f"/observations/images/{camera}"].ndim != expected_dims:
raise ValueError(
f"Expect {expected_dims} dimensions for {'compressed' if image_compressed else 'uncompressed'} images but {data[f'/observations/images/{camera}'].ndim} provided."
)
if not image_compressed:
b, h, w, c = data[f"/observations/images/{camera}"].shape
if not c < h and c < w:
raise ValueError(f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided.")
@staticmethod
def extract_episode_frames(
episode_path: str | Path, features: dict[str, dict], image_compressed: bool
) -> list[dict[str, torch.Tensor]]:
"""
Extract frames from an episode stored in an HDF5 file.
Parameters
----------
episode_path : str or Path
Path to the HDF5 file containing the episode data.
features : dict of str to dict
Dictionary where keys are feature identifiers and values are dictionaries with feature details.
image_compressed : bool
Flag indicating whether the images are stored in a compressed format.
Returns
-------
list of dict of str to torch.Tensor
List of frames, where each frame is a dictionary mapping feature identifiers to tensors.
"""
frames = []
with h5py.File(episode_path, "r") as file:
for frame_idx in range(file["/action"].shape[0]):
frame = {}
for feature_id in features:
feature_name_hd5 = feature_id.replace(".", "/")
if "images" in feature_id.split("."):
image = (
(file[feature_name_hd5][frame_idx])
if not image_compressed
else cv2.imdecode(file[feature_name_hd5][frame_idx], 1)
)
frame[feature_id] = torch.from_numpy(image.transpose(2, 0, 1))
else:
frame[feature_id] = torch.from_numpy(file[feature_name_hd5][frame_idx])
frames.append(frame)
return frames
@staticmethod
def define_features(
hdf5_file_path: Path, image_compressed: bool = True, encode_as_video: bool = True
) -> dict[str, dict]:
"""
Define features from an HDF5 file.
Parameters
----------
hdf5_file_path : Path
The path to the HDF5 file.
image_compressed : bool, optional
Whether the images are compressed, by default True.
encode_as_video : bool, optional
Whether to encode images as video or as images, by default True.
Returns
-------
dict[str, dict]
A dictionary where keys are topic names and values are dictionaries
containing feature information such as dtype, shape, and names.
"""
# Initialize lists to store topics and features
topics = []
features = {}
# Open the HDF5 file
with h5py.File(hdf5_file_path, "r") as hdf5_file:
# Collect all dataset names in the HDF5 file
hdf5_file.visititems(lambda name, obj: topics.append(name) if isinstance(obj, h5py.Dataset) else None)
# Iterate over each topic to define its features
for topic in topics:
# If the topic is an image, define it as a video feature
if "images" in topic.split("/"):
sample = hdf5_file[topic][0]
features[topic.replace("/", ".")] = {
"dtype": "video" if encode_as_video else "image",
"shape": cv2.imdecode(hdf5_file[topic][0], 1).transpose(2, 0, 1).shape
if image_compressed
else sample.shape,
"names": [
"channel",
"height",
"width",
],
}
# Skip compressed length topics
elif "compress_len" in topic.split("/"):
continue
# Otherwise, define it as a regular feature
else:
features[topic.replace("/", ".")] = {
"dtype": str(hdf5_file[topic][0].dtype),
"shape": (topic_shape := hdf5_file[topic][0].shape),
"names": [f"{topic.split('/')[-1]}_{k}" for k in range(topic_shape[0])],
}
# Return the defined features
return features
class DatasetConverter:
"""
A class to convert datasets to Lerobot format.
Parameters
----------
raw_path : Path or str
The path to the raw dataset.
dataset_repo_id : str
The repository ID where the dataset will be stored.
fps : int
Frames per second for the dataset.
robot_type : str, optional
The type of robot, by default "".
encode_as_videos : bool, optional
Whether to encode images as videos, by default True.
image_compressed : bool, optional
Whether the images are compressed, by default True.
image_writer_processes : int, optional
Number of processes for writing images, by default 0.
image_writer_threads : int, optional
Number of threads for writing images, by default 0.
Methods
-------
extract_episode(episode_path, task_description='')
Extracts frames from a single episode and saves it with a description.
extract_episodes(episode_description='')
Extracts frames from all episodes and saves them with a description.
push_dataset_to_hub(dataset_tags=None, private=False, push_videos=True, license="apache-2.0")
Pushes the dataset to the Hugging Face Hub.
init_lerobot_dataset()
Initializes the Lerobot dataset.
"""
def __init__(
self,
raw_path: Path | str,
dataset_repo_id: str,
fps: int,
robot_type: str = "",
encode_as_videos: bool = True,
image_compressed: bool = True,
image_writer_processes: int = 0,
image_writer_threads: int = 0,
):
self.raw_path = raw_path if isinstance(raw_path, Path) else Path(raw_path)
self.dataset_repo_id = dataset_repo_id
self.fps = fps
self.robot_type = robot_type
self.image_compressed = image_compressed
self.image_writer_threads = image_writer_threads
self.image_writer_processes = image_writer_processes
self.encode_as_videos = encode_as_videos
self.logger = logging.getLogger(self.__class__.__name__)
self.logger.setLevel(logging.INFO)
# Add console handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - [%(name)s] - %(message)s")
console_handler.setFormatter(formatter)
self.logger.addHandler(console_handler)
self.logger.info(f"{'-'*10} Aloha HD5 -> Lerobot Converter {'-'*10}")
self.logger.info(f"Processing Aloha HD5 dataset from {self.raw_path}")
self.logger.info(f"Dataset will be stored in {self.dataset_repo_id}")
self.logger.info(f"FPS: {self.fps}")
self.logger.info(f"Robot type: {self.robot_type}")
self.logger.info(f"Image compressed: {self.image_compressed}")
self.logger.info(f"Encoding images as videos: {self.encode_as_videos}")
self.logger.info(f"#writer processes: {self.image_writer_processes}")
self.logger.info(f"#writer threads: {self.image_writer_threads}")
self.episode_list = list(self.raw_path.glob("episode_*.hdf5"))
AlohaHD5Extractor.check_format(self.episode_list, image_compressed=self.image_compressed)
self.features = AlohaHD5Extractor.define_features(
self.episode_list[0],
image_compressed=self.image_compressed,
encode_as_video=self.encode_as_videos,
)
def extract_episode(self, episode_path, task_description: str = ""):
"""
Extracts frames from an episode and saves them to the dataset.
Parameters
----------
episode_path : str
The path to the episode file.
task_description : str, optional
A description of the task associated with the episode (default is an empty string).
Returns
-------
None
"""
for frame in AlohaHD5Extractor.extract_episode_frames(episode_path, self.features, self.image_compressed):
self.dataset.add_frame(frame)
self.logger.info(f"Saving Episode with Description: {task_description} ...")
self.dataset.save_episode(task=task_description)
def extract_episodes(self, episode_description: str = ""):
"""
Extracts episodes from the episode list and processes them.
Parameters
----------
episode_description : str, optional
A description of the task to be passed to the extract_episode method (default is '').
Raises
------
Exception
If an error occurs during the processing of an episode, it will be caught and printed.
Notes
-----
After processing all episodes, the dataset is consolidated.
"""
for episode_path in self.episode_list:
try:
self.extract_episode(episode_path, task_description=episode_description)
except Exception as e:
print(f"Error processing episode {episode_path}", f"{e}")
traceback.print_exc()
continue
self.dataset.consolidate()
def push_dataset_to_hub(
self,
dataset_tags: list[str] | None = None,
private: bool = False,
push_videos: bool = True,
license: str | None = "apache-2.0",
):
"""
Pushes the dataset to the Hugging Face Hub.
Parameters
----------
dataset_tags : list of str, optional
A list of tags to associate with the dataset on the Hub. Default is None.
private : bool, optional
If True, the dataset will be private. Default is False.
push_videos : bool, optional
If True, videos will be pushed along with the dataset. Default is True.
license : str, optional
The license under which the dataset is released. Default is "apache-2.0".
Returns
-------
None
"""
self.logger.info(f"Pushing dataset to Hugging Face Hub. ID: {self.dataset_repo_id} ...")
self.dataset.push_to_hub(
tags=dataset_tags,
license=license,
push_videos=push_videos,
private=private,
)
def init_lerobot_dataset(self):
"""
Initializes the LeRobot dataset.
This method cleans the cache if the dataset already exists and then creates a new LeRobot dataset.
Returns
-------
LeRobotDataset
The initialized LeRobot dataset.
"""
# Clean the cache if the dataset already exists
if os.path.exists(LEROBOT_HOME / self.dataset_repo_id):
shutil.rmtree(LEROBOT_HOME / self.dataset_repo_id)
self.dataset = LeRobotDataset.create(
repo_id=self.dataset_repo_id,
fps=self.fps,
robot_type=self.robot_type,
features=self.features,
image_writer_threads=self.image_writer_threads,
image_writer_processes=self.image_writer_processes,
)
return self.dataset
def str2bool(value):
if isinstance(value, bool):
return value
value = value.lower()
if value in ("yes", "true", "t", "y", "1"):
return True
if value in ("no", "false", "f", "n", "0"):
return False
raise argparse.ArgumentTypeError("Boolean value expected.")
def main():
"""
Convert Aloha HD5 dataset and push to Hugging Face hub.
This script processes raw HDF5 files from the Aloha dataset, converts them into a specified format,
and optionally uploads the dataset to the Hugging Face hub.
Parameters
----------
--raw-path : Path
Directory containing the raw HDF5 files.
--dataset-repo-id : str
Repository ID where the dataset will be stored.
--fps : int
Frames per second for the dataset.
--robot-type : str, optional
Type of robot, either "aloha-stationary" or "aloha-mobile". Default is "aloha-stationary".
--private : bool, optional
Set to True to make the dataset private. Default is False.
--push-videos : bool, optional
Set to True to push videos to the hub. Default is True.
--license : str, optional
License for the dataset. Default is "apache-2.0".
--image-compressed : bool, optional
Set to True if the images are compressed. Default is True.
--video-encoding : bool, optional
Set to True to encode images as videos. Default is True.
--nproc : int, optional
Number of image writer processes. Default is 10.
--nthreads : int, optional
Number of image writer threads. Default is 5.
"""
parser = argparse.ArgumentParser(description="Convert Aloha HD5 dataset and push to Hugging Face hub.")
parser.add_argument("--raw-path", type=Path, required=True, help="Directory containing the raw hdf5 files.")
parser.add_argument(
"--dataset-repo-id", type=str, required=True, help="Repository ID where the dataset will be stored."
)
parser.add_argument("--fps", type=int, required=True, help="Frames per second for the dataset.")
parser.add_argument(
"--description", type=str, help="Description of the dataset.", default="Aloha recorded dataset."
)
parser.add_argument(
"--robot-type",
type=str,
choices=["aloha-stationary", "aloha-mobile"],
default="aloha-stationary",
help="Type of robot.",
)
parser.add_argument("--private", type=str2bool, default=False, help="Set to True to make the dataset private.")
parser.add_argument("--push", type=str2bool, default=True, help="Set to True to push videos to the hub.")
parser.add_argument("--license", type=str, default="apache-2.0", help="License for the dataset.")
parser.add_argument(
"--image-compressed", type=str2bool, default=True, help="Set to True if the images are compressed."
)
parser.add_argument("--video-encoding", type=str2bool, default=True, help="Set to True to encode images as videos.")
parser.add_argument("--nproc", type=int, default=10, help="Number of image writer processes.")
parser.add_argument("--nthreads", type=int, default=5, help="Number of image writer threads.")
args = parser.parse_args()
print(
args.video_encoding,
"-------------------------------------------------------------------------------------------------------",
)
converter = DatasetConverter(
raw_path=args.raw_path,
dataset_repo_id=args.dataset_repo_id,
fps=args.fps,
robot_type=args.robot_type,
image_compressed=args.image_compressed,
encode_as_videos=args.video_encoding,
image_writer_processes=args.nproc,
image_writer_threads=args.nthreads,
)
converter.init_lerobot_dataset()
converter.extract_episodes(episode_description=args.description)
if args.push:
converter.push_dataset_to_hub(
dataset_tags=AlohaHD5Extractor.TAGS, private=args.private, push_videos=True, license=args.license
)
if __name__ == "__main__":
main()

29
scripts/compose.yml Normal file
View File

@@ -0,0 +1,29 @@
# Run with:
# docker compose -f scripts/compose.yml up --build
services:
openpi_server:
image: openpi_server
build:
context: ..
dockerfile: scripts/serve_policy.Dockerfile
init: true
tty: true
network_mode: host
# Populate configured openpi data home to /openpi_assets inside the container.
# Populate aws credential inside the container.
volumes:
- $PWD:/app
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
- ~/.aws/:/root/.aws/
environment:
- SERVER_ARGS
- OPENPI_DATA_HOME=/openpi_assets
# Comment out this block if not running on a machine with GPUs.
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]

View File

@@ -0,0 +1,67 @@
"""Compute normalization statistics for a config.
This script is used to compute the normalization statistics for a given config. It
will compute the mean and standard deviation of the data in the dataset and save it
to the config metadata directory.
"""
import numpy as np
import tqdm
import tyro
import openpi.shared.normalize as normalize
import openpi.training.config as _config
import openpi.training.data_loader as _data_loader
def create_dataset(config: _config.TrainConfig) -> tuple[str, _data_loader.Dataset]:
model = config.create_model()
data_config = config.data.create(config.metadata_dir, model)
if data_config.repo_id is None:
raise ValueError("Data config must have a repo_id")
dataset = _data_loader.TransformedDataset(
_data_loader.create_dataset(data_config, model),
[
*data_config.repack_transforms.inputs,
*data_config.data_transforms.inputs,
],
)
return data_config.repo_id, dataset
def main(config_name: str, max_frames: int | None = None):
config = _config.get_config(config_name)
repo_id, dataset = create_dataset(config)
num_frames = len(dataset)
shuffle = False
if max_frames is not None and max_frames < num_frames:
num_frames = max_frames
shuffle = True
data_loader = _data_loader.TorchDataLoader(
dataset,
local_batch_size=1,
num_workers=8,
shuffle=shuffle,
num_batches=num_frames,
)
keys = ["state", "actions"]
stats = {key: normalize.RunningStats() for key in keys}
for batch in tqdm.tqdm(data_loader, total=num_frames, desc="Computing stats"):
for key in keys:
values = np.asarray(batch[key][0])
stats[key].update(values.reshape(-1, values.shape[-1]))
norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
output_path = config.metadata_dir / repo_id
print(f"Writing stats to: {output_path}")
normalize.save(output_path, norm_stats)
if __name__ == "__main__":
tyro.cli(main)

View File

@@ -0,0 +1,37 @@
#!/bin/bash
# Add Docker's official GPG key:
sudo apt-get update
sudo apt-get install -y ca-certificates curl
sudo install -m 0755 -d /etc/apt/keyrings
sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc
sudo chmod a+r /etc/apt/keyrings/docker.asc
# Add the repository to Apt sources:
echo \
"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \
$(. /etc/os-release && echo "$VERSION_CODENAME") stable" |
sudo tee /etc/apt/sources.list.d/docker.list >/dev/null
sudo apt-get update
sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin
# Add current user to the 'docker' group, which allows them to use docker commands (docker build, docker run, etc).
# See https://docs.docker.com/engine/install/linux-postinstall/
username=$(whoami)
sudo usermod -aG docker $username
# Configure docker to start automatically on system boot.
sudo systemctl enable docker.service
sudo systemctl enable containerd.service
# https://forums.docker.com/t/docker-credential-desktop-exe-executable-file-not-found-in-path-using-wsl2/100225/5
if [ ~/.docker/config.json ]; then
sed -i 's/credsStore/credStore/g' ~/.docker/config.json
fi
echo ""
echo "********************************************************************"
echo "**** Restart to allow Docker permission changes to take effect. ****"
echo "********************************************************************"
echo ""

View File

@@ -0,0 +1,17 @@
#!/bin/bash
# Installs the NVIDIA Container Toolkit, which allows Docker containers to access NVIDIA GPUs.
# NVIDIA's official documentation: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg &&
curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list |
sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' |
sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
# NVIDIA's documenation omits 'sudo' in the following command, but it is required.
sudo sed -i -e '/experimental/ s/^#//g' /etc/apt/sources.list.d/nvidia-container-toolkit.list
sudo apt-get update
sudo apt-get install -y nvidia-container-toolkit
sudo nvidia-ctk runtime configure --runtime=docker
sudo systemctl restart docker

View File

@@ -0,0 +1,34 @@
# Dockerfile for serving a PI policy.
# Based on UV's instructions: https://docs.astral.sh/uv/guides/integration/docker/#developing-in-a-container
# Build the container:
# docker build . -t openpi_server -f scripts/serve_policy.Dockerfile
# Run the container:
# docker run --rm -it --network=host -v .:/app --gpus=all openpi_server /bin/bash
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
WORKDIR /app
# Needed because LeRobot uses git-lfs.
RUN apt-get update && apt-get install -y git git-lfs
# Copy from the cache instead of linking since it's a mounted volume
ENV UV_LINK_MODE=copy
# Write the virtual environment outside of the project directory so it doesn't
# leak out of the container when we mount the application code.
ENV UV_PROJECT_ENVIRONMENT=/.venv
# Install the project's dependencies using the lockfile and settings
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=uv.lock,target=uv.lock \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
--mount=type=bind,source=packages/openpi-client/pyproject.toml,target=packages/openpi-client/pyproject.toml \
--mount=type=bind,source=packages/openpi-client/src,target=packages/openpi-client/src \
GIT_LFS_SKIP_SMUDGE=1 uv sync --frozen --no-install-project --no-dev
CMD /bin/bash -c "uv run scripts/serve_policy.py $SERVER_ARGS"

243
scripts/serve_policy.py Normal file
View File

@@ -0,0 +1,243 @@
from collections.abc import Sequence
import dataclasses
import enum
import logging
from typing import Any
import tyro
from openpi import transforms
from openpi.models import exported as _exported
from openpi.models import model as _model
from openpi.policies import aloha_policy
from openpi.policies import calvin_policy
from openpi.policies import droid_policy
from openpi.policies import libero_policy
from openpi.policies import policy as _policy
from openpi.policies import policy_config as _policy_config
from openpi.serving import websocket_policy_server
from openpi.training import config as _config
class EnvMode(enum.Enum):
"""Supported environments."""
ALOHA = "aloha"
ALOHA_SIM = "aloha_sim"
DROID = "droid"
CALVIN = "calvin"
LIBERO = "libero"
@dataclasses.dataclass
class Exported:
"""Load an exported checkpoint."""
# Checkpoint directory (e.g., "s3://openpi-assets/exported/pi0_aloha/model").
dir: str
# Processor name to load the norm stats from. If not provided, the default processor for the environment will be used.
processor: str | None = None
@dataclasses.dataclass
class Checkpoint:
"""Load a policy from a trained checkpoint."""
# Training config name (e.g., "pi0_aloha_sim").
config: str
# Checkpoint directory (e.g., "checkpoints/pi0_aloha_sim/exp/10000").
dir: str
@dataclasses.dataclass
class Args:
"""Arguments for the serve_policy script."""
# Environment to serve the policy for.
env: EnvMode = EnvMode.ALOHA_SIM
# Specifies how to load the policy. If not provided, the default policy for the environment will be used.
policy: Checkpoint | Exported | None = None
# If provided, overrides the default prompt for the policy.
default_prompt: str | None = None
# Port to serve the policy on.
port: int = 8000
# Record the policy's behavior for debugging.
record: bool = False
def repack_from_env(env: EnvMode) -> transforms.Group:
"""Creates environment specific repack transforms."""
# TODO(ury): Move this to the runtime.
match env:
case EnvMode.ALOHA:
return transforms.Group(
inputs=[aloha_policy.ActInputsRepack()],
outputs=[aloha_policy.ActOutputsRepack()],
)
case EnvMode.ALOHA_SIM:
return transforms.Group(
inputs=[aloha_policy.ActInputsRepack()],
outputs=[aloha_policy.ActOutputsRepack()],
)
case _:
return transforms.Group()
# Default exported models.
DEFAULT_EXPORTED: dict[EnvMode, Exported] = {
EnvMode.ALOHA: Exported(
dir="s3://openpi-assets/exported/pi0_aloha/model",
processor="trossen_biarm_single_base_cam_24dim",
),
EnvMode.ALOHA_SIM: Exported(
dir="s3://openpi-assets/exported/pi0_aloha_sim/model",
processor="huggingface_aloha_sim_transfer_cube",
),
EnvMode.DROID: Exported(
dir="s3://openpi-assets/exported/pi0_droid/model",
processor="openx_droid",
),
EnvMode.CALVIN: Exported(
dir="s3://openpi-assets/exported/pi0_calvin/model",
processor="calvin",
),
EnvMode.LIBERO: Exported(
dir="s3://openpi-assets/exported/pi0_libero/model",
processor="libero",
),
}
def create_default_policy(
env: EnvMode, *, default_prompt: str | None = None, exported: Exported | None = None
) -> _policy.Policy:
model: _model.BaseModel
config: _policy_config.PolicyConfig
default_exported = DEFAULT_EXPORTED[env]
if exported:
checkpoint_dir = exported.dir
processor = exported.processor or default_exported.processor
else:
checkpoint_dir = default_exported.dir
processor = default_exported.processor
assert processor, "Default processor must be always set"
logging.info("Loading model...")
model = _exported.PiModel.from_checkpoint(checkpoint_dir)
def make_policy_config(
input_layers: Sequence[transforms.DataTransformFn],
output_layers: Sequence[transforms.DataTransformFn],
sample_kwargs: dict[str, Any] | None = None,
):
sample_kwargs = sample_kwargs or {"num_steps": 10}
return _policy_config.PolicyConfig(
model=model,
norm_stats=model.norm_stats(processor),
default_prompt=default_prompt,
input_layers=input_layers,
output_layers=output_layers,
sample_kwargs=sample_kwargs,
)
logging.info("Creating policy...")
match env:
case EnvMode.ALOHA:
delta_action_mask = _policy_config.make_bool_mask(6, -1, 6, -1)
config = make_policy_config(
input_layers=[
aloha_policy.ActInputsRepack(),
aloha_policy.AlohaInputs(
action_dim=model.action_dim,
delta_action_mask=delta_action_mask,
adapt_to_pi=True,
),
],
output_layers=[
aloha_policy.AlohaOutputs(
delta_action_mask=delta_action_mask,
adapt_to_pi=True,
),
aloha_policy.ActOutputsRepack(),
],
)
case EnvMode.ALOHA_SIM:
config = make_policy_config(
input_layers=[
aloha_policy.ActInputsRepack(),
aloha_policy.AlohaInputs(action_dim=model.action_dim),
],
output_layers=[
aloha_policy.AlohaOutputs(),
aloha_policy.ActOutputsRepack(),
],
)
case EnvMode.DROID:
config = make_policy_config(
input_layers=[
droid_policy.DroidInputs(action_dim=model.action_dim),
],
output_layers=[
droid_policy.DroidOutputs(),
transforms.SubsampleActions(stride=5),
],
)
case EnvMode.CALVIN:
config = make_policy_config(
input_layers=[
calvin_policy.CalvinInputs(action_dim=model.action_dim),
],
output_layers=[
calvin_policy.CalvinOutputs(),
],
)
case EnvMode.LIBERO:
config = make_policy_config(
input_layers=[
libero_policy.LiberoInputs(action_dim=model.action_dim),
],
output_layers=[
libero_policy.LiberoOutputs(),
],
)
case _:
raise ValueError(f"Unknown environment mode: {env}")
return _policy_config.create_policy(config)
def create_policy(args: Args) -> _policy.Policy:
match args.policy:
case Checkpoint():
return _policy_config.create_trained_policy(
_config.get_config(args.policy.config),
args.policy.dir,
repack_transforms=repack_from_env(args.env),
default_prompt=args.default_prompt,
)
case Exported():
return create_default_policy(args.env, default_prompt=args.default_prompt, exported=args.policy)
case None:
return create_default_policy(args.env, default_prompt=args.default_prompt)
def main(args: Args) -> None:
policy = create_policy(args)
# Record the policy's behavior.
if args.record:
policy = _policy.PolicyRecorder(policy, "policy_records")
logging.info("Creating server...")
server = websocket_policy_server.WebsocketPolicyServer(policy=policy, host="0.0.0.0", port=args.port)
logging.info("Serving...")
server.serve_forever()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, force=True)
main(tyro.cli(Args))

284
scripts/train.py Normal file
View File

@@ -0,0 +1,284 @@
import dataclasses
from functools import partial
import logging
import platform
from typing import Any
import etils.epath as epath
from flax.training import common_utils
import jax
import jax._src.tree_util as private_tree_util
import jax.experimental
import jax.numpy as jnp
import optax
import tqdm_loggable.auto as tqdm
import wandb
import openpi.models.common as _common
import openpi.models.model as _model
import openpi.shared.array_typing as at
import openpi.training.checkpoints as _checkpoints
import openpi.training.config as _config
import openpi.training.data_loader as _data_loader
import openpi.training.optimizer as _optimizer
import openpi.training.sharding as sharding
import openpi.training.utils as training_utils
import openpi.training.weight_loaders as _weight_loaders
def init_logging():
"""Custom logging format for better readability."""
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
class CustomFormatter(logging.Formatter):
def format(self, record):
record.levelname = level_mapping.get(record.levelname, record.levelname)
return super().format(record)
formatter = CustomFormatter(
fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
datefmt="%H:%M:%S",
)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.handlers[0].setFormatter(formatter)
def init_wandb(config: _config.TrainConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True):
if not enabled:
wandb.init(mode="disabled")
return
ckpt_dir = config.checkpoint_dir
if not ckpt_dir.exists():
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
if resuming:
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
wandb.init(id=run_id, resume="must", project=config.project_name)
else:
wandb.init(
name=config.exp_name,
config=dataclasses.asdict(config),
project=config.project_name,
)
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
if log_code:
wandb.run.log_code(epath.Path(__file__).parent.parent)
def _load_weights_and_validate(weight_loader: _weight_loaders.WeightLoader, params: at.Params) -> at.Params:
"""Runs the weight loader and validates that the params structure, shapes, and dtypes are unchanged."""
new_params = weight_loader.load(jax.tree.map(lambda x: x, params))
if errors := list(private_tree_util.equality_errors(params, new_params)):
raise ValueError(
"Weight loading changed the params structure:\n"
+ (
"\n".join(
f" - {jax.tree_util.keystr(path)} changed from {thing1} to {thing2}, so {explanation}.\n"
for path, thing1, thing2, explanation in errors
)
)
)
def check(kp, x, y):
if (x := jax.ShapeDtypeStruct(x.shape, x.dtype)) != (y := jax.ShapeDtypeStruct(y.shape, y.dtype)):
raise ValueError(
f"Weight loading changed the params structure: expected {y}, got {x} at {jax.tree_util.keystr(kp)}"
)
jax.tree_util.tree_map_with_path(check, params, new_params)
return new_params
@at.typecheck
def init_train_state(
config: _config.TrainConfig,
model: _model.Model,
init_rng: at.KeyArrayLike,
batch: tuple[_common.Observation, _common.Actions],
mesh: jax.sharding.Mesh,
data_sharding: jax.sharding.Sharding,
*,
resume: bool,
) -> tuple[training_utils.TrainState, Any]:
weight_decay_mask = None
freeze_mask = None
tx = _optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask, freeze_mask)
def init(
rng: at.KeyArrayLike,
data: tuple[_common.Observation, _common.Actions],
params_sharding: jax.sharding.Sharding | None = None,
) -> training_utils.TrainState:
rng, model_rng = jax.random.split(rng)
observation, actions = data
params = model.init_params(model_rng, observation, actions)
# jax.experimental.io_callback raises spmd partitioning warnings, setting constraints
# to replicate params to avoid the warnings. the returned train state will be sharded still
# since fsdp sharding is specified as output_sharding when jitting this function.
if params_sharding is not None:
params = jax.lax.with_sharding_constraint(params, params_sharding)
params = jax.experimental.io_callback(
partial(_load_weights_and_validate, config.weight_loader),
params,
params,
ordered=True,
)
if params_sharding is not None:
params = jax.lax.with_sharding_constraint(params, params_sharding)
return training_utils.TrainState(
step=0,
params=params,
opt_state=tx.init(params),
tx=tx,
ema_decay=config.ema_decay,
ema_params=None if config.ema_decay is None else params,
)
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
train_state_shape = jax.eval_shape(init, init_rng, batch)
state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True)
if resume:
return train_state_shape, state_sharding
train_state = jax.jit(
init,
in_shardings=(replicated_sharding, data_sharding),
out_shardings=state_sharding,
static_argnums=(2,),
)(init_rng, batch, replicated_sharding)
return train_state, state_sharding
@at.typecheck
def train_step(
rng: at.KeyArrayLike,
state: training_utils.TrainState,
model: _model.Model,
batch: tuple[_common.Observation, _common.Actions],
) -> tuple[training_utils.TrainState, dict[str, at.Array]]:
def loss_fn(params: at.Params, rng: at.KeyArrayLike, observation: _common.Observation, actions: _common.Actions):
chunked_loss = model.compute_loss(rng, observation, actions, params=params, train=True)
return jnp.mean(chunked_loss)
train_rng = jax.random.fold_in(rng, state.step)
observation, actions = batch
loss, grads = jax.value_and_grad(loss_fn)(state.params, train_rng, observation, actions)
updates, new_opt_state = state.tx.update(grads, state.opt_state, state.params)
new_params = optax.apply_updates(state.params, updates)
new_state = state.replace(step=state.step + 1, params=new_params, opt_state=new_opt_state)
if state.ema_decay is not None:
new_state = new_state.replace(
ema_params=jax.tree.map(
lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new, state.ema_params, new_params
)
)
kernel_mask = training_utils.mask_from_regex(r".*\['kernel'\]", state.params)
kernel_params = jax.tree.map(lambda p, m: p if m else None, state.params, kernel_mask)
info = {
"loss": loss,
"grad_norm": optax.global_norm(grads), # TODO: do not compute norm for frozen params
"param_norm": optax.global_norm(kernel_params),
}
return new_state, info
def main(config: _config.TrainConfig):
init_logging()
logging.info(f"Running on: {platform.node()}")
if config.batch_size % jax.device_count() != 0:
raise ValueError(
f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
)
jax.config.update("jax_threefry_partitionable", True) # noqa: FBT003
jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser()))
rng = jax.random.key(config.seed)
train_rng, init_rng = jax.random.split(rng)
if jax.device_count() % config.fsdp_devices != 0:
raise ValueError(
f"Number of devices {jax.device_count()} must be divisible by the number of FSDP devices {config.fsdp_devices}."
)
mesh_shape = (jax.device_count() // config.fsdp_devices, config.fsdp_devices)
mesh = jax.make_mesh(mesh_shape, ("batch", "model"))
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(("batch", "model")))
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir(
config.checkpoint_dir,
keep_interval=config.keep_interval,
overwrite=config.overwrite,
resume=config.resume,
)
init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
model = config.create_model()
data_loader = _data_loader.create_data_loader(
config,
model,
sharding=data_sharding,
num_workers=config.num_workers,
shuffle=True,
)
data_iter = iter(data_loader)
batch = next(data_iter)
logging.info(f"Initialized data loader:\n{training_utils.array_tree_to_info(batch)}")
train_state, train_state_sharding = init_train_state(
config, model, init_rng, batch, mesh, data_sharding, resume=resuming
)
jax.block_until_ready(train_state)
logging.info(f"Initialized train state:\n{training_utils.array_tree_to_info(train_state.params)}")
if resuming:
train_state = _checkpoints.restore_state(checkpoint_manager, train_state, data_loader)
ptrain_step = jax.jit(
train_step,
in_shardings=(replicated_sharding, train_state_sharding, None, data_sharding),
out_shardings=(train_state_sharding, replicated_sharding),
donate_argnums=(1,),
)
start_step = int(train_state.step)
pbar = tqdm.tqdm(
range(start_step, config.num_train_steps),
initial=start_step,
total=config.num_train_steps,
dynamic_ncols=True,
)
infos = []
for step in pbar:
train_state, info = ptrain_step(train_rng, train_state, model, batch)
infos.append(info)
if step % config.log_interval == 0:
stacked_infos = common_utils.stack_forest(infos)
reduced_info = jax.device_get(jax.tree.map(jnp.mean, stacked_infos))
info_str = ", ".join(f"{k}={v:.4f}" for k, v in reduced_info.items())
pbar.write(f"Step {step}: {info_str}")
wandb.log(reduced_info, step=step)
infos = []
batch = next(data_iter)
if (step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1:
_checkpoints.save_state(checkpoint_manager, train_state, data_loader, step)
logging.info("Waiting for checkpoint manager to finish")
checkpoint_manager.wait_until_finished()
if __name__ == "__main__":
main(_config.cli())

27
scripts/train_test.py Normal file
View File

@@ -0,0 +1,27 @@
import dataclasses
import pathlib
import pytest
from openpi.training import config as _config
from . import train
@pytest.mark.parametrize("config_name", ["debug"])
def test_train(tmp_path: pathlib.Path, config_name: str):
config = dataclasses.replace(
_config._CONFIGS_DICT[config_name], # noqa: SLF001
batch_size=2,
checkpoint_base_dir=tmp_path / "checkpoint",
exp_name="test",
overwrite=False,
resume=False,
num_train_steps=2,
log_interval=1,
)
train.main(config)
# test resuming
config = dataclasses.replace(config, resume=True, num_train_steps=4)
train.main(config)