Initial commit
This commit is contained in:
0
scripts/__init__.py
Normal file
0
scripts/__init__.py
Normal file
479
scripts/aloha_hd5.py
Normal file
479
scripts/aloha_hd5.py
Normal file
@@ -0,0 +1,479 @@
|
||||
# ruff: noqa
|
||||
"""
|
||||
Script courtesy of Raziel90 https://github.com/huggingface/lerobot/pull/586/files
|
||||
|
||||
Example usage
|
||||
python scripts/aloha_hd5.py --raw-path ~/data/ --dataset-repo-id <hf-username>/<dataset-name> --robot-type <aloha-stationary|aloha-mobile> --fps 50 --video-encoding=false --push=false
|
||||
|
||||
The data will be saved locally the value of the LEROBOT_HOME environment variable. By default this is set to ~/.cache/huggingface/lerobot
|
||||
If you wish to submit the dataset to the hub, you can do so by setting up the hf cli https://huggingface.co/docs/huggingface_hub/en/guides/cli and setting --push=true
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import traceback
|
||||
|
||||
import cv2
|
||||
import h5py
|
||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
import torch
|
||||
|
||||
|
||||
class AlohaHD5Extractor:
|
||||
TAGS = ["aloha", "robotics", "hdf5"]
|
||||
aloha_stationary = "aloha-stationary"
|
||||
aloha_mobile = "aloha-mobile"
|
||||
|
||||
@staticmethod
|
||||
def get_cameras(hdf5_data: h5py.File):
|
||||
"""
|
||||
Extracts the list of RGB camera keys from the given HDF5 data.
|
||||
Parameters
|
||||
----------
|
||||
hdf5_data : h5py.File
|
||||
The HDF5 file object containing the dataset.
|
||||
Returns
|
||||
-------
|
||||
list of str
|
||||
A list of keys corresponding to RGB cameras in the dataset.
|
||||
"""
|
||||
|
||||
rgb_cameras = [key for key in hdf5_data["/observations/images"] if "depth" not in key]
|
||||
return rgb_cameras
|
||||
|
||||
@staticmethod
|
||||
def check_format(episode_list: list[str] | list[Path], image_compressed: bool = True):
|
||||
"""
|
||||
Check the format of the given list of HDF5 files.
|
||||
Parameters
|
||||
----------
|
||||
episode_list : list of str or list of Path
|
||||
List of paths to the HDF5 files to be checked.
|
||||
image_compressed : bool, optional
|
||||
Flag indicating whether the images are compressed (default is True).
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the episode_list is empty.
|
||||
If any HDF5 file is missing required keys '/action' or '/observations/qpos'.
|
||||
If the '/action' or '/observations/qpos' keys do not have 2 dimensions.
|
||||
If the number of frames in '/action' and '/observations/qpos' keys do not match.
|
||||
If the number of frames in '/observations/images/{camera}' does not match the number of frames in '/action' and '/observations/qpos'.
|
||||
If the dimensions of images do not match the expected dimensions based on the image_compressed flag.
|
||||
If uncompressed images do not have the expected (h, w, c) format.
|
||||
"""
|
||||
|
||||
if not episode_list:
|
||||
raise ValueError("No hdf5 files found in the raw directory. Make sure they are named 'episode_*.hdf5'")
|
||||
for episode_path in episode_list:
|
||||
with h5py.File(episode_path, "r") as data:
|
||||
if not all(key in data for key in ["/action", "/observations/qpos"]):
|
||||
raise ValueError(
|
||||
"Missing required keys in the hdf5 file. Make sure the keys '/action' and '/observations/qpos' are present."
|
||||
)
|
||||
|
||||
if not data["/action"].ndim == data["/observations/qpos"].ndim == 2:
|
||||
raise ValueError("The '/action' and '/observations/qpos' keys should have both 2 dimensions.")
|
||||
|
||||
if (num_frames := data["/action"].shape[0]) != data["/observations/qpos"].shape[0]:
|
||||
raise ValueError(
|
||||
"The '/action' and '/observations/qpos' keys should have the same number of frames."
|
||||
)
|
||||
|
||||
for camera in AlohaHD5Extractor.get_cameras(data):
|
||||
if num_frames != data[f"/observations/images/{camera}"].shape[0]:
|
||||
raise ValueError(
|
||||
f"The number of frames in '/observations/images/{camera}' should be the same as in '/action' and '/observations/qpos' keys."
|
||||
)
|
||||
|
||||
expected_dims = 2 if image_compressed else 4
|
||||
if data[f"/observations/images/{camera}"].ndim != expected_dims:
|
||||
raise ValueError(
|
||||
f"Expect {expected_dims} dimensions for {'compressed' if image_compressed else 'uncompressed'} images but {data[f'/observations/images/{camera}'].ndim} provided."
|
||||
)
|
||||
if not image_compressed:
|
||||
b, h, w, c = data[f"/observations/images/{camera}"].shape
|
||||
if not c < h and c < w:
|
||||
raise ValueError(f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided.")
|
||||
|
||||
@staticmethod
|
||||
def extract_episode_frames(
|
||||
episode_path: str | Path, features: dict[str, dict], image_compressed: bool
|
||||
) -> list[dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Extract frames from an episode stored in an HDF5 file.
|
||||
Parameters
|
||||
----------
|
||||
episode_path : str or Path
|
||||
Path to the HDF5 file containing the episode data.
|
||||
features : dict of str to dict
|
||||
Dictionary where keys are feature identifiers and values are dictionaries with feature details.
|
||||
image_compressed : bool
|
||||
Flag indicating whether the images are stored in a compressed format.
|
||||
Returns
|
||||
-------
|
||||
list of dict of str to torch.Tensor
|
||||
List of frames, where each frame is a dictionary mapping feature identifiers to tensors.
|
||||
"""
|
||||
|
||||
frames = []
|
||||
with h5py.File(episode_path, "r") as file:
|
||||
for frame_idx in range(file["/action"].shape[0]):
|
||||
frame = {}
|
||||
for feature_id in features:
|
||||
feature_name_hd5 = feature_id.replace(".", "/")
|
||||
if "images" in feature_id.split("."):
|
||||
image = (
|
||||
(file[feature_name_hd5][frame_idx])
|
||||
if not image_compressed
|
||||
else cv2.imdecode(file[feature_name_hd5][frame_idx], 1)
|
||||
)
|
||||
frame[feature_id] = torch.from_numpy(image.transpose(2, 0, 1))
|
||||
else:
|
||||
frame[feature_id] = torch.from_numpy(file[feature_name_hd5][frame_idx])
|
||||
frames.append(frame)
|
||||
return frames
|
||||
|
||||
@staticmethod
|
||||
def define_features(
|
||||
hdf5_file_path: Path, image_compressed: bool = True, encode_as_video: bool = True
|
||||
) -> dict[str, dict]:
|
||||
"""
|
||||
Define features from an HDF5 file.
|
||||
Parameters
|
||||
----------
|
||||
hdf5_file_path : Path
|
||||
The path to the HDF5 file.
|
||||
image_compressed : bool, optional
|
||||
Whether the images are compressed, by default True.
|
||||
encode_as_video : bool, optional
|
||||
Whether to encode images as video or as images, by default True.
|
||||
Returns
|
||||
-------
|
||||
dict[str, dict]
|
||||
A dictionary where keys are topic names and values are dictionaries
|
||||
containing feature information such as dtype, shape, and names.
|
||||
"""
|
||||
|
||||
# Initialize lists to store topics and features
|
||||
topics = []
|
||||
features = {}
|
||||
|
||||
# Open the HDF5 file
|
||||
with h5py.File(hdf5_file_path, "r") as hdf5_file:
|
||||
# Collect all dataset names in the HDF5 file
|
||||
hdf5_file.visititems(lambda name, obj: topics.append(name) if isinstance(obj, h5py.Dataset) else None)
|
||||
|
||||
# Iterate over each topic to define its features
|
||||
for topic in topics:
|
||||
# If the topic is an image, define it as a video feature
|
||||
if "images" in topic.split("/"):
|
||||
sample = hdf5_file[topic][0]
|
||||
features[topic.replace("/", ".")] = {
|
||||
"dtype": "video" if encode_as_video else "image",
|
||||
"shape": cv2.imdecode(hdf5_file[topic][0], 1).transpose(2, 0, 1).shape
|
||||
if image_compressed
|
||||
else sample.shape,
|
||||
"names": [
|
||||
"channel",
|
||||
"height",
|
||||
"width",
|
||||
],
|
||||
}
|
||||
# Skip compressed length topics
|
||||
elif "compress_len" in topic.split("/"):
|
||||
continue
|
||||
# Otherwise, define it as a regular feature
|
||||
else:
|
||||
features[topic.replace("/", ".")] = {
|
||||
"dtype": str(hdf5_file[topic][0].dtype),
|
||||
"shape": (topic_shape := hdf5_file[topic][0].shape),
|
||||
"names": [f"{topic.split('/')[-1]}_{k}" for k in range(topic_shape[0])],
|
||||
}
|
||||
# Return the defined features
|
||||
return features
|
||||
|
||||
|
||||
class DatasetConverter:
|
||||
"""
|
||||
A class to convert datasets to Lerobot format.
|
||||
Parameters
|
||||
----------
|
||||
raw_path : Path or str
|
||||
The path to the raw dataset.
|
||||
dataset_repo_id : str
|
||||
The repository ID where the dataset will be stored.
|
||||
fps : int
|
||||
Frames per second for the dataset.
|
||||
robot_type : str, optional
|
||||
The type of robot, by default "".
|
||||
encode_as_videos : bool, optional
|
||||
Whether to encode images as videos, by default True.
|
||||
image_compressed : bool, optional
|
||||
Whether the images are compressed, by default True.
|
||||
image_writer_processes : int, optional
|
||||
Number of processes for writing images, by default 0.
|
||||
image_writer_threads : int, optional
|
||||
Number of threads for writing images, by default 0.
|
||||
Methods
|
||||
-------
|
||||
extract_episode(episode_path, task_description='')
|
||||
Extracts frames from a single episode and saves it with a description.
|
||||
extract_episodes(episode_description='')
|
||||
Extracts frames from all episodes and saves them with a description.
|
||||
push_dataset_to_hub(dataset_tags=None, private=False, push_videos=True, license="apache-2.0")
|
||||
Pushes the dataset to the Hugging Face Hub.
|
||||
init_lerobot_dataset()
|
||||
Initializes the Lerobot dataset.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
raw_path: Path | str,
|
||||
dataset_repo_id: str,
|
||||
fps: int,
|
||||
robot_type: str = "",
|
||||
encode_as_videos: bool = True,
|
||||
image_compressed: bool = True,
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
):
|
||||
self.raw_path = raw_path if isinstance(raw_path, Path) else Path(raw_path)
|
||||
self.dataset_repo_id = dataset_repo_id
|
||||
self.fps = fps
|
||||
self.robot_type = robot_type
|
||||
self.image_compressed = image_compressed
|
||||
self.image_writer_threads = image_writer_threads
|
||||
self.image_writer_processes = image_writer_processes
|
||||
self.encode_as_videos = encode_as_videos
|
||||
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.logger.setLevel(logging.INFO)
|
||||
|
||||
# Add console handler
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter("%(asctime)s - [%(name)s] - %(message)s")
|
||||
console_handler.setFormatter(formatter)
|
||||
self.logger.addHandler(console_handler)
|
||||
|
||||
self.logger.info(f"{'-'*10} Aloha HD5 -> Lerobot Converter {'-'*10}")
|
||||
self.logger.info(f"Processing Aloha HD5 dataset from {self.raw_path}")
|
||||
self.logger.info(f"Dataset will be stored in {self.dataset_repo_id}")
|
||||
self.logger.info(f"FPS: {self.fps}")
|
||||
self.logger.info(f"Robot type: {self.robot_type}")
|
||||
self.logger.info(f"Image compressed: {self.image_compressed}")
|
||||
self.logger.info(f"Encoding images as videos: {self.encode_as_videos}")
|
||||
self.logger.info(f"#writer processes: {self.image_writer_processes}")
|
||||
self.logger.info(f"#writer threads: {self.image_writer_threads}")
|
||||
|
||||
self.episode_list = list(self.raw_path.glob("episode_*.hdf5"))
|
||||
AlohaHD5Extractor.check_format(self.episode_list, image_compressed=self.image_compressed)
|
||||
self.features = AlohaHD5Extractor.define_features(
|
||||
self.episode_list[0],
|
||||
image_compressed=self.image_compressed,
|
||||
encode_as_video=self.encode_as_videos,
|
||||
)
|
||||
|
||||
def extract_episode(self, episode_path, task_description: str = ""):
|
||||
"""
|
||||
Extracts frames from an episode and saves them to the dataset.
|
||||
Parameters
|
||||
----------
|
||||
episode_path : str
|
||||
The path to the episode file.
|
||||
task_description : str, optional
|
||||
A description of the task associated with the episode (default is an empty string).
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
|
||||
for frame in AlohaHD5Extractor.extract_episode_frames(episode_path, self.features, self.image_compressed):
|
||||
self.dataset.add_frame(frame)
|
||||
self.logger.info(f"Saving Episode with Description: {task_description} ...")
|
||||
self.dataset.save_episode(task=task_description)
|
||||
|
||||
def extract_episodes(self, episode_description: str = ""):
|
||||
"""
|
||||
Extracts episodes from the episode list and processes them.
|
||||
Parameters
|
||||
----------
|
||||
episode_description : str, optional
|
||||
A description of the task to be passed to the extract_episode method (default is '').
|
||||
Raises
|
||||
------
|
||||
Exception
|
||||
If an error occurs during the processing of an episode, it will be caught and printed.
|
||||
Notes
|
||||
-----
|
||||
After processing all episodes, the dataset is consolidated.
|
||||
"""
|
||||
|
||||
for episode_path in self.episode_list:
|
||||
try:
|
||||
self.extract_episode(episode_path, task_description=episode_description)
|
||||
except Exception as e:
|
||||
print(f"Error processing episode {episode_path}", f"{e}")
|
||||
traceback.print_exc()
|
||||
continue
|
||||
self.dataset.consolidate()
|
||||
|
||||
def push_dataset_to_hub(
|
||||
self,
|
||||
dataset_tags: list[str] | None = None,
|
||||
private: bool = False,
|
||||
push_videos: bool = True,
|
||||
license: str | None = "apache-2.0",
|
||||
):
|
||||
"""
|
||||
Pushes the dataset to the Hugging Face Hub.
|
||||
Parameters
|
||||
----------
|
||||
dataset_tags : list of str, optional
|
||||
A list of tags to associate with the dataset on the Hub. Default is None.
|
||||
private : bool, optional
|
||||
If True, the dataset will be private. Default is False.
|
||||
push_videos : bool, optional
|
||||
If True, videos will be pushed along with the dataset. Default is True.
|
||||
license : str, optional
|
||||
The license under which the dataset is released. Default is "apache-2.0".
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
|
||||
self.logger.info(f"Pushing dataset to Hugging Face Hub. ID: {self.dataset_repo_id} ...")
|
||||
self.dataset.push_to_hub(
|
||||
tags=dataset_tags,
|
||||
license=license,
|
||||
push_videos=push_videos,
|
||||
private=private,
|
||||
)
|
||||
|
||||
def init_lerobot_dataset(self):
|
||||
"""
|
||||
Initializes the LeRobot dataset.
|
||||
This method cleans the cache if the dataset already exists and then creates a new LeRobot dataset.
|
||||
Returns
|
||||
-------
|
||||
LeRobotDataset
|
||||
The initialized LeRobot dataset.
|
||||
"""
|
||||
|
||||
# Clean the cache if the dataset already exists
|
||||
if os.path.exists(LEROBOT_HOME / self.dataset_repo_id):
|
||||
shutil.rmtree(LEROBOT_HOME / self.dataset_repo_id)
|
||||
self.dataset = LeRobotDataset.create(
|
||||
repo_id=self.dataset_repo_id,
|
||||
fps=self.fps,
|
||||
robot_type=self.robot_type,
|
||||
features=self.features,
|
||||
image_writer_threads=self.image_writer_threads,
|
||||
image_writer_processes=self.image_writer_processes,
|
||||
)
|
||||
|
||||
return self.dataset
|
||||
|
||||
|
||||
def str2bool(value):
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
value = value.lower()
|
||||
if value in ("yes", "true", "t", "y", "1"):
|
||||
return True
|
||||
if value in ("no", "false", "f", "n", "0"):
|
||||
return False
|
||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Convert Aloha HD5 dataset and push to Hugging Face hub.
|
||||
This script processes raw HDF5 files from the Aloha dataset, converts them into a specified format,
|
||||
and optionally uploads the dataset to the Hugging Face hub.
|
||||
Parameters
|
||||
----------
|
||||
--raw-path : Path
|
||||
Directory containing the raw HDF5 files.
|
||||
--dataset-repo-id : str
|
||||
Repository ID where the dataset will be stored.
|
||||
--fps : int
|
||||
Frames per second for the dataset.
|
||||
--robot-type : str, optional
|
||||
Type of robot, either "aloha-stationary" or "aloha-mobile". Default is "aloha-stationary".
|
||||
--private : bool, optional
|
||||
Set to True to make the dataset private. Default is False.
|
||||
--push-videos : bool, optional
|
||||
Set to True to push videos to the hub. Default is True.
|
||||
--license : str, optional
|
||||
License for the dataset. Default is "apache-2.0".
|
||||
--image-compressed : bool, optional
|
||||
Set to True if the images are compressed. Default is True.
|
||||
--video-encoding : bool, optional
|
||||
Set to True to encode images as videos. Default is True.
|
||||
--nproc : int, optional
|
||||
Number of image writer processes. Default is 10.
|
||||
--nthreads : int, optional
|
||||
Number of image writer threads. Default is 5.
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(description="Convert Aloha HD5 dataset and push to Hugging Face hub.")
|
||||
parser.add_argument("--raw-path", type=Path, required=True, help="Directory containing the raw hdf5 files.")
|
||||
parser.add_argument(
|
||||
"--dataset-repo-id", type=str, required=True, help="Repository ID where the dataset will be stored."
|
||||
)
|
||||
parser.add_argument("--fps", type=int, required=True, help="Frames per second for the dataset.")
|
||||
parser.add_argument(
|
||||
"--description", type=str, help="Description of the dataset.", default="Aloha recorded dataset."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--robot-type",
|
||||
type=str,
|
||||
choices=["aloha-stationary", "aloha-mobile"],
|
||||
default="aloha-stationary",
|
||||
help="Type of robot.",
|
||||
)
|
||||
parser.add_argument("--private", type=str2bool, default=False, help="Set to True to make the dataset private.")
|
||||
parser.add_argument("--push", type=str2bool, default=True, help="Set to True to push videos to the hub.")
|
||||
parser.add_argument("--license", type=str, default="apache-2.0", help="License for the dataset.")
|
||||
parser.add_argument(
|
||||
"--image-compressed", type=str2bool, default=True, help="Set to True if the images are compressed."
|
||||
)
|
||||
parser.add_argument("--video-encoding", type=str2bool, default=True, help="Set to True to encode images as videos.")
|
||||
|
||||
parser.add_argument("--nproc", type=int, default=10, help="Number of image writer processes.")
|
||||
parser.add_argument("--nthreads", type=int, default=5, help="Number of image writer threads.")
|
||||
|
||||
args = parser.parse_args()
|
||||
print(
|
||||
args.video_encoding,
|
||||
"-------------------------------------------------------------------------------------------------------",
|
||||
)
|
||||
|
||||
converter = DatasetConverter(
|
||||
raw_path=args.raw_path,
|
||||
dataset_repo_id=args.dataset_repo_id,
|
||||
fps=args.fps,
|
||||
robot_type=args.robot_type,
|
||||
image_compressed=args.image_compressed,
|
||||
encode_as_videos=args.video_encoding,
|
||||
image_writer_processes=args.nproc,
|
||||
image_writer_threads=args.nthreads,
|
||||
)
|
||||
converter.init_lerobot_dataset()
|
||||
converter.extract_episodes(episode_description=args.description)
|
||||
|
||||
if args.push:
|
||||
converter.push_dataset_to_hub(
|
||||
dataset_tags=AlohaHD5Extractor.TAGS, private=args.private, push_videos=True, license=args.license
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
29
scripts/compose.yml
Normal file
29
scripts/compose.yml
Normal file
@@ -0,0 +1,29 @@
|
||||
# Run with:
|
||||
# docker compose -f scripts/compose.yml up --build
|
||||
services:
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: scripts/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
# Populate configured openpi data home to /openpi_assets inside the container.
|
||||
# Populate aws credential inside the container.
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
||||
- ~/.aws/:/root/.aws/
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
- OPENPI_DATA_HOME=/openpi_assets
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
67
scripts/compute_norm_stats.py
Normal file
67
scripts/compute_norm_stats.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Compute normalization statistics for a config.
|
||||
|
||||
This script is used to compute the normalization statistics for a given config. It
|
||||
will compute the mean and standard deviation of the data in the dataset and save it
|
||||
to the config metadata directory.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
import openpi.shared.normalize as normalize
|
||||
import openpi.training.config as _config
|
||||
import openpi.training.data_loader as _data_loader
|
||||
|
||||
|
||||
def create_dataset(config: _config.TrainConfig) -> tuple[str, _data_loader.Dataset]:
|
||||
model = config.create_model()
|
||||
data_config = config.data.create(config.metadata_dir, model)
|
||||
if data_config.repo_id is None:
|
||||
raise ValueError("Data config must have a repo_id")
|
||||
dataset = _data_loader.TransformedDataset(
|
||||
_data_loader.create_dataset(data_config, model),
|
||||
[
|
||||
*data_config.repack_transforms.inputs,
|
||||
*data_config.data_transforms.inputs,
|
||||
],
|
||||
)
|
||||
return data_config.repo_id, dataset
|
||||
|
||||
|
||||
def main(config_name: str, max_frames: int | None = None):
|
||||
config = _config.get_config(config_name)
|
||||
repo_id, dataset = create_dataset(config)
|
||||
|
||||
num_frames = len(dataset)
|
||||
shuffle = False
|
||||
|
||||
if max_frames is not None and max_frames < num_frames:
|
||||
num_frames = max_frames
|
||||
shuffle = True
|
||||
|
||||
data_loader = _data_loader.TorchDataLoader(
|
||||
dataset,
|
||||
local_batch_size=1,
|
||||
num_workers=8,
|
||||
shuffle=shuffle,
|
||||
num_batches=num_frames,
|
||||
)
|
||||
|
||||
keys = ["state", "actions"]
|
||||
stats = {key: normalize.RunningStats() for key in keys}
|
||||
|
||||
for batch in tqdm.tqdm(data_loader, total=num_frames, desc="Computing stats"):
|
||||
for key in keys:
|
||||
values = np.asarray(batch[key][0])
|
||||
stats[key].update(values.reshape(-1, values.shape[-1]))
|
||||
|
||||
norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
|
||||
|
||||
output_path = config.metadata_dir / repo_id
|
||||
print(f"Writing stats to: {output_path}")
|
||||
normalize.save(output_path, norm_stats)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tyro.cli(main)
|
||||
37
scripts/install_docker_ubuntu22.sh
Executable file
37
scripts/install_docker_ubuntu22.sh
Executable file
@@ -0,0 +1,37 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Add Docker's official GPG key:
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y ca-certificates curl
|
||||
sudo install -m 0755 -d /etc/apt/keyrings
|
||||
sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc
|
||||
sudo chmod a+r /etc/apt/keyrings/docker.asc
|
||||
|
||||
# Add the repository to Apt sources:
|
||||
echo \
|
||||
"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \
|
||||
$(. /etc/os-release && echo "$VERSION_CODENAME") stable" |
|
||||
sudo tee /etc/apt/sources.list.d/docker.list >/dev/null
|
||||
sudo apt-get update
|
||||
|
||||
sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin
|
||||
|
||||
# Add current user to the 'docker' group, which allows them to use docker commands (docker build, docker run, etc).
|
||||
# See https://docs.docker.com/engine/install/linux-postinstall/
|
||||
username=$(whoami)
|
||||
sudo usermod -aG docker $username
|
||||
|
||||
# Configure docker to start automatically on system boot.
|
||||
sudo systemctl enable docker.service
|
||||
sudo systemctl enable containerd.service
|
||||
|
||||
# https://forums.docker.com/t/docker-credential-desktop-exe-executable-file-not-found-in-path-using-wsl2/100225/5
|
||||
if [ ~/.docker/config.json ]; then
|
||||
sed -i 's/credsStore/credStore/g' ~/.docker/config.json
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "********************************************************************"
|
||||
echo "**** Restart to allow Docker permission changes to take effect. ****"
|
||||
echo "********************************************************************"
|
||||
echo ""
|
||||
17
scripts/install_nvidia_container_toolkit.sh
Executable file
17
scripts/install_nvidia_container_toolkit.sh
Executable file
@@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Installs the NVIDIA Container Toolkit, which allows Docker containers to access NVIDIA GPUs.
|
||||
# NVIDIA's official documentation: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
|
||||
|
||||
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg &&
|
||||
curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list |
|
||||
sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' |
|
||||
sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
|
||||
|
||||
# NVIDIA's documenation omits 'sudo' in the following command, but it is required.
|
||||
sudo sed -i -e '/experimental/ s/^#//g' /etc/apt/sources.list.d/nvidia-container-toolkit.list
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y nvidia-container-toolkit
|
||||
|
||||
sudo nvidia-ctk runtime configure --runtime=docker
|
||||
sudo systemctl restart docker
|
||||
34
scripts/serve_policy.Dockerfile
Normal file
34
scripts/serve_policy.Dockerfile
Normal file
@@ -0,0 +1,34 @@
|
||||
# Dockerfile for serving a PI policy.
|
||||
# Based on UV's instructions: https://docs.astral.sh/uv/guides/integration/docker/#developing-in-a-container
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t openpi_server -f scripts/serve_policy.Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app --gpus=all openpi_server /bin/bash
|
||||
|
||||
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Needed because LeRobot uses git-lfs.
|
||||
RUN apt-get update && apt-get install -y git git-lfs
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Write the virtual environment outside of the project directory so it doesn't
|
||||
# leak out of the container when we mount the application code.
|
||||
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
||||
|
||||
# Install the project's dependencies using the lockfile and settings
|
||||
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=uv.lock,target=uv.lock \
|
||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||
--mount=type=bind,source=packages/openpi-client/pyproject.toml,target=packages/openpi-client/pyproject.toml \
|
||||
--mount=type=bind,source=packages/openpi-client/src,target=packages/openpi-client/src \
|
||||
GIT_LFS_SKIP_SMUDGE=1 uv sync --frozen --no-install-project --no-dev
|
||||
|
||||
CMD /bin/bash -c "uv run scripts/serve_policy.py $SERVER_ARGS"
|
||||
243
scripts/serve_policy.py
Normal file
243
scripts/serve_policy.py
Normal file
@@ -0,0 +1,243 @@
|
||||
from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import tyro
|
||||
|
||||
from openpi import transforms
|
||||
from openpi.models import exported as _exported
|
||||
from openpi.models import model as _model
|
||||
from openpi.policies import aloha_policy
|
||||
from openpi.policies import calvin_policy
|
||||
from openpi.policies import droid_policy
|
||||
from openpi.policies import libero_policy
|
||||
from openpi.policies import policy as _policy
|
||||
from openpi.policies import policy_config as _policy_config
|
||||
from openpi.serving import websocket_policy_server
|
||||
from openpi.training import config as _config
|
||||
|
||||
|
||||
class EnvMode(enum.Enum):
|
||||
"""Supported environments."""
|
||||
|
||||
ALOHA = "aloha"
|
||||
ALOHA_SIM = "aloha_sim"
|
||||
DROID = "droid"
|
||||
CALVIN = "calvin"
|
||||
LIBERO = "libero"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Exported:
|
||||
"""Load an exported checkpoint."""
|
||||
|
||||
# Checkpoint directory (e.g., "s3://openpi-assets/exported/pi0_aloha/model").
|
||||
dir: str
|
||||
# Processor name to load the norm stats from. If not provided, the default processor for the environment will be used.
|
||||
processor: str | None = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Checkpoint:
|
||||
"""Load a policy from a trained checkpoint."""
|
||||
|
||||
# Training config name (e.g., "pi0_aloha_sim").
|
||||
config: str
|
||||
# Checkpoint directory (e.g., "checkpoints/pi0_aloha_sim/exp/10000").
|
||||
dir: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
"""Arguments for the serve_policy script."""
|
||||
|
||||
# Environment to serve the policy for.
|
||||
env: EnvMode = EnvMode.ALOHA_SIM
|
||||
# Specifies how to load the policy. If not provided, the default policy for the environment will be used.
|
||||
policy: Checkpoint | Exported | None = None
|
||||
|
||||
# If provided, overrides the default prompt for the policy.
|
||||
default_prompt: str | None = None
|
||||
|
||||
# Port to serve the policy on.
|
||||
port: int = 8000
|
||||
# Record the policy's behavior for debugging.
|
||||
record: bool = False
|
||||
|
||||
|
||||
def repack_from_env(env: EnvMode) -> transforms.Group:
|
||||
"""Creates environment specific repack transforms."""
|
||||
# TODO(ury): Move this to the runtime.
|
||||
match env:
|
||||
case EnvMode.ALOHA:
|
||||
return transforms.Group(
|
||||
inputs=[aloha_policy.ActInputsRepack()],
|
||||
outputs=[aloha_policy.ActOutputsRepack()],
|
||||
)
|
||||
case EnvMode.ALOHA_SIM:
|
||||
return transforms.Group(
|
||||
inputs=[aloha_policy.ActInputsRepack()],
|
||||
outputs=[aloha_policy.ActOutputsRepack()],
|
||||
)
|
||||
case _:
|
||||
return transforms.Group()
|
||||
|
||||
|
||||
# Default exported models.
|
||||
DEFAULT_EXPORTED: dict[EnvMode, Exported] = {
|
||||
EnvMode.ALOHA: Exported(
|
||||
dir="s3://openpi-assets/exported/pi0_aloha/model",
|
||||
processor="trossen_biarm_single_base_cam_24dim",
|
||||
),
|
||||
EnvMode.ALOHA_SIM: Exported(
|
||||
dir="s3://openpi-assets/exported/pi0_aloha_sim/model",
|
||||
processor="huggingface_aloha_sim_transfer_cube",
|
||||
),
|
||||
EnvMode.DROID: Exported(
|
||||
dir="s3://openpi-assets/exported/pi0_droid/model",
|
||||
processor="openx_droid",
|
||||
),
|
||||
EnvMode.CALVIN: Exported(
|
||||
dir="s3://openpi-assets/exported/pi0_calvin/model",
|
||||
processor="calvin",
|
||||
),
|
||||
EnvMode.LIBERO: Exported(
|
||||
dir="s3://openpi-assets/exported/pi0_libero/model",
|
||||
processor="libero",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def create_default_policy(
|
||||
env: EnvMode, *, default_prompt: str | None = None, exported: Exported | None = None
|
||||
) -> _policy.Policy:
|
||||
model: _model.BaseModel
|
||||
config: _policy_config.PolicyConfig
|
||||
|
||||
default_exported = DEFAULT_EXPORTED[env]
|
||||
if exported:
|
||||
checkpoint_dir = exported.dir
|
||||
processor = exported.processor or default_exported.processor
|
||||
else:
|
||||
checkpoint_dir = default_exported.dir
|
||||
processor = default_exported.processor
|
||||
assert processor, "Default processor must be always set"
|
||||
|
||||
logging.info("Loading model...")
|
||||
model = _exported.PiModel.from_checkpoint(checkpoint_dir)
|
||||
|
||||
def make_policy_config(
|
||||
input_layers: Sequence[transforms.DataTransformFn],
|
||||
output_layers: Sequence[transforms.DataTransformFn],
|
||||
sample_kwargs: dict[str, Any] | None = None,
|
||||
):
|
||||
sample_kwargs = sample_kwargs or {"num_steps": 10}
|
||||
return _policy_config.PolicyConfig(
|
||||
model=model,
|
||||
norm_stats=model.norm_stats(processor),
|
||||
default_prompt=default_prompt,
|
||||
input_layers=input_layers,
|
||||
output_layers=output_layers,
|
||||
sample_kwargs=sample_kwargs,
|
||||
)
|
||||
|
||||
logging.info("Creating policy...")
|
||||
match env:
|
||||
case EnvMode.ALOHA:
|
||||
delta_action_mask = _policy_config.make_bool_mask(6, -1, 6, -1)
|
||||
config = make_policy_config(
|
||||
input_layers=[
|
||||
aloha_policy.ActInputsRepack(),
|
||||
aloha_policy.AlohaInputs(
|
||||
action_dim=model.action_dim,
|
||||
delta_action_mask=delta_action_mask,
|
||||
adapt_to_pi=True,
|
||||
),
|
||||
],
|
||||
output_layers=[
|
||||
aloha_policy.AlohaOutputs(
|
||||
delta_action_mask=delta_action_mask,
|
||||
adapt_to_pi=True,
|
||||
),
|
||||
aloha_policy.ActOutputsRepack(),
|
||||
],
|
||||
)
|
||||
case EnvMode.ALOHA_SIM:
|
||||
config = make_policy_config(
|
||||
input_layers=[
|
||||
aloha_policy.ActInputsRepack(),
|
||||
aloha_policy.AlohaInputs(action_dim=model.action_dim),
|
||||
],
|
||||
output_layers=[
|
||||
aloha_policy.AlohaOutputs(),
|
||||
aloha_policy.ActOutputsRepack(),
|
||||
],
|
||||
)
|
||||
case EnvMode.DROID:
|
||||
config = make_policy_config(
|
||||
input_layers=[
|
||||
droid_policy.DroidInputs(action_dim=model.action_dim),
|
||||
],
|
||||
output_layers=[
|
||||
droid_policy.DroidOutputs(),
|
||||
transforms.SubsampleActions(stride=5),
|
||||
],
|
||||
)
|
||||
case EnvMode.CALVIN:
|
||||
config = make_policy_config(
|
||||
input_layers=[
|
||||
calvin_policy.CalvinInputs(action_dim=model.action_dim),
|
||||
],
|
||||
output_layers=[
|
||||
calvin_policy.CalvinOutputs(),
|
||||
],
|
||||
)
|
||||
case EnvMode.LIBERO:
|
||||
config = make_policy_config(
|
||||
input_layers=[
|
||||
libero_policy.LiberoInputs(action_dim=model.action_dim),
|
||||
],
|
||||
output_layers=[
|
||||
libero_policy.LiberoOutputs(),
|
||||
],
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Unknown environment mode: {env}")
|
||||
|
||||
return _policy_config.create_policy(config)
|
||||
|
||||
|
||||
def create_policy(args: Args) -> _policy.Policy:
|
||||
match args.policy:
|
||||
case Checkpoint():
|
||||
return _policy_config.create_trained_policy(
|
||||
_config.get_config(args.policy.config),
|
||||
args.policy.dir,
|
||||
repack_transforms=repack_from_env(args.env),
|
||||
default_prompt=args.default_prompt,
|
||||
)
|
||||
case Exported():
|
||||
return create_default_policy(args.env, default_prompt=args.default_prompt, exported=args.policy)
|
||||
case None:
|
||||
return create_default_policy(args.env, default_prompt=args.default_prompt)
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
policy = create_policy(args)
|
||||
|
||||
# Record the policy's behavior.
|
||||
if args.record:
|
||||
policy = _policy.PolicyRecorder(policy, "policy_records")
|
||||
|
||||
logging.info("Creating server...")
|
||||
server = websocket_policy_server.WebsocketPolicyServer(policy=policy, host="0.0.0.0", port=args.port)
|
||||
|
||||
logging.info("Serving...")
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, force=True)
|
||||
main(tyro.cli(Args))
|
||||
284
scripts/train.py
Normal file
284
scripts/train.py
Normal file
@@ -0,0 +1,284 @@
|
||||
import dataclasses
|
||||
from functools import partial
|
||||
import logging
|
||||
import platform
|
||||
from typing import Any
|
||||
|
||||
import etils.epath as epath
|
||||
from flax.training import common_utils
|
||||
import jax
|
||||
import jax._src.tree_util as private_tree_util
|
||||
import jax.experimental
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
import tqdm_loggable.auto as tqdm
|
||||
import wandb
|
||||
|
||||
import openpi.models.common as _common
|
||||
import openpi.models.model as _model
|
||||
import openpi.shared.array_typing as at
|
||||
import openpi.training.checkpoints as _checkpoints
|
||||
import openpi.training.config as _config
|
||||
import openpi.training.data_loader as _data_loader
|
||||
import openpi.training.optimizer as _optimizer
|
||||
import openpi.training.sharding as sharding
|
||||
import openpi.training.utils as training_utils
|
||||
import openpi.training.weight_loaders as _weight_loaders
|
||||
|
||||
|
||||
def init_logging():
|
||||
"""Custom logging format for better readability."""
|
||||
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
|
||||
|
||||
class CustomFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
record.levelname = level_mapping.get(record.levelname, record.levelname)
|
||||
return super().format(record)
|
||||
|
||||
formatter = CustomFormatter(
|
||||
fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
|
||||
def init_wandb(config: _config.TrainConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True):
|
||||
if not enabled:
|
||||
wandb.init(mode="disabled")
|
||||
return
|
||||
|
||||
ckpt_dir = config.checkpoint_dir
|
||||
if not ckpt_dir.exists():
|
||||
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
|
||||
if resuming:
|
||||
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
|
||||
wandb.init(id=run_id, resume="must", project=config.project_name)
|
||||
else:
|
||||
wandb.init(
|
||||
name=config.exp_name,
|
||||
config=dataclasses.asdict(config),
|
||||
project=config.project_name,
|
||||
)
|
||||
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
|
||||
|
||||
if log_code:
|
||||
wandb.run.log_code(epath.Path(__file__).parent.parent)
|
||||
|
||||
|
||||
def _load_weights_and_validate(weight_loader: _weight_loaders.WeightLoader, params: at.Params) -> at.Params:
|
||||
"""Runs the weight loader and validates that the params structure, shapes, and dtypes are unchanged."""
|
||||
new_params = weight_loader.load(jax.tree.map(lambda x: x, params))
|
||||
|
||||
if errors := list(private_tree_util.equality_errors(params, new_params)):
|
||||
raise ValueError(
|
||||
"Weight loading changed the params structure:\n"
|
||||
+ (
|
||||
"\n".join(
|
||||
f" - {jax.tree_util.keystr(path)} changed from {thing1} to {thing2}, so {explanation}.\n"
|
||||
for path, thing1, thing2, explanation in errors
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def check(kp, x, y):
|
||||
if (x := jax.ShapeDtypeStruct(x.shape, x.dtype)) != (y := jax.ShapeDtypeStruct(y.shape, y.dtype)):
|
||||
raise ValueError(
|
||||
f"Weight loading changed the params structure: expected {y}, got {x} at {jax.tree_util.keystr(kp)}"
|
||||
)
|
||||
|
||||
jax.tree_util.tree_map_with_path(check, params, new_params)
|
||||
|
||||
return new_params
|
||||
|
||||
|
||||
@at.typecheck
|
||||
def init_train_state(
|
||||
config: _config.TrainConfig,
|
||||
model: _model.Model,
|
||||
init_rng: at.KeyArrayLike,
|
||||
batch: tuple[_common.Observation, _common.Actions],
|
||||
mesh: jax.sharding.Mesh,
|
||||
data_sharding: jax.sharding.Sharding,
|
||||
*,
|
||||
resume: bool,
|
||||
) -> tuple[training_utils.TrainState, Any]:
|
||||
weight_decay_mask = None
|
||||
freeze_mask = None
|
||||
tx = _optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask, freeze_mask)
|
||||
|
||||
def init(
|
||||
rng: at.KeyArrayLike,
|
||||
data: tuple[_common.Observation, _common.Actions],
|
||||
params_sharding: jax.sharding.Sharding | None = None,
|
||||
) -> training_utils.TrainState:
|
||||
rng, model_rng = jax.random.split(rng)
|
||||
observation, actions = data
|
||||
params = model.init_params(model_rng, observation, actions)
|
||||
# jax.experimental.io_callback raises spmd partitioning warnings, setting constraints
|
||||
# to replicate params to avoid the warnings. the returned train state will be sharded still
|
||||
# since fsdp sharding is specified as output_sharding when jitting this function.
|
||||
if params_sharding is not None:
|
||||
params = jax.lax.with_sharding_constraint(params, params_sharding)
|
||||
params = jax.experimental.io_callback(
|
||||
partial(_load_weights_and_validate, config.weight_loader),
|
||||
params,
|
||||
params,
|
||||
ordered=True,
|
||||
)
|
||||
if params_sharding is not None:
|
||||
params = jax.lax.with_sharding_constraint(params, params_sharding)
|
||||
return training_utils.TrainState(
|
||||
step=0,
|
||||
params=params,
|
||||
opt_state=tx.init(params),
|
||||
tx=tx,
|
||||
ema_decay=config.ema_decay,
|
||||
ema_params=None if config.ema_decay is None else params,
|
||||
)
|
||||
|
||||
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
|
||||
train_state_shape = jax.eval_shape(init, init_rng, batch)
|
||||
state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True)
|
||||
|
||||
if resume:
|
||||
return train_state_shape, state_sharding
|
||||
|
||||
train_state = jax.jit(
|
||||
init,
|
||||
in_shardings=(replicated_sharding, data_sharding),
|
||||
out_shardings=state_sharding,
|
||||
static_argnums=(2,),
|
||||
)(init_rng, batch, replicated_sharding)
|
||||
return train_state, state_sharding
|
||||
|
||||
|
||||
@at.typecheck
|
||||
def train_step(
|
||||
rng: at.KeyArrayLike,
|
||||
state: training_utils.TrainState,
|
||||
model: _model.Model,
|
||||
batch: tuple[_common.Observation, _common.Actions],
|
||||
) -> tuple[training_utils.TrainState, dict[str, at.Array]]:
|
||||
def loss_fn(params: at.Params, rng: at.KeyArrayLike, observation: _common.Observation, actions: _common.Actions):
|
||||
chunked_loss = model.compute_loss(rng, observation, actions, params=params, train=True)
|
||||
return jnp.mean(chunked_loss)
|
||||
|
||||
train_rng = jax.random.fold_in(rng, state.step)
|
||||
observation, actions = batch
|
||||
loss, grads = jax.value_and_grad(loss_fn)(state.params, train_rng, observation, actions)
|
||||
updates, new_opt_state = state.tx.update(grads, state.opt_state, state.params)
|
||||
new_params = optax.apply_updates(state.params, updates)
|
||||
|
||||
new_state = state.replace(step=state.step + 1, params=new_params, opt_state=new_opt_state)
|
||||
if state.ema_decay is not None:
|
||||
new_state = new_state.replace(
|
||||
ema_params=jax.tree.map(
|
||||
lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new, state.ema_params, new_params
|
||||
)
|
||||
)
|
||||
|
||||
kernel_mask = training_utils.mask_from_regex(r".*\['kernel'\]", state.params)
|
||||
kernel_params = jax.tree.map(lambda p, m: p if m else None, state.params, kernel_mask)
|
||||
info = {
|
||||
"loss": loss,
|
||||
"grad_norm": optax.global_norm(grads), # TODO: do not compute norm for frozen params
|
||||
"param_norm": optax.global_norm(kernel_params),
|
||||
}
|
||||
return new_state, info
|
||||
|
||||
|
||||
def main(config: _config.TrainConfig):
|
||||
init_logging()
|
||||
logging.info(f"Running on: {platform.node()}")
|
||||
|
||||
if config.batch_size % jax.device_count() != 0:
|
||||
raise ValueError(
|
||||
f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
|
||||
)
|
||||
|
||||
jax.config.update("jax_threefry_partitionable", True) # noqa: FBT003
|
||||
jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser()))
|
||||
|
||||
rng = jax.random.key(config.seed)
|
||||
train_rng, init_rng = jax.random.split(rng)
|
||||
|
||||
if jax.device_count() % config.fsdp_devices != 0:
|
||||
raise ValueError(
|
||||
f"Number of devices {jax.device_count()} must be divisible by the number of FSDP devices {config.fsdp_devices}."
|
||||
)
|
||||
mesh_shape = (jax.device_count() // config.fsdp_devices, config.fsdp_devices)
|
||||
mesh = jax.make_mesh(mesh_shape, ("batch", "model"))
|
||||
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(("batch", "model")))
|
||||
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
|
||||
checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir(
|
||||
config.checkpoint_dir,
|
||||
keep_interval=config.keep_interval,
|
||||
overwrite=config.overwrite,
|
||||
resume=config.resume,
|
||||
)
|
||||
init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
|
||||
|
||||
model = config.create_model()
|
||||
|
||||
data_loader = _data_loader.create_data_loader(
|
||||
config,
|
||||
model,
|
||||
sharding=data_sharding,
|
||||
num_workers=config.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
data_iter = iter(data_loader)
|
||||
batch = next(data_iter)
|
||||
logging.info(f"Initialized data loader:\n{training_utils.array_tree_to_info(batch)}")
|
||||
|
||||
train_state, train_state_sharding = init_train_state(
|
||||
config, model, init_rng, batch, mesh, data_sharding, resume=resuming
|
||||
)
|
||||
jax.block_until_ready(train_state)
|
||||
logging.info(f"Initialized train state:\n{training_utils.array_tree_to_info(train_state.params)}")
|
||||
|
||||
if resuming:
|
||||
train_state = _checkpoints.restore_state(checkpoint_manager, train_state, data_loader)
|
||||
|
||||
ptrain_step = jax.jit(
|
||||
train_step,
|
||||
in_shardings=(replicated_sharding, train_state_sharding, None, data_sharding),
|
||||
out_shardings=(train_state_sharding, replicated_sharding),
|
||||
donate_argnums=(1,),
|
||||
)
|
||||
|
||||
start_step = int(train_state.step)
|
||||
pbar = tqdm.tqdm(
|
||||
range(start_step, config.num_train_steps),
|
||||
initial=start_step,
|
||||
total=config.num_train_steps,
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
|
||||
infos = []
|
||||
for step in pbar:
|
||||
train_state, info = ptrain_step(train_rng, train_state, model, batch)
|
||||
infos.append(info)
|
||||
if step % config.log_interval == 0:
|
||||
stacked_infos = common_utils.stack_forest(infos)
|
||||
reduced_info = jax.device_get(jax.tree.map(jnp.mean, stacked_infos))
|
||||
info_str = ", ".join(f"{k}={v:.4f}" for k, v in reduced_info.items())
|
||||
pbar.write(f"Step {step}: {info_str}")
|
||||
wandb.log(reduced_info, step=step)
|
||||
infos = []
|
||||
batch = next(data_iter)
|
||||
|
||||
if (step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1:
|
||||
_checkpoints.save_state(checkpoint_manager, train_state, data_loader, step)
|
||||
|
||||
logging.info("Waiting for checkpoint manager to finish")
|
||||
checkpoint_manager.wait_until_finished()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(_config.cli())
|
||||
27
scripts/train_test.py
Normal file
27
scripts/train_test.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import dataclasses
|
||||
import pathlib
|
||||
|
||||
import pytest
|
||||
|
||||
from openpi.training import config as _config
|
||||
|
||||
from . import train
|
||||
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["debug"])
|
||||
def test_train(tmp_path: pathlib.Path, config_name: str):
|
||||
config = dataclasses.replace(
|
||||
_config._CONFIGS_DICT[config_name], # noqa: SLF001
|
||||
batch_size=2,
|
||||
checkpoint_base_dir=tmp_path / "checkpoint",
|
||||
exp_name="test",
|
||||
overwrite=False,
|
||||
resume=False,
|
||||
num_train_steps=2,
|
||||
log_interval=1,
|
||||
)
|
||||
train.main(config)
|
||||
|
||||
# test resuming
|
||||
config = dataclasses.replace(config, resume=True, num_train_steps=4)
|
||||
train.main(config)
|
||||
Reference in New Issue
Block a user