Compare commits
44 Commits
user/alibe
...
thom_arm
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b670d3c43e | ||
|
|
efd3357124 | ||
|
|
785db44bd5 | ||
|
|
e2f690e779 | ||
|
|
68a680a9eb | ||
|
|
ce5329cf44 | ||
|
|
f409bee6b1 | ||
|
|
c91ececc75 | ||
|
|
8a10e8442b | ||
|
|
10c5151fc6 | ||
|
|
a7d24f6dc2 | ||
|
|
642f1d0328 | ||
|
|
4843988d81 | ||
|
|
772927616a | ||
|
|
d52c6037e8 | ||
|
|
b0cb342795 | ||
|
|
8460ea6f83 | ||
|
|
9a3b0b738a | ||
|
|
c4da689171 | ||
|
|
9b62c25f6c | ||
|
|
01eae09ba6 | ||
|
|
19dfb9144a | ||
|
|
096149b118 | ||
|
|
5ec0af62c6 | ||
|
|
625f0557ef | ||
|
|
4d7d41cdee | ||
|
|
c9069df9f1 | ||
|
|
68c1b13406 | ||
|
|
f52f4f2cd2 | ||
|
|
89c6be84ca | ||
|
|
fc5cf3d84a | ||
|
|
29a196c5dd | ||
|
|
ced3de4c94 | ||
|
|
7b47ab211b | ||
|
|
1249aee3ac | ||
|
|
b187942db4 | ||
|
|
473345fdf6 | ||
|
|
e89521dfa0 | ||
|
|
7bb5b15f4c | ||
|
|
df914aa76c | ||
|
|
0ea7a8b2a3 | ||
|
|
460df2ccea | ||
|
|
f5de57b385 | ||
|
|
47de07658c |
4
.gitattributes
vendored
4
.gitattributes
vendored
@@ -1,2 +1,6 @@
|
||||
*.memmap filter=lfs diff=lfs merge=lfs -text
|
||||
*.stl filter=lfs diff=lfs merge=lfs -text
|
||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.json filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
38
.github/workflows/test.yml
vendored
38
.github/workflows/test.yml
vendored
@@ -29,6 +29,8 @@ jobs:
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
|
||||
- name: Install EGL
|
||||
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
|
||||
@@ -57,6 +59,40 @@ jobs:
|
||||
&& rm -rf tests/outputs outputs
|
||||
|
||||
|
||||
pytest-minimal:
|
||||
name: Pytest (minimal install)
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DATA_DIR: tests/data
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
|
||||
- name: Install poetry
|
||||
run: |
|
||||
pipx install poetry && poetry config virtualenvs.in-project true
|
||||
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Install poetry dependencies
|
||||
run: |
|
||||
poetry install --extras "test"
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pytest tests -v --cov=./lerobot --durations=0 \
|
||||
-W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \
|
||||
-W ignore::UserWarning:torch.utils.data.dataloader:558 \
|
||||
-W ignore::UserWarning:gymnasium.utils.env_checker:247 \
|
||||
&& rm -rf tests/outputs outputs
|
||||
|
||||
|
||||
end-to-end:
|
||||
name: End-to-end
|
||||
runs-on: ubuntu-latest
|
||||
@@ -65,6 +101,8 @@ jobs:
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
|
||||
- name: Install EGL
|
||||
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
|
||||
|
||||
@@ -195,6 +195,11 @@ Follow these steps to start contributing:
|
||||
git commit
|
||||
```
|
||||
|
||||
Note, if you already commited some changes that have a wrong formatting, you can use:
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
Please write [good commit messages](https://chris.beams.io/posts/git-commit/).
|
||||
|
||||
It is a good idea to sync your copy of the code with the original
|
||||
|
||||
7
Makefile
7
Makefile
@@ -22,9 +22,8 @@ test-end-to-end:
|
||||
${MAKE} test-act-ete-eval
|
||||
${MAKE} test-diffusion-ete-train
|
||||
${MAKE} test-diffusion-ete-eval
|
||||
# TODO(rcadene, alexander-soare): enable end-to-end tests for tdmpc
|
||||
# ${MAKE} test-tdmpc-ete-train
|
||||
# ${MAKE} test-tdmpc-ete-eval
|
||||
${MAKE} test-tdmpc-ete-train
|
||||
${MAKE} test-tdmpc-ete-eval
|
||||
${MAKE} test-default-ete-eval
|
||||
|
||||
test-act-ete-train:
|
||||
@@ -80,7 +79,7 @@ test-tdmpc-ete-train:
|
||||
policy=tdmpc \
|
||||
env=xarm \
|
||||
env.task=XarmLift-v0 \
|
||||
dataset_repo_id=lerobot/xarm_lift_medium_replay \
|
||||
dataset_repo_id=lerobot/xarm_lift_medium \
|
||||
wandb.enable=False \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=2 \
|
||||
|
||||
23
README.md
23
README.md
@@ -57,7 +57,6 @@
|
||||
- Thanks to Tony Zaho, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
|
||||
- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
|
||||
- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
|
||||
- Thanks to Vincent Moens and colleagues for open sourcing [TorchRL](https://github.com/pytorch/rl). It allowed for quick experimentations on the design of `LeRobot`.
|
||||
- Thanks to Antonio Loquercio and Ashish Kumar for their early support.
|
||||
|
||||
|
||||
@@ -93,6 +92,8 @@ To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tra
|
||||
wandb login
|
||||
```
|
||||
|
||||
(note: you will also need to enable WandB in the configuration. See below.)
|
||||
|
||||
## Walkthrough
|
||||
|
||||
```
|
||||
@@ -159,13 +160,13 @@ See `python lerobot/scripts/eval.py --help` for more instructions.
|
||||
|
||||
Check out [example 3](./examples/3_train_policy.py) that illustrates how to start training a model.
|
||||
|
||||
In general, you can use our training script to easily train any policy. To use wandb for logging training and evaluation curves, make sure you ran `wandb login`. Here is an example of training the ACT policy on trajectories collected by humans on the Aloha simulation environment for the insertion task:
|
||||
In general, you can use our training script to easily train any policy. Here is an example of training the ACT policy on trajectories collected by humans on the Aloha simulation environment for the insertion task:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
policy=act \
|
||||
env=aloha \
|
||||
env.task=AlohaInsertion-v0 \
|
||||
dataset_repo_id=lerobot/aloha_sim_insertion_human
|
||||
dataset_repo_id=lerobot/aloha_sim_insertion_human \
|
||||
```
|
||||
|
||||
The experiment directory is automatically generated and will show up in yellow in your terminal. It looks like `outputs/train/2024-05-05/20-21-12_aloha_act_default`. You can manually specify an experiment directory by adding this argument to the `train.py` python command:
|
||||
@@ -173,17 +174,17 @@ The experiment directory is automatically generated and will show up in yellow i
|
||||
hydra.run.dir=your/new/experiment/dir
|
||||
```
|
||||
|
||||
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of logs from wandb:
|
||||

|
||||
To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding:
|
||||
|
||||
You can deactivate wandb by adding these arguments to the `train.py` python command:
|
||||
```bash
|
||||
wandb.disable_artifact=true \
|
||||
wandb.enable=false
|
||||
wandb.enable=true
|
||||
```
|
||||
|
||||
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. After training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python lerobot/scripts/eval.py --help` for more instructions.
|
||||
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser:
|
||||
|
||||

|
||||
|
||||
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. After training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python lerobot/scripts/eval.py --help` for more instructions.
|
||||
|
||||
## Contribute
|
||||
|
||||
@@ -196,11 +197,11 @@ To add a dataset to the hub, you need to login using a write-access token, which
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Then move your dataset folder in `data` directory (e.g. `data/aloha_ping_pong`), and push your dataset to the hub with:
|
||||
Then move your dataset folder in `data` directory (e.g. `data/aloha_static_pingpong_test`), and push your dataset to the hub with:
|
||||
```bash
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--data-dir data \
|
||||
--dataset-id aloha_ping_ping \
|
||||
--dataset-id aloha_static_pingpong_test \
|
||||
--raw-format aloha_hdf5 \
|
||||
--community-id lerobot
|
||||
```
|
||||
|
||||
@@ -7,6 +7,11 @@ ARG DEBIAN_FRONTEND=noninteractive
|
||||
# Install apt dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake \
|
||||
git git-lfs openssh-client \
|
||||
nano vim ffmpeg \
|
||||
htop atop nvtop \
|
||||
sed gawk grep curl wget \
|
||||
tcpdump sysstat screen \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
|
||||
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
@@ -18,7 +23,8 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
|
||||
# Install LeRobot
|
||||
COPY . /lerobot
|
||||
RUN git lfs install
|
||||
RUN git clone https://github.com/huggingface/lerobot.git
|
||||
WORKDIR /lerobot
|
||||
RUN pip install --upgrade --no-cache-dir pip
|
||||
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht]"
|
||||
|
||||
90
examples/4_calculate_validation_loss.py
Normal file
90
examples/4_calculate_validation_loss.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""This script demonstrates how to slice a dataset and calculate the loss on a subset of the data.
|
||||
|
||||
This technique can be useful for debugging and testing purposes, as well as identifying whether a policy
|
||||
is learning effectively.
|
||||
|
||||
Furthermore, relying on validation loss to evaluate performance is generally not considered a good practice,
|
||||
especially in the context of imitation learning. The most reliable approach is to evaluate the policy directly
|
||||
on the target environment, whether that be in simulation or the real world.
|
||||
"""
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
||||
# Download the diffusion policy for pusht environment
|
||||
pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
|
||||
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
||||
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
|
||||
|
||||
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
|
||||
# Set up the dataset.
|
||||
delta_timestamps = {
|
||||
# Load the previous image and state at -0.1 seconds before current frame,
|
||||
# then load current image and state corresponding to 0.0 second.
|
||||
"observation.image": [-0.1, 0.0],
|
||||
"observation.state": [-0.1, 0.0],
|
||||
# Load the previous action (-0.1), the next action to be executed (0.0),
|
||||
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
||||
# used to calculate the loss.
|
||||
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||
}
|
||||
|
||||
# Load the last 10% of episodes of the dataset as a validation set.
|
||||
# - Load full dataset
|
||||
full_dataset = LeRobotDataset("lerobot/pusht", split="train")
|
||||
# - Calculate train and val subsets
|
||||
num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100)
|
||||
num_val_episodes = full_dataset.num_episodes - num_train_episodes
|
||||
print(f"Number of episodes in full dataset: {full_dataset.num_episodes}")
|
||||
print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}")
|
||||
print(f"Number of episodes in validation dataset (10% subset): {num_val_episodes}")
|
||||
# - Get first frame index of the validation set
|
||||
first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item()
|
||||
# - Load frames subset belonging to validation set using the `split` argument.
|
||||
# It utilizes the `datasets` library's syntax for slicing datasets.
|
||||
# For more information on the Slice API, please see:
|
||||
# https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
|
||||
train_dataset = LeRobotDataset(
|
||||
"lerobot/pusht", split=f"train[:{first_val_frame_index}]", delta_timestamps=delta_timestamps
|
||||
)
|
||||
val_dataset = LeRobotDataset(
|
||||
"lerobot/pusht", split=f"train[{first_val_frame_index}:]", delta_timestamps=delta_timestamps
|
||||
)
|
||||
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
|
||||
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")
|
||||
|
||||
# Create dataloader for evaluation.
|
||||
val_dataloader = torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
num_workers=4,
|
||||
batch_size=64,
|
||||
shuffle=False,
|
||||
pin_memory=device != torch.device("cpu"),
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
# Run validation loop.
|
||||
loss_cumsum = 0
|
||||
n_examples_evaluated = 0
|
||||
for batch in val_dataloader:
|
||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||
output_dict = policy.forward(batch)
|
||||
|
||||
loss_cumsum += output_dict["loss"].item()
|
||||
n_examples_evaluated += batch["index"].shape[0]
|
||||
|
||||
# Calculate the average loss over the validation set.
|
||||
average_loss = loss_cumsum / n_examples_evaluated
|
||||
|
||||
print(f"Average loss on validation set: {average_loss:.4f}")
|
||||
1
gym_dora/README.md
Normal file
1
gym_dora/README.md
Normal file
@@ -0,0 +1 @@
|
||||
# gym_dora
|
||||
17
gym_dora/example.py
Normal file
17
gym_dora/example.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import gymnasium as gym
|
||||
|
||||
import gym_dora # noqa: F401
|
||||
|
||||
env = gym.make("gym_dora/DoraAloha-v0", disable_env_checker=True)
|
||||
obs = env.reset()
|
||||
|
||||
policy = ... # make_policy
|
||||
|
||||
done = False
|
||||
while not done:
|
||||
actions = policy.select_action(obs)
|
||||
observation, reward, terminated, truncated, info = env.step(actions)
|
||||
|
||||
done = terminated | truncated | done
|
||||
|
||||
env.close()
|
||||
17
gym_dora/gym_dora/__init__.py
Normal file
17
gym_dora/gym_dora/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from gymnasium.envs.registration import register
|
||||
|
||||
register(
|
||||
id="gym_dora/DoraAloha-v0",
|
||||
entry_point="gym_dora.env:DoraEnv",
|
||||
max_episode_steps=300,
|
||||
nondeterministic=True,
|
||||
kwargs={"model": "aloha"},
|
||||
)
|
||||
|
||||
register(
|
||||
id="gym_dora/DoraKoch-v0",
|
||||
entry_point="gym_dora.env:DoraEnv",
|
||||
max_episode_steps=300,
|
||||
nondeterministic=True,
|
||||
kwargs={"model": "koch"},
|
||||
)
|
||||
199
gym_dora/gym_dora/env.py
Normal file
199
gym_dora/gym_dora/env.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import os
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from dora import Node
|
||||
from gymnasium import spaces
|
||||
|
||||
FPS = int(os.getenv("FPS", "30"))
|
||||
IMAGE_WIDTH = int(os.getenv("IMAGE_WIDTH", "640"))
|
||||
IMAGE_HEIGHT = int(os.getenv("IMAGE_HEIGHT", "480"))
|
||||
|
||||
ALOHA_JOINTS = [
|
||||
# absolute joint position
|
||||
"left_arm_waist",
|
||||
"left_arm_shoulder",
|
||||
"left_arm_elbow",
|
||||
"left_arm_forearm_roll",
|
||||
"left_arm_wrist_angle",
|
||||
"left_arm_wrist_rotate",
|
||||
# normalized gripper position 0: close, 1: open
|
||||
"left_arm_gripper",
|
||||
# absolute joint position
|
||||
"right_arm_waist",
|
||||
"right_arm_shoulder",
|
||||
"right_arm_elbow",
|
||||
"right_arm_forearm_roll",
|
||||
"right_arm_wrist_angle",
|
||||
"right_arm_wrist_rotate",
|
||||
# normalized gripper position 0: close, 1: open
|
||||
"right_arm_gripper",
|
||||
]
|
||||
ALOHA_ACTIONS = [
|
||||
# position and quaternion for end effector
|
||||
"left_arm_waist",
|
||||
"left_arm_shoulder",
|
||||
"left_arm_elbow",
|
||||
"left_arm_forearm_roll",
|
||||
"left_arm_wrist_angle",
|
||||
"left_arm_wrist_rotate",
|
||||
# normalized gripper position (0: close, 1: open)
|
||||
"left_arm_gripper",
|
||||
"right_arm_waist",
|
||||
"right_arm_shoulder",
|
||||
"right_arm_elbow",
|
||||
"right_arm_forearm_roll",
|
||||
"right_arm_wrist_angle",
|
||||
"right_arm_wrist_rotate",
|
||||
# normalized gripper position (0: close, 1: open)
|
||||
"right_arm_gripper",
|
||||
]
|
||||
|
||||
|
||||
class DoraEnv(gym.Env):
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": FPS}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model="aloha",
|
||||
observation_width=IMAGE_WIDTH,
|
||||
observation_height=IMAGE_HEIGHT,
|
||||
cameras_names=None,
|
||||
num_joints=None,
|
||||
num_actions=None,
|
||||
):
|
||||
"""Initializes the Dora environment.
|
||||
|
||||
Args:
|
||||
model (str): The model to use. Either 'aloha' or 'custom'.
|
||||
observation_width (int): The width of the observation image.
|
||||
observation_height (int): The height of the observation image.
|
||||
cameras_names (list): A list of camera names to use. If not provided, the default is ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist'].
|
||||
num_joints (int): The number of joints in the model. If not provided, the default is 14 for 'aloha' and 6 for 'fivedof'.
|
||||
num_actions (int): The number of actions in the model. If not provided, the default is 14 for 'aloha' and 6 for 'fivedof'.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Initialize a new node
|
||||
self.node = Node() if os.environ.get("DORA_NODE_CONFIG", None) is not None else None
|
||||
self.observation = {"pixels": {}, "agent_pos": None}
|
||||
self.terminated = False
|
||||
|
||||
self.observation_height = observation_height
|
||||
self.observation_width = observation_width
|
||||
|
||||
# Observation space
|
||||
if model == "aloha":
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"pixels": spaces.Dict(
|
||||
{
|
||||
"cam_high": spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(self.observation_height, self.observation_width, 3),
|
||||
dtype=np.uint8,
|
||||
),
|
||||
"cam_low": spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(self.observation_height, self.observation_width, 3),
|
||||
dtype=np.uint8,
|
||||
),
|
||||
"cam_left_wrist": spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(self.observation_height, self.observation_width, 3),
|
||||
dtype=np.uint8,
|
||||
),
|
||||
"cam_right_wrist": spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(self.observation_height, self.observation_width, 3),
|
||||
dtype=np.uint8,
|
||||
),
|
||||
}
|
||||
),
|
||||
"agent_pos": spaces.Box(
|
||||
low=-1000.0,
|
||||
high=1000.0,
|
||||
shape=(len(ALOHA_JOINTS),),
|
||||
dtype=np.float64,
|
||||
),
|
||||
}
|
||||
)
|
||||
elif model == "custom":
|
||||
pixel_dict = {}
|
||||
for camera in cameras_names:
|
||||
assert camera.startswith("cam"), "Camera names must start with 'cam'"
|
||||
pixel_dict[camera] = spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(self.observation_height, self.observation_width, 3),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"pixels": spaces.Dict(pixel_dict),
|
||||
"agent_pos": spaces.Box(
|
||||
low=-1000.0,
|
||||
high=1000.0,
|
||||
shape=(num_joints,),
|
||||
dtype=np.float64,
|
||||
),
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError("Model must be either 'aloha' or 'custom'.")
|
||||
|
||||
# Action space
|
||||
if model == "aloha":
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(len(ALOHA_ACTIONS),), dtype=np.float32)
|
||||
elif model == "custom":
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(num_actions,), dtype=np.float32)
|
||||
|
||||
def _get_obs(self):
|
||||
while True:
|
||||
event = self.node.next(timeout=0.001)
|
||||
|
||||
## If event is None, the node event stream is closed and we should terminate the env
|
||||
if event is None:
|
||||
self.terminated = True
|
||||
break
|
||||
|
||||
if event["type"] == "INPUT":
|
||||
# Map Image input into pixels key within Aloha environment
|
||||
if "cam" in event["id"]:
|
||||
self.observation["pixels"][event["id"]] = (
|
||||
event["value"].to_numpy().reshape(self.observation_height, self.observation_width, 3)
|
||||
)
|
||||
else:
|
||||
# Map other inputs into the observation dictionary using the event id as key
|
||||
self.observation[event["id"]] = event["value"].to_numpy()
|
||||
|
||||
# If the event is a timeout error break the update loop.
|
||||
elif event["type"] == "ERROR":
|
||||
break
|
||||
|
||||
def reset(self, seed: int | None = None):
|
||||
self.node.send_output("reset")
|
||||
self._get_obs()
|
||||
self.terminated = False
|
||||
info = {}
|
||||
return self.observation, info
|
||||
|
||||
def step(self, action: np.ndarray):
|
||||
# Send the action to the dataflow as action key.
|
||||
self.node.send_output("action", pa.array(action))
|
||||
self._get_obs()
|
||||
reward = 0
|
||||
terminated = truncated = self.terminated
|
||||
info = {}
|
||||
return self.observation, reward, terminated, truncated, info
|
||||
|
||||
def render(self): ...
|
||||
|
||||
def close(self):
|
||||
# Drop the node
|
||||
del self.node
|
||||
182
gym_dora/poetry.lock
generated
Normal file
182
gym_dora/poetry.lock
generated
Normal file
@@ -0,0 +1,182 @@
|
||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "cloudpickle"
|
||||
version = "3.0.0"
|
||||
description = "Pickler class to extend the standard pickle.Pickler functionality"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "cloudpickle-3.0.0-py3-none-any.whl", hash = "sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7"},
|
||||
{file = "cloudpickle-3.0.0.tar.gz", hash = "sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dora-rs"
|
||||
version = "0.3.4"
|
||||
description = "`dora` goal is to be a low latency, composable, and distributed data flow."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "dora_rs-0.3.4-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:d1b738eea5a4966d731c26c6b6a0a50a491a24f7e9e335475f983cfc6f0da19e"},
|
||||
{file = "dora_rs-0.3.4-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:80b724871618c78a4e5863938fa66724176cc40352771087aebe1e62a8141157"},
|
||||
{file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a3919e157b47dc1dbc74c040a73087a4485f0d1bee99b6adcdbc36559400fe2"},
|
||||
{file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f7c95f6e5858fd651d6cd220e4f052e99db2944b9c37fb0b5402d60ac4b41a63"},
|
||||
{file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37d915fbbca282446235c98a9ca08389aa3ef3155d4e88c6c136326e9a830042"},
|
||||
{file = "dora_rs-0.3.4-cp37-abi3-win32.whl", hash = "sha256:c9f7f22f65c884ec9bee0245ce98d0c7fad25dec0f982e566f844b5e8e58818f"},
|
||||
{file = "dora_rs-0.3.4-cp37-abi3-win_amd64.whl", hash = "sha256:0a6a37f96a9f6e13b58b02a6ea75af192af5fbe4f456f6a67b1f239c3cee3276"},
|
||||
{file = "dora_rs-0.3.4.tar.gz", hash = "sha256:05c5d0db0d23d7c4669995ae34db11cd636dbf91f5705d832669bd04e7452903"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pyarrow = "*"
|
||||
|
||||
[[package]]
|
||||
name = "farama-notifications"
|
||||
version = "0.0.4"
|
||||
description = "Notifications for all Farama Foundation maintained libraries."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "Farama-Notifications-0.0.4.tar.gz", hash = "sha256:13fceff2d14314cf80703c8266462ebf3733c7d165336eee998fc58e545efd18"},
|
||||
{file = "Farama_Notifications-0.0.4-py3-none-any.whl", hash = "sha256:14de931035a41961f7c056361dc7f980762a143d05791ef5794a751a2caf05ae"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gymnasium"
|
||||
version = "0.29.1"
|
||||
description = "A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym)."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "gymnasium-0.29.1-py3-none-any.whl", hash = "sha256:61c3384b5575985bb7f85e43213bcb40f36fcdff388cae6bc229304c71f2843e"},
|
||||
{file = "gymnasium-0.29.1.tar.gz", hash = "sha256:1a532752efcb7590478b1cc7aa04f608eb7a2fdad5570cd217b66b6a35274bb1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
cloudpickle = ">=1.2.0"
|
||||
farama-notifications = ">=0.0.1"
|
||||
numpy = ">=1.21.0"
|
||||
typing-extensions = ">=4.3.0"
|
||||
|
||||
[package.extras]
|
||||
accept-rom-license = ["autorom[accept-rom-license] (>=0.4.2,<0.5.0)"]
|
||||
all = ["box2d-py (==2.3.5)", "cython (<3)", "imageio (>=2.14.1)", "jax (>=0.4.0)", "jaxlib (>=0.4.0)", "lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "mujoco (>=2.3.3)", "mujoco-py (>=2.1,<2.2)", "opencv-python (>=3.0)", "pygame (>=2.1.3)", "shimmy[atari] (>=0.1.0,<1.0)", "swig (==4.*)", "torch (>=1.0.0)"]
|
||||
atari = ["shimmy[atari] (>=0.1.0,<1.0)"]
|
||||
box2d = ["box2d-py (==2.3.5)", "pygame (>=2.1.3)", "swig (==4.*)"]
|
||||
classic-control = ["pygame (>=2.1.3)", "pygame (>=2.1.3)"]
|
||||
jax = ["jax (>=0.4.0)", "jaxlib (>=0.4.0)"]
|
||||
mujoco = ["imageio (>=2.14.1)", "mujoco (>=2.3.3)"]
|
||||
mujoco-py = ["cython (<3)", "cython (<3)", "mujoco-py (>=2.1,<2.2)", "mujoco-py (>=2.1,<2.2)"]
|
||||
other = ["lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "opencv-python (>=3.0)", "torch (>=1.0.0)"]
|
||||
testing = ["pytest (==7.1.3)", "scipy (>=1.7.3)"]
|
||||
toy-text = ["pygame (>=2.1.3)", "pygame (>=2.1.3)"]
|
||||
|
||||
[[package]]
|
||||
name = "numpy"
|
||||
version = "1.26.4"
|
||||
description = "Fundamental package for array computing in Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"},
|
||||
{file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"},
|
||||
{file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"},
|
||||
{file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"},
|
||||
{file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyarrow"
|
||||
version = "16.1.0"
|
||||
description = "Python library for Apache Arrow"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pyarrow-16.1.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:17e23b9a65a70cc733d8b738baa6ad3722298fa0c81d88f63ff94bf25eaa77b9"},
|
||||
{file = "pyarrow-16.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4740cc41e2ba5d641071d0ab5e9ef9b5e6e8c7611351a5cb7c1d175eaf43674a"},
|
||||
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:98100e0268d04e0eec47b73f20b39c45b4006f3c4233719c3848aa27a03c1aef"},
|
||||
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f68f409e7b283c085f2da014f9ef81e885d90dcd733bd648cfba3ef265961848"},
|
||||
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a8914cd176f448e09746037b0c6b3a9d7688cef451ec5735094055116857580c"},
|
||||
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:48be160782c0556156d91adbdd5a4a7e719f8d407cb46ae3bb4eaee09b3111bd"},
|
||||
{file = "pyarrow-16.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:9cf389d444b0f41d9fe1444b70650fea31e9d52cfcb5f818b7888b91b586efff"},
|
||||
{file = "pyarrow-16.1.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:d0ebea336b535b37eee9eee31761813086d33ed06de9ab6fc6aaa0bace7b250c"},
|
||||
{file = "pyarrow-16.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e73cfc4a99e796727919c5541c65bb88b973377501e39b9842ea71401ca6c1c"},
|
||||
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf9251264247ecfe93e5f5a0cd43b8ae834f1e61d1abca22da55b20c788417f6"},
|
||||
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddf5aace92d520d3d2a20031d8b0ec27b4395cab9f74e07cc95edf42a5cc0147"},
|
||||
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:25233642583bf658f629eb230b9bb79d9af4d9f9229890b3c878699c82f7d11e"},
|
||||
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a33a64576fddfbec0a44112eaf844c20853647ca833e9a647bfae0582b2ff94b"},
|
||||
{file = "pyarrow-16.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:185d121b50836379fe012753cf15c4ba9638bda9645183ab36246923875f8d1b"},
|
||||
{file = "pyarrow-16.1.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:2e51ca1d6ed7f2e9d5c3c83decf27b0d17bb207a7dea986e8dc3e24f80ff7d6f"},
|
||||
{file = "pyarrow-16.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06ebccb6f8cb7357de85f60d5da50e83507954af617d7b05f48af1621d331c9a"},
|
||||
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b04707f1979815f5e49824ce52d1dceb46e2f12909a48a6a753fe7cafbc44a0c"},
|
||||
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d32000693deff8dc5df444b032b5985a48592c0697cb6e3071a5d59888714e2"},
|
||||
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8785bb10d5d6fd5e15d718ee1d1f914fe768bf8b4d1e5e9bf253de8a26cb1628"},
|
||||
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:e1369af39587b794873b8a307cc6623a3b1194e69399af0efd05bb202195a5a7"},
|
||||
{file = "pyarrow-16.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:febde33305f1498f6df85e8020bca496d0e9ebf2093bab9e0f65e2b4ae2b3444"},
|
||||
{file = "pyarrow-16.1.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b5f5705ab977947a43ac83b52ade3b881eb6e95fcc02d76f501d549a210ba77f"},
|
||||
{file = "pyarrow-16.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0d27bf89dfc2576f6206e9cd6cf7a107c9c06dc13d53bbc25b0bd4556f19cf5f"},
|
||||
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d07de3ee730647a600037bc1d7b7994067ed64d0eba797ac74b2bc77384f4c2"},
|
||||
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fbef391b63f708e103df99fbaa3acf9f671d77a183a07546ba2f2c297b361e83"},
|
||||
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:19741c4dbbbc986d38856ee7ddfdd6a00fc3b0fc2d928795b95410d38bb97d15"},
|
||||
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:f2c5fb249caa17b94e2b9278b36a05ce03d3180e6da0c4c3b3ce5b2788f30eed"},
|
||||
{file = "pyarrow-16.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:e6b6d3cd35fbb93b70ade1336022cc1147b95ec6af7d36906ca7fe432eb09710"},
|
||||
{file = "pyarrow-16.1.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:18da9b76a36a954665ccca8aa6bd9f46c1145f79c0bb8f4f244f5f8e799bca55"},
|
||||
{file = "pyarrow-16.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:99f7549779b6e434467d2aa43ab2b7224dd9e41bdde486020bae198978c9e05e"},
|
||||
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f07fdffe4fd5b15f5ec15c8b64584868d063bc22b86b46c9695624ca3505b7b4"},
|
||||
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddfe389a08ea374972bd4065d5f25d14e36b43ebc22fc75f7b951f24378bf0b5"},
|
||||
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:3b20bd67c94b3a2ea0a749d2a5712fc845a69cb5d52e78e6449bbd295611f3aa"},
|
||||
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:ba8ac20693c0bb0bf4b238751d4409e62852004a8cf031c73b0e0962b03e45e3"},
|
||||
{file = "pyarrow-16.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:31a1851751433d89a986616015841977e0a188662fcffd1a5677453f1df2de0a"},
|
||||
{file = "pyarrow-16.1.0.tar.gz", hash = "sha256:15fbb22ea96d11f0b5768504a3f961edab25eaf4197c341720c4a387f6c60315"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
numpy = ">=1.16.6"
|
||||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
version = "4.11.0"
|
||||
description = "Backported and Experimental Type Hints for Python 3.8+"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "typing_extensions-4.11.0-py3-none-any.whl", hash = "sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a"},
|
||||
{file = "typing_extensions-4.11.0.tar.gz", hash = "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0"},
|
||||
]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "7e437b5c547ebe11095f1ce4ff1851d636f8e707ad7de8a6224b0f9ad978240f"
|
||||
17
gym_dora/pyproject.toml
Normal file
17
gym_dora/pyproject.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[tool.poetry]
|
||||
name = "gym-dora"
|
||||
version = "0.1.0"
|
||||
description = ""
|
||||
authors = ["Simon Alibert <alibert.sim@gmail.com>"]
|
||||
readme = "README.md"
|
||||
packages = [{ include = "gym_dora" }]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
gymnasium = ">=0.29.1"
|
||||
dora-rs = ">=0.3.4"
|
||||
pyarrow = ">=12.0.0"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This file contains lists of available environments, dataset and policies to reflect the current state of LeRobot library.
|
||||
We do not want to import all the dependencies, but instead we keep it lightweight to ensure fast access to these variables.
|
||||
@@ -46,13 +61,21 @@ available_datasets_per_env = {
|
||||
"lerobot/aloha_sim_insertion_scripted",
|
||||
"lerobot/aloha_sim_transfer_cube_human",
|
||||
"lerobot/aloha_sim_transfer_cube_scripted",
|
||||
"lerobot/aloha_sim_insertion_human_image",
|
||||
"lerobot/aloha_sim_insertion_scripted_image",
|
||||
"lerobot/aloha_sim_transfer_cube_human_image",
|
||||
"lerobot/aloha_sim_transfer_cube_scripted_image",
|
||||
],
|
||||
"pusht": ["lerobot/pusht"],
|
||||
"pusht": ["lerobot/pusht", "lerobot/pusht_image"],
|
||||
"xarm": [
|
||||
"lerobot/xarm_lift_medium",
|
||||
"lerobot/xarm_lift_medium_replay",
|
||||
"lerobot/xarm_push_medium",
|
||||
"lerobot/xarm_push_medium_replay",
|
||||
"lerobot/xarm_lift_medium_image",
|
||||
"lerobot/xarm_lift_medium_replay_image",
|
||||
"lerobot/xarm_push_medium_image",
|
||||
"lerobot/xarm_push_medium_replay_image",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""To enable `lerobot.__version__`"""
|
||||
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import random
|
||||
import shutil
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@@ -5,17 +20,19 @@ import datasets
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
load_episode_data_index,
|
||||
load_hf_dataset,
|
||||
load_info,
|
||||
load_previous_and_future_frames,
|
||||
load_stats,
|
||||
load_videos,
|
||||
reset_episode_index,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
|
||||
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
CODEBASE_VERSION = "v1.3"
|
||||
CODEBASE_VERSION = "v1.4"
|
||||
|
||||
|
||||
class LeRobotDataset(torch.utils.data.Dataset):
|
||||
@@ -39,7 +56,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||
self.hf_dataset = load_hf_dataset(repo_id, version, root, split)
|
||||
self.episode_data_index = load_episode_data_index(repo_id, version, root)
|
||||
if split == "train":
|
||||
self.episode_data_index = load_episode_data_index(repo_id, version, root)
|
||||
else:
|
||||
self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
|
||||
self.hf_dataset = reset_episode_index(self.hf_dataset)
|
||||
self.stats = load_stats(repo_id, version, root)
|
||||
self.info = load_info(repo_id, version, root)
|
||||
if self.video:
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Helper code for loading PushT dataset from Diffusion Policy (https://diffusion-policy.cs.columbia.edu/)
|
||||
|
||||
Copied from the original Diffusion Policy repository and used in our `download_and_upload_dataset.py` script.
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This file contains all obsolete download scripts. They are centralized here to not have to load
|
||||
useless dependencies when using datasets.
|
||||
@@ -9,17 +24,16 @@ import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import tqdm
|
||||
|
||||
ALOHA_RAW_URLS_DIR = "lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls"
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
def download_raw(raw_dir, dataset_id):
|
||||
if "pusht" in dataset_id:
|
||||
if "aloha" in dataset_id or "image" in dataset_id:
|
||||
download_hub(raw_dir, dataset_id)
|
||||
elif "pusht" in dataset_id:
|
||||
download_pusht(raw_dir)
|
||||
elif "xarm" in dataset_id:
|
||||
download_xarm(raw_dir)
|
||||
elif "aloha" in dataset_id:
|
||||
download_aloha(raw_dir, dataset_id)
|
||||
elif "umi" in dataset_id:
|
||||
download_umi(raw_dir)
|
||||
else:
|
||||
@@ -88,37 +102,13 @@ def download_xarm(raw_dir: Path):
|
||||
zip_path.unlink()
|
||||
|
||||
|
||||
def download_aloha(raw_dir: Path, dataset_id: str):
|
||||
import gdown
|
||||
|
||||
subset_id = dataset_id.replace("aloha_", "")
|
||||
urls_path = Path(ALOHA_RAW_URLS_DIR) / f"{subset_id}.txt"
|
||||
assert urls_path.exists(), f"{subset_id}.txt not found in '{ALOHA_RAW_URLS_DIR}' directory."
|
||||
|
||||
with open(urls_path) as f:
|
||||
# strip lines and ignore empty lines
|
||||
urls = [url.strip() for url in f if url.strip()]
|
||||
|
||||
# sanity check
|
||||
for url in urls:
|
||||
assert (
|
||||
"drive.google.com/drive/folders" in url or "drive.google.com/file" in url
|
||||
), f"Wrong url provided '{url}' in file '{urls_path}'."
|
||||
|
||||
def download_hub(raw_dir: Path, dataset_id: str):
|
||||
raw_dir = Path(raw_dir)
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logging.info(f"Start downloading from google drive for {dataset_id}")
|
||||
for url in urls:
|
||||
if "drive.google.com/drive/folders" in url:
|
||||
# when a folder url is given, download up to 50 files from the folder
|
||||
gdown.download_folder(url, output=str(raw_dir), remaining_ok=True)
|
||||
|
||||
elif "drive.google.com/file" in url:
|
||||
# because of the 50 files limit per folder, we download the remaining files (file by file)
|
||||
gdown.download(url, output=str(raw_dir), fuzzy=True)
|
||||
|
||||
logging.info(f"End downloading from google drive for {dataset_id}")
|
||||
logging.info(f"Start downloading from huggingface.co/cadene for {dataset_id}")
|
||||
snapshot_download(f"cadene/{dataset_id}_raw", repo_type="dataset", local_dir=raw_dir)
|
||||
logging.info(f"Finish downloading from huggingface.co/cadene for {dataset_id}")
|
||||
|
||||
|
||||
def download_umi(raw_dir: Path):
|
||||
@@ -133,21 +123,30 @@ def download_umi(raw_dir: Path):
|
||||
if __name__ == "__main__":
|
||||
data_dir = Path("data")
|
||||
dataset_ids = [
|
||||
"pusht_image",
|
||||
"xarm_lift_medium_image",
|
||||
"xarm_lift_medium_replay_image",
|
||||
"xarm_push_medium_image",
|
||||
"xarm_push_medium_replay_image",
|
||||
"aloha_sim_insertion_human_image",
|
||||
"aloha_sim_insertion_scripted_image",
|
||||
"aloha_sim_transfer_cube_human_image",
|
||||
"aloha_sim_transfer_cube_scripted_image",
|
||||
"pusht",
|
||||
"xarm_lift_medium",
|
||||
"xarm_lift_medium_replay",
|
||||
"xarm_push_medium",
|
||||
"xarm_push_medium_replay",
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
"aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
"aloha_mobile_cabinet",
|
||||
"aloha_mobile_chair",
|
||||
"aloha_mobile_elevator",
|
||||
"aloha_mobile_shrimp",
|
||||
"aloha_mobile_wash_pan",
|
||||
"aloha_mobile_wipe_wine",
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
"aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
"aloha_static_battery",
|
||||
"aloha_static_candy",
|
||||
"aloha_static_coffee",
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# imagecodecs/numcodecs.py
|
||||
|
||||
# Copyright (c) 2021-2022, Christoph Gohlke
|
||||
|
||||
200
lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py
Normal file
200
lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py
Normal file
@@ -0,0 +1,200 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Contains utilities to process raw data format from dora-record
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
|
||||
def check_format(raw_dir) -> bool:
|
||||
assert raw_dir.exists()
|
||||
|
||||
leader_file = list(raw_dir.glob("*.parquet"))
|
||||
if len(leader_file) == 0:
|
||||
raise ValueError(f"Missing parquet files in '{raw_dir}'")
|
||||
return True
|
||||
|
||||
|
||||
def load_from_raw(raw_dir: Path, out_dir: Path):
|
||||
# Load data stream that will be used as reference for the timestamps synchronization
|
||||
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
|
||||
if len(reference_files) == 0:
|
||||
raise ValueError(f"Missing reference files for camera, starting with in '{raw_dir}'")
|
||||
# select first camera in alphanumeric order
|
||||
reference_key = sorted(reference_files)[0].stem
|
||||
reference_df = pd.read_parquet(raw_dir / f"{reference_key}.parquet")
|
||||
reference_df = reference_df[["timestamp_utc", reference_key]]
|
||||
|
||||
# Merge all data stream using nearest backward strategy
|
||||
df = reference_df
|
||||
for path in raw_dir.glob("*.parquet"):
|
||||
key = path.stem # action or observation.state or ...
|
||||
if key == reference_key:
|
||||
continue
|
||||
modality_df = pd.read_parquet(path)
|
||||
modality_df = modality_df[["timestamp_utc", key]]
|
||||
df = pd.merge_asof(
|
||||
df,
|
||||
modality_df,
|
||||
on="timestamp_utc",
|
||||
direction="backward",
|
||||
)
|
||||
|
||||
# Remove rows with a NaN in any column. It can happened during the first frames of an episode,
|
||||
# because some cameras didnt start recording yet.
|
||||
df = df.dropna(axis=0)
|
||||
|
||||
# Remove rows with episode_index -1 which indicates a failed episode
|
||||
df = df[df["episode_index"] != -1]
|
||||
|
||||
# dora only use arrays, so single values are encapsulated into a list
|
||||
df["episode_index"] = df["episode_index"].map(lambda x: x[0])
|
||||
df["frame_index"] = df.groupby("episode_index").cumcount()
|
||||
df = df.reset_index()
|
||||
df["index"] = df.index
|
||||
|
||||
# set 'next.done' to True for the last frame of each episode
|
||||
df["next.done"] = False
|
||||
df.loc[df.groupby("episode_index").tail(1).index, "next.done"] = True
|
||||
|
||||
df["timestamp"] = df["timestamp_utc"].map(lambda x: x.timestamp())
|
||||
# each episode starts with timestamp 0 to match the ones from the video
|
||||
df["timestamp"] = df.groupby("episode_index")["timestamp"].transform(lambda x: x - x.iloc[0])
|
||||
|
||||
del df["timestamp_utc"]
|
||||
|
||||
# sanity check episode indices go from 0 to n-1
|
||||
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
|
||||
expected_ep_ids = list(range(df["episode_index"].max() + 1))
|
||||
assert ep_ids == expected_ep_ids, f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}"
|
||||
|
||||
# Create symlink to raw videos directory (that needs to be absolute not relative)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
videos_dir = out_dir / "videos"
|
||||
videos_dir.symlink_to((raw_dir / "videos").absolute())
|
||||
|
||||
# sanity check the video paths are well formated
|
||||
for key in df:
|
||||
if "observation.images." not in key:
|
||||
continue
|
||||
for ep_idx in ep_ids:
|
||||
video_path = videos_dir / f"{key}_episode_{ep_idx:06d}.mp4"
|
||||
assert video_path.exists(), f"Video file not found in {video_path}"
|
||||
|
||||
data_dict = {}
|
||||
for key in df:
|
||||
# is video frame
|
||||
if "observation.images." in key:
|
||||
# we need `[0] because dora only use arrays, so single values are encapsulated into a list.
|
||||
# it is the case for video_frame dictionary = [{"path": ..., "timestamp": ...}]
|
||||
data_dict[key] = [video_frame[0] for video_frame in df[key].values]
|
||||
|
||||
# sanity check the video path is well formated
|
||||
video_path = videos_dir.parent / data_dict[key][0]["path"]
|
||||
assert video_path.exists(), f"Video file not found in {video_path}"
|
||||
# is number
|
||||
elif df[key].iloc[0].ndim == 0 or df[key].iloc[0].shape[0] == 1:
|
||||
data_dict[key] = torch.from_numpy(df[key].values)
|
||||
# is vector
|
||||
elif df[key].iloc[0].shape[0] > 1:
|
||||
data_dict[key] = torch.stack([torch.from_numpy(x.copy()) for x in df[key].values])
|
||||
else:
|
||||
raise ValueError(key)
|
||||
|
||||
# Get the episode index containing for each unique episode index
|
||||
first_ep_index_df = df.groupby("episode_index").agg(start_index=("index", "first")).reset_index()
|
||||
from_ = first_ep_index_df["start_index"].tolist()
|
||||
to_ = from_[1:] + [len(df)]
|
||||
episode_data_index = {
|
||||
"from": from_,
|
||||
"to": to_,
|
||||
}
|
||||
|
||||
return data_dict, episode_data_index
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
features = {}
|
||||
|
||||
keys = [key for key in data_dict if "observation.images." in key]
|
||||
for key in keys:
|
||||
if video:
|
||||
features[key] = VideoFrame()
|
||||
else:
|
||||
features[key] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.velocity" in data_dict:
|
||||
features["observation.velocity"] = Sequence(
|
||||
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.effort" in data_dict:
|
||||
features["observation.effort"] = Sequence(
|
||||
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
features["timestamp"] = Value(dtype="float32", id=None)
|
||||
features["next.done"] = Value(dtype="bool", id=None)
|
||||
features["index"] = Value(dtype="int64", id=None)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||
init_logging()
|
||||
|
||||
if debug:
|
||||
logging.warning("debug=True not implemented. Falling back to debug=False.")
|
||||
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
if fps is None:
|
||||
fps = 30
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
if not video:
|
||||
raise NotImplementedError()
|
||||
|
||||
data_df, episode_data_index = load_from_raw(raw_dir, out_dir)
|
||||
hf_dataset = to_hf_dataset(data_df, video)
|
||||
|
||||
info = {
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
return hf_dataset, episode_data_index, info
|
||||
@@ -1,8 +1,23 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Contains utilities to process raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act
|
||||
"""
|
||||
|
||||
import re
|
||||
import gc
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
@@ -64,10 +79,8 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
id_from = 0
|
||||
|
||||
for ep_path in tqdm.tqdm(hdf5_files, total=len(hdf5_files)):
|
||||
for ep_idx, ep_path in tqdm.tqdm(enumerate(hdf5_files), total=len(hdf5_files)):
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
ep_idx = int(re.search(r"episode_(\d+)", ep_path.name).group(1))
|
||||
num_frames = ep["/action"].shape[0]
|
||||
|
||||
# last step of demonstration is considered done
|
||||
@@ -76,6 +89,10 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
if "/observations/qvel" in ep:
|
||||
velocity = torch.from_numpy(ep["/observations/qvel"][:])
|
||||
if "/observations/effort" in ep:
|
||||
effort = torch.from_numpy(ep["/observations/effort"][:])
|
||||
|
||||
ep_dict = {}
|
||||
|
||||
@@ -116,6 +133,10 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
ep_dict["observation.state"] = state
|
||||
if "/observations/velocity" in ep:
|
||||
ep_dict["observation.velocity"] = velocity
|
||||
if "/observations/effort" in ep:
|
||||
ep_dict["observation.effort"] = effort
|
||||
ep_dict["action"] = action
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
@@ -131,6 +152,8 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
|
||||
id_from += num_frames
|
||||
|
||||
gc.collect()
|
||||
|
||||
# process first episode only
|
||||
if debug:
|
||||
break
|
||||
@@ -152,6 +175,14 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.velocity" in data_dict:
|
||||
features["observation.velocity"] = Sequence(
|
||||
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.effort" in data_dict:
|
||||
features["observation.effort"] = Sequence(
|
||||
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from copy import deepcopy
|
||||
from math import ceil
|
||||
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Process zarr files formatted like in: https://github.com/real-stanford/diffusion_policy"""
|
||||
|
||||
import shutil
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Process UMI (Universal Manipulation Interface) data stored in Zarr format like in: https://github.com/real-stanford/universal_manipulation_interface"""
|
||||
|
||||
import logging
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Process pickle files formatted like in: https://github.com/fyhMer/fowm"""
|
||||
|
||||
import pickle
|
||||
|
||||
@@ -1,5 +1,22 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
@@ -64,7 +81,23 @@ def hf_transform_to_torch(items_dict):
|
||||
def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
if root is not None:
|
||||
hf_dataset = load_from_disk(str(Path(root) / repo_id / split))
|
||||
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
|
||||
# TODO(rcadene): clean this which enables getting a subset of dataset
|
||||
if split != "train":
|
||||
if "%" in split:
|
||||
raise NotImplementedError(f"We dont support splitting based on percentage for now ({split}).")
|
||||
match_from = re.search(r"train\[(\d+):\]", split)
|
||||
match_to = re.search(r"train\[:(\d+)\]", split)
|
||||
if match_from:
|
||||
from_frame_index = int(match_from.group(1))
|
||||
hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset)))
|
||||
elif match_to:
|
||||
to_frame_index = int(match_to.group(1))
|
||||
hf_dataset = hf_dataset.select(range(to_frame_index))
|
||||
else:
|
||||
raise ValueError(
|
||||
f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"'
|
||||
)
|
||||
else:
|
||||
hf_dataset = load_dataset(repo_id, revision=version, split=split)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
@@ -230,6 +263,84 @@ def load_previous_and_future_frames(
|
||||
return item
|
||||
|
||||
|
||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
|
||||
|
||||
Parameters:
|
||||
- hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index.
|
||||
|
||||
Returns:
|
||||
- episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys:
|
||||
- "from": A tensor containing the starting index of each episode.
|
||||
- "to": A tensor containing the ending index of each episode.
|
||||
"""
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
current_episode = None
|
||||
"""
|
||||
The episode_index is a list of integers, each representing the episode index of the corresponding example.
|
||||
For instance, the following is a valid episode_index:
|
||||
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
|
||||
|
||||
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
|
||||
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
|
||||
{
|
||||
"from": [0, 3, 7],
|
||||
"to": [3, 7, 12]
|
||||
}
|
||||
"""
|
||||
if len(hf_dataset) == 0:
|
||||
episode_data_index = {
|
||||
"from": torch.tensor([]),
|
||||
"to": torch.tensor([]),
|
||||
}
|
||||
return episode_data_index
|
||||
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
||||
if episode_idx != current_episode:
|
||||
# We encountered a new episode, so we append its starting location to the "from" list
|
||||
episode_data_index["from"].append(idx)
|
||||
# If this is not the first episode, we append the ending location of the previous episode to the "to" list
|
||||
if current_episode is not None:
|
||||
episode_data_index["to"].append(idx)
|
||||
# Let's keep track of the current episode index
|
||||
current_episode = episode_idx
|
||||
else:
|
||||
# We are still in the same episode, so there is nothing for us to do here
|
||||
pass
|
||||
# We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list
|
||||
episode_data_index["to"].append(idx + 1)
|
||||
|
||||
for k in ["from", "to"]:
|
||||
episode_data_index[k] = torch.tensor(episode_data_index[k])
|
||||
|
||||
return episode_data_index
|
||||
|
||||
|
||||
def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
|
||||
"""
|
||||
Reset the `episode_index` of the provided HuggingFace Dataset.
|
||||
|
||||
`episode_data_index` (and related functionality such as `load_previous_and_future_frames`) requires the
|
||||
`episode_index` to be sorted, continuous (1,1,1 and not 1,2,1) and start at 0.
|
||||
|
||||
This brings the `episode_index` to the required format.
|
||||
"""
|
||||
if len(hf_dataset) == 0:
|
||||
return hf_dataset
|
||||
unique_episode_idxs = torch.stack(hf_dataset["episode_index"]).unique().tolist()
|
||||
episode_idx_to_reset_idx_mapping = {
|
||||
ep_id: reset_ep_id for reset_ep_id, ep_id in enumerate(unique_episode_idxs)
|
||||
}
|
||||
|
||||
def modify_ep_idx_func(example):
|
||||
example["episode_index"] = episode_idx_to_reset_idx_mapping[example["episode_index"].item()]
|
||||
return example
|
||||
|
||||
hf_dataset = hf_dataset.map(modify_ep_idx_func)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def cycle(iterable):
|
||||
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
|
||||
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import subprocess
|
||||
import warnings
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
|
||||
import gymnasium as gym
|
||||
@@ -13,11 +28,11 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
||||
raise ValueError("`n_envs must be at least 1")
|
||||
|
||||
kwargs = {
|
||||
"obs_type": "pixels_agent_pos",
|
||||
"render_mode": "rgb_array",
|
||||
# "obs_type": "pixels_agent_pos",
|
||||
# "render_mode": "rgb_array",
|
||||
"max_episode_steps": cfg.env.episode_length,
|
||||
"visualization_width": 384,
|
||||
"visualization_height": 384,
|
||||
# "visualization_width": 384,
|
||||
# "visualization_height": 384,
|
||||
}
|
||||
|
||||
package_name = f"gym_{cfg.env.name}"
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# TODO(rcadene, alexander-soare): clean this file
|
||||
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py"""
|
||||
|
||||
@@ -82,9 +97,9 @@ class Logger:
|
||||
# Also save the full Hydra config for the env configuration.
|
||||
OmegaConf.save(self._cfg, save_dir / "config.yaml")
|
||||
if self._wandb and not self._disable_wandb_artifact:
|
||||
# note wandb artifact does not accept ":" in its name
|
||||
# note wandb artifact does not accept ":" or "/" in its name
|
||||
artifact = self._wandb.Artifact(
|
||||
self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier),
|
||||
f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
|
||||
type="model",
|
||||
)
|
||||
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
||||
@@ -94,9 +109,10 @@ class Logger:
|
||||
self._buffer_dir.mkdir(parents=True, exist_ok=True)
|
||||
fp = self._buffer_dir / f"{str(identifier)}.pkl"
|
||||
buffer.save(fp)
|
||||
if self._wandb:
|
||||
if self._wandb and not self._disable_wandb_artifact:
|
||||
# note wandb artifact does not accept ":" or "/" in its name
|
||||
artifact = self._wandb.Artifact(
|
||||
self._group + "-" + str(self._seed) + "-" + str(identifier),
|
||||
f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
|
||||
type="buffer",
|
||||
)
|
||||
artifact.add_file(fp)
|
||||
@@ -114,6 +130,11 @@ class Logger:
|
||||
assert mode in {"train", "eval"}
|
||||
if self._wandb is not None:
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@@ -51,8 +66,12 @@ class ACTConfig:
|
||||
documentation in the policy class).
|
||||
latent_dim: The VAE's latent dimension.
|
||||
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
|
||||
use_temporal_aggregation: Whether to blend the actions of multiple policy invocations for any given
|
||||
environment step.
|
||||
temporal_ensemble_momentum: Exponential moving average (EMA) momentum parameter (α) for ensembling
|
||||
actions for a given time step over multiple policy invocations. Updates are calculated as:
|
||||
x⁻ₙ = αx⁻ₙ₋₁ + (1-α)xₙ. Note that the ACT paper and original ACT code describes a different
|
||||
parameter here: they refer to a weighting scheme wᵢ = exp(-m⋅i) and set m = 0.01. With our
|
||||
formulation, this is equivalent to α = exp(-0.01) ≈ 0.99. When this parameter is provided, we
|
||||
require `n_action_steps == 1` (since we need to query the policy every step anyway).
|
||||
dropout: Dropout to use in the transformer layers (see code for details).
|
||||
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
|
||||
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
|
||||
@@ -100,6 +119,9 @@ class ACTConfig:
|
||||
dim_feedforward: int = 3200
|
||||
feedforward_activation: str = "relu"
|
||||
n_encoder_layers: int = 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: int = 1
|
||||
# VAE.
|
||||
use_vae: bool = True
|
||||
@@ -107,7 +129,7 @@ class ACTConfig:
|
||||
n_vae_encoder_layers: int = 4
|
||||
|
||||
# Inference.
|
||||
use_temporal_aggregation: bool = False
|
||||
temporal_ensemble_momentum: float | None = None
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: float = 0.1
|
||||
@@ -119,8 +141,11 @@ class ACTConfig:
|
||||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
)
|
||||
if self.use_temporal_aggregation:
|
||||
raise NotImplementedError("Temporal aggregation is not yet implemented.")
|
||||
if self.temporal_ensemble_momentum is not None and self.n_action_steps > 1:
|
||||
raise NotImplementedError(
|
||||
"`n_action_steps` must be 1 when using temporal ensembling. This is "
|
||||
"because the policy needs to be queried every step to compute the ensembled action."
|
||||
)
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||
@@ -130,10 +155,3 @@ class ACTConfig:
|
||||
raise ValueError(
|
||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||
)
|
||||
# Check that there is only one image.
|
||||
# TODO(alexander-soare): generalize this to multiple images.
|
||||
if (
|
||||
sum(k.startswith("observation.images.") for k in self.input_shapes) != 1
|
||||
or "observation.images.top" not in self.input_shapes
|
||||
):
|
||||
raise ValueError('For now, only "observation.images.top" is accepted for an image input.')
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Action Chunking Transformer Policy
|
||||
|
||||
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
|
||||
@@ -46,7 +61,8 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
super().__init__()
|
||||
if config is None:
|
||||
config = ACTConfig()
|
||||
self.config = config
|
||||
self.config: ACTConfig = config
|
||||
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
@@ -56,11 +72,18 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
self.model = ACT(config)
|
||||
|
||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
if self.config.n_action_steps is not None:
|
||||
if self.config.temporal_ensemble_momentum is not None:
|
||||
self._ensembled_actions = None
|
||||
else:
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
@torch.no_grad
|
||||
@@ -71,37 +94,56 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
assert "observation.images.top" in batch
|
||||
assert "observation.state" in batch
|
||||
|
||||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
self._stack_images(batch)
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
|
||||
# If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return
|
||||
# the first action.
|
||||
if self.config.temporal_ensemble_momentum is not None:
|
||||
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
if self._ensembled_actions is None:
|
||||
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
|
||||
# time step of the episode.
|
||||
self._ensembled_actions = actions.clone()
|
||||
else:
|
||||
# self._ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
|
||||
# the EMA update for those entries.
|
||||
alpha = self.config.temporal_ensemble_momentum
|
||||
self._ensembled_actions = alpha * self._ensembled_actions + (1 - alpha) * actions[:, :-1]
|
||||
# The last action, which has no prior moving average, needs to get concatenated onto the end.
|
||||
self._ensembled_actions = torch.cat([self._ensembled_actions, actions[:, -1:]], dim=1)
|
||||
# "Consume" the first action.
|
||||
action, self._ensembled_actions = self._ensembled_actions[:, 0], self._ensembled_actions[:, 1:]
|
||||
return action
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._action_queue) == 0:
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
actions = self.model(batch)[0][: self.config.n_action_steps]
|
||||
actions = self.model(batch)[0][:, : self.config.n_action_steps]
|
||||
|
||||
# TODO(rcadene): make _forward return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
self._stack_images(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
l1_loss = (
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss}
|
||||
loss_dict = {"l1_loss": l1_loss.item()}
|
||||
if self.config.use_vae:
|
||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||
# each dimension independently, we sum over the latent dimension to get the total
|
||||
@@ -110,28 +152,13 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
mean_kld = (
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss_dict["kld_loss"] = mean_kld
|
||||
loss_dict["kld_loss"] = mean_kld.item()
|
||||
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
||||
else:
|
||||
loss_dict["loss"] = l1_loss
|
||||
|
||||
return loss_dict
|
||||
|
||||
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
|
||||
|
||||
This function expects `batch` to have (at least):
|
||||
{
|
||||
"observation.state": (B, state_dim) batch of robot states.
|
||||
"observation.images.{name}": (B, C, H, W) tensor of images.
|
||||
}
|
||||
"""
|
||||
# Stack images in the order dictated by input_shapes.
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[k] for k in self.config.input_shapes if k.startswith("observation.images.")],
|
||||
dim=-4,
|
||||
)
|
||||
|
||||
|
||||
class ACT(nn.Module):
|
||||
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.
|
||||
@@ -161,10 +188,10 @@ class ACT(nn.Module):
|
||||
│ encoder │ │ │ │Transf.│ │
|
||||
│ │ │ │ │encoder│ │
|
||||
└───▲─────┘ │ │ │ │ │
|
||||
│ │ │ └───▲───┘ │
|
||||
│ │ │ │ │
|
||||
inputs └─────┼─────┘ │
|
||||
│ │
|
||||
│ │ │ └▲──▲─▲─┘ │
|
||||
│ │ │ │ │ │ │
|
||||
inputs └─────┼──┘ │ image emb. │
|
||||
│ state emb. │
|
||||
└───────────────────────┘
|
||||
"""
|
||||
|
||||
@@ -306,18 +333,18 @@ class ACT(nn.Module):
|
||||
all_cam_features.append(cam_features)
|
||||
all_cam_pos_embeds.append(cam_pos_embed)
|
||||
# Concatenate camera observation feature maps and positional embeddings along the width dimension.
|
||||
encoder_in = torch.cat(all_cam_features, axis=3)
|
||||
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3)
|
||||
encoder_in = torch.cat(all_cam_features, axis=-1)
|
||||
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
|
||||
|
||||
# Get positional embeddings for robot state and latent.
|
||||
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"])
|
||||
latent_embed = self.encoder_latent_input_proj(latent_sample)
|
||||
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
|
||||
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
|
||||
|
||||
# Stack encoder input and positional embeddings moving to (S, B, C).
|
||||
encoder_in = torch.cat(
|
||||
[
|
||||
torch.stack([latent_embed, robot_state_embed], axis=0),
|
||||
encoder_in.flatten(2).permute(2, 0, 1),
|
||||
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
|
||||
]
|
||||
)
|
||||
pos_embed = torch.cat(
|
||||
|
||||
@@ -1,3 +1,19 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@@ -51,6 +67,7 @@ class DiffusionConfig:
|
||||
use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning.
|
||||
Bias modulation is used be default, while this parameter indicates whether to also use scale
|
||||
modulation.
|
||||
noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"].
|
||||
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
|
||||
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
|
||||
beta_start: Beta value for the first forward-diffusion step.
|
||||
@@ -110,6 +127,7 @@ class DiffusionConfig:
|
||||
diffusion_step_embed_dim: int = 128
|
||||
use_film_scale_modulation: bool = True
|
||||
# Noise scheduler.
|
||||
noise_scheduler_type: str = "DDPM"
|
||||
num_train_timesteps: int = 100
|
||||
beta_schedule: str = "squaredcos_cap_v2"
|
||||
beta_start: float = 0.0001
|
||||
@@ -130,17 +148,30 @@ class DiffusionConfig:
|
||||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
)
|
||||
# There should only be one image key.
|
||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||
if len(image_keys) != 1:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
||||
)
|
||||
image_key = next(iter(image_keys))
|
||||
if (
|
||||
self.crop_shape[0] > self.input_shapes["observation.image"][1]
|
||||
or self.crop_shape[1] > self.input_shapes["observation.image"][2]
|
||||
self.crop_shape[0] > self.input_shapes[image_key][1]
|
||||
or self.crop_shape[1] > self.input_shapes[image_key][2]
|
||||
):
|
||||
raise ValueError(
|
||||
f'`crop_shape` should fit within `input_shapes["observation.image"]`. Got {self.crop_shape} '
|
||||
f'for `crop_shape` and {self.input_shapes["observation.image"]} for '
|
||||
'`input_shapes["observation.image"]`.'
|
||||
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {self.input_shapes[image_key]} for "
|
||||
"`input_shapes[{image_key}]`."
|
||||
)
|
||||
supported_prediction_types = ["epsilon", "sample"]
|
||||
if self.prediction_type not in supported_prediction_types:
|
||||
raise ValueError(
|
||||
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
|
||||
)
|
||||
supported_noise_schedulers = ["DDPM", "DDIM"]
|
||||
if self.noise_scheduler_type not in supported_noise_schedulers:
|
||||
raise ValueError(
|
||||
f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
|
||||
f"Got {self.noise_scheduler_type}."
|
||||
)
|
||||
|
||||
@@ -1,8 +1,24 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
||||
|
||||
TODO(alexander-soare):
|
||||
- Remove reliance on Robomimic for SpatialSoftmax.
|
||||
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
|
||||
- Make compatible with multiple image keys.
|
||||
"""
|
||||
|
||||
import math
|
||||
@@ -10,12 +26,13 @@ from collections import deque
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from robomimic.models.base_nets import SpatialSoftmax
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
@@ -66,10 +83,18 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
|
||||
self.diffusion = DiffusionModel(config)
|
||||
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
||||
if len(image_keys) != 1:
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
||||
)
|
||||
self.input_image_key = image_keys[0]
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Clear observation and action queues. Should be called on `env.reset()`
|
||||
"""
|
||||
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||
self._queues = {
|
||||
"observation.image": deque(maxlen=self.config.n_obs_steps),
|
||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||
@@ -98,16 +123,14 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
"horizon" may not the best name to describe what the variable actually means, because this period is
|
||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||
"""
|
||||
assert "observation.image" in batch
|
||||
assert "observation.state" in batch
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
# stack n latest observations from the queue
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
|
||||
# TODO(rcadene): make above methods return output dictionary?
|
||||
@@ -121,11 +144,25 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
batch = self.normalize_targets(batch)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
return {"loss": loss}
|
||||
|
||||
|
||||
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
|
||||
"""
|
||||
Factory for noise scheduler instances of the requested type. All kwargs are passed
|
||||
to the scheduler.
|
||||
"""
|
||||
if name == "DDPM":
|
||||
return DDPMScheduler(**kwargs)
|
||||
elif name == "DDIM":
|
||||
return DDIMScheduler(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported noise scheduler type {name}")
|
||||
|
||||
|
||||
class DiffusionModel(nn.Module):
|
||||
def __init__(self, config: DiffusionConfig):
|
||||
super().__init__()
|
||||
@@ -138,12 +175,12 @@ class DiffusionModel(nn.Module):
|
||||
* config.n_obs_steps,
|
||||
)
|
||||
|
||||
self.noise_scheduler = DDPMScheduler(
|
||||
self.noise_scheduler = _make_noise_scheduler(
|
||||
config.noise_scheduler_type,
|
||||
num_train_timesteps=config.num_train_timesteps,
|
||||
beta_start=config.beta_start,
|
||||
beta_end=config.beta_end,
|
||||
beta_schedule=config.beta_schedule,
|
||||
variance_type="fixed_small",
|
||||
clip_sample=config.clip_sample,
|
||||
clip_sample_range=config.clip_sample_range,
|
||||
prediction_type=config.prediction_type,
|
||||
@@ -185,13 +222,12 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""
|
||||
This function expects `batch` to have (at least):
|
||||
This function expects `batch` to have:
|
||||
{
|
||||
"observation.state": (B, n_obs_steps, state_dim)
|
||||
"observation.image": (B, n_obs_steps, C, H, W)
|
||||
}
|
||||
"""
|
||||
assert set(batch).issuperset({"observation.state", "observation.image"})
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
@@ -275,6 +311,77 @@ class DiffusionModel(nn.Module):
|
||||
return loss.mean()
|
||||
|
||||
|
||||
class SpatialSoftmax(nn.Module):
|
||||
"""
|
||||
Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
|
||||
(https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation.
|
||||
|
||||
At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass"
|
||||
of activations of each channel, i.e., keypoints in the image space for the policy to focus on.
|
||||
|
||||
Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2):
|
||||
-----------------------------------------------------
|
||||
| (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) |
|
||||
| (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) |
|
||||
| ... | ... | ... | ... |
|
||||
| (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) |
|
||||
-----------------------------------------------------
|
||||
This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot
|
||||
product with the coordinates (120x2) to get expected points of maximal activation (512x2).
|
||||
|
||||
The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally
|
||||
provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable
|
||||
linear mapping (in_channels, H, W) -> (num_kp, H, W).
|
||||
"""
|
||||
|
||||
def __init__(self, input_shape, num_kp=None):
|
||||
"""
|
||||
Args:
|
||||
input_shape (list): (C, H, W) input feature map shape.
|
||||
num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
assert len(input_shape) == 3
|
||||
self._in_c, self._in_h, self._in_w = input_shape
|
||||
|
||||
if num_kp is not None:
|
||||
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
|
||||
self._out_c = num_kp
|
||||
else:
|
||||
self.nets = None
|
||||
self._out_c = self._in_c
|
||||
|
||||
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
|
||||
# and causes a small degradation in pc_success of pre-trained models.
|
||||
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
|
||||
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
|
||||
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
|
||||
# register as buffer so it's moved to the correct device.
|
||||
self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
|
||||
|
||||
def forward(self, features: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
features: (B, C, H, W) input feature maps.
|
||||
Returns:
|
||||
(B, K, 2) image-space coordinates of keypoints.
|
||||
"""
|
||||
if self.nets is not None:
|
||||
features = self.nets(features)
|
||||
|
||||
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
|
||||
features = features.reshape(-1, self._in_h * self._in_w)
|
||||
# 2d softmax normalization
|
||||
attention = F.softmax(features, dim=-1)
|
||||
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
|
||||
expected_xy = attention @ self.pos_grid
|
||||
# reshape to [B, K, 2]
|
||||
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
|
||||
|
||||
return feature_keypoints
|
||||
|
||||
|
||||
class DiffusionRgbEncoder(nn.Module):
|
||||
"""Encoder an RGB image into a 1D feature vector.
|
||||
|
||||
@@ -315,11 +422,16 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
|
||||
# Set up pooling and final layers.
|
||||
# Use a dry run to get the feature map shape.
|
||||
# The dummy input should take the number of image channels from `config.input_shapes` and it should
|
||||
# use the height and width from `config.crop_shape`.
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
assert len(image_keys) == 1
|
||||
image_key = image_keys[0]
|
||||
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape))
|
||||
with torch.inference_mode():
|
||||
feat_map_shape = tuple(
|
||||
self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:]
|
||||
)
|
||||
self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
dummy_feature_map = self.backbone(dummy_input)
|
||||
feature_map_shape = tuple(dummy_feature_map.shape[1:])
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
@@ -1,4 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
@@ -8,9 +24,10 @@ from lerobot.common.utils.utils import get_safe_torch_device
|
||||
|
||||
def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
||||
expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
|
||||
assert set(hydra_cfg.policy).issuperset(
|
||||
expected_kwargs
|
||||
), f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
|
||||
if not set(hydra_cfg.policy).issuperset(expected_kwargs):
|
||||
logging.warning(
|
||||
f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
|
||||
)
|
||||
policy_cfg = policy_cfg_class(
|
||||
**{
|
||||
k: v
|
||||
@@ -62,11 +79,18 @@ def make_policy(
|
||||
|
||||
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
|
||||
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
|
||||
if pretrained_policy_name_or_path is None:
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
|
||||
# Make a fresh policy.
|
||||
policy = policy_cls(policy_cfg, dataset_stats)
|
||||
else:
|
||||
policy = policy_cls.from_pretrained(pretrained_policy_name_or_path)
|
||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||
# hyperparameters that we want to vary).
|
||||
# TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with, pretrained
|
||||
# weights which are then loaded into a fresh policy with the desired config. This PR in huggingface_hub should
|
||||
# make it possible to avoid the hack: https://github.com/huggingface/huggingface_hub/pull/2274.
|
||||
policy = policy_cls(policy_cfg)
|
||||
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
|
||||
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""A protocol that all policies should follow.
|
||||
|
||||
This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes
|
||||
@@ -38,7 +53,8 @@ class Policy(Protocol):
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict:
|
||||
"""Run the batch through the model and compute the loss for training or validation.
|
||||
|
||||
Returns a dictionary with "loss" and maybe other information.
|
||||
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
|
||||
other items should be logging-friendly, native Python types.
|
||||
"""
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]):
|
||||
|
||||
@@ -1,3 +1,19 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@@ -47,7 +63,7 @@ class TDMPCConfig:
|
||||
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
|
||||
elites, when updating the gaussian parameters for CEM.
|
||||
gaussian_mean_momentum: Momentum (α) used for EMA updates of the mean parameter μ of the gaussian
|
||||
paramters optimized in CEM. Updates are calculated as μ⁻ ← αμ⁻ + (1-α)μ.
|
||||
parameters optimized in CEM. Updates are calculated as μ⁻ ← αμ⁻ + (1-α)μ.
|
||||
max_random_shift_ratio: Maximum random shift (as a proportion of the image size) to apply to the
|
||||
image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation
|
||||
is applied. Note that the input images are assumed to be square for this augmentation.
|
||||
@@ -131,12 +147,18 @@ class TDMPCConfig:
|
||||
|
||||
def __post_init__(self):
|
||||
"""Input validation (not exhaustive)."""
|
||||
if self.input_shapes["observation.image"][-2] != self.input_shapes["observation.image"][-1]:
|
||||
# There should only be one image key.
|
||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||
if len(image_keys) != 1:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
||||
)
|
||||
image_key = next(iter(image_keys))
|
||||
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
|
||||
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||
# augmentation. It should be able to be removed.
|
||||
raise ValueError(
|
||||
"Only square images are handled now. Got image shape "
|
||||
f"{self.input_shapes['observation.image']}."
|
||||
f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
|
||||
)
|
||||
if self.n_gaussian_samples <= 0:
|
||||
raise ValueError(
|
||||
|
||||
@@ -1,3 +1,19 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Implementation of Finetuning Offline World Models in the Real World.
|
||||
|
||||
The comments in this code may sometimes refer to these references:
|
||||
@@ -96,13 +112,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
def save(self, fp):
|
||||
"""Save state dict of TOLD model to filepath."""
|
||||
torch.save(self.state_dict(), fp)
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
||||
assert len(image_keys) == 1
|
||||
self.input_image_key = image_keys[0]
|
||||
|
||||
def load(self, fp):
|
||||
"""Load a saved state dict from filepath into current agent."""
|
||||
self.load_state_dict(torch.load(fp))
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
@@ -121,10 +136,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]):
|
||||
"""Select a single action given environment observations."""
|
||||
assert "observation.image" in batch
|
||||
assert "observation.state" in batch
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -303,13 +316,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
info = {}
|
||||
|
||||
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
|
||||
batch_size = batch["index"].shape[0]
|
||||
|
||||
# (b, t) -> (t, b)
|
||||
for key in batch:
|
||||
if batch[key].ndim > 1:
|
||||
@@ -337,6 +348,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
# Run latent rollout using the latent dynamics model and policy model.
|
||||
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
|
||||
# gives us a next `z`.
|
||||
batch_size = batch["index"].shape[0]
|
||||
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
|
||||
z_preds[0] = self.model.encode(current_observation)
|
||||
reward_preds = torch.empty_like(reward, device=device)
|
||||
|
||||
@@ -1,9 +1,28 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def populate_queues(queues, batch):
|
||||
for key in batch:
|
||||
# Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the
|
||||
# queues have the keys they want).
|
||||
if key not in queues:
|
||||
continue
|
||||
if len(queues[key]) != queues[key].maxlen:
|
||||
# initialize by copying the first observation several times until the queue is full
|
||||
while len(queues[key]) != queues[key].maxlen:
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import logging
|
||||
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import warnings
|
||||
|
||||
import imageio
|
||||
|
||||
@@ -1,8 +1,25 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os.path as osp
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
@@ -39,6 +56,31 @@ def set_global_seed(seed):
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def seeded_context(seed: int) -> Generator[None, None, None]:
|
||||
"""Set the seed when entering a context, and restore the prior random state at exit.
|
||||
|
||||
Example usage:
|
||||
|
||||
```
|
||||
a = random.random() # produces some random number
|
||||
with seeded_context(1337):
|
||||
b = random.random() # produces some other random number
|
||||
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
|
||||
```
|
||||
"""
|
||||
random_state = random.getstate()
|
||||
np_random_state = np.random.get_state()
|
||||
torch_random_state = torch.random.get_rng_state()
|
||||
torch_cuda_random_state = torch.cuda.random.get_rng_state()
|
||||
set_global_seed(seed)
|
||||
yield None
|
||||
random.setstate(random_state)
|
||||
np.random.set_state(np_random_state)
|
||||
torch.random.set_rng_state(torch_random_state)
|
||||
torch.cuda.random.set_rng_state(torch_cuda_random_state)
|
||||
|
||||
|
||||
def init_logging():
|
||||
def custom_format(record):
|
||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
@@ -35,7 +35,7 @@ eval:
|
||||
use_async_envs: false
|
||||
|
||||
wandb:
|
||||
enable: true
|
||||
enable: false
|
||||
# Set to true to disable saving an artifact despite save_model == True
|
||||
disable_artifact: false
|
||||
project: lerobot
|
||||
|
||||
14
lerobot/configs/env/dora.yaml
vendored
Normal file
14
lerobot/configs/env/dora.yaml
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
# @package _global_
|
||||
|
||||
fps: 30
|
||||
|
||||
env:
|
||||
name: dora
|
||||
task: DoraAloha-v0
|
||||
# from_pixels: True
|
||||
# pixels_only: False
|
||||
# image_size: [3, 480, 640]
|
||||
episode_length: 400
|
||||
# fps: ${fps}
|
||||
# state_dim: 14
|
||||
# action_dim: 14
|
||||
@@ -3,6 +3,12 @@
|
||||
seed: 1000
|
||||
dataset_repo_id: lerobot/aloha_sim_insertion_human
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.top:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
@@ -18,12 +24,6 @@ training:
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.top:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
@@ -66,6 +66,9 @@ policy:
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
@@ -73,7 +76,7 @@ policy:
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
use_temporal_aggregation: false
|
||||
temporal_ensemble_momentum: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
|
||||
101
lerobot/configs/policy/act_real_world.yaml
Normal file
101
lerobot/configs/policy/act_real_world.yaml
Normal file
@@ -0,0 +1,101 @@
|
||||
# @package _global_
|
||||
|
||||
seed: 1000
|
||||
dataset_repo_id: cadene/aloha_v2_static_dora_test
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.cam_right_wrist:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_left_wrist:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_high:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_low:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
eval_freq: 99999999999999
|
||||
save_freq: 1000
|
||||
log_freq: 100
|
||||
save_model: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100 # chunk_size
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.cam_right_wrist: [3, 480, 640]
|
||||
observation.images.cam_left_wrist: [3, 480, 640]
|
||||
observation.images.cam_high: [3, 480, 640]
|
||||
observation.images.cam_low: [3, 480, 640]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.cam_right_wrist: mean_std
|
||||
observation.images.cam_left_wrist: mean_std
|
||||
observation.images.cam_high: mean_std
|
||||
observation.images.cam_low: mean_std
|
||||
observation.state: mean_std
|
||||
output_normalization_modes:
|
||||
action: mean_std
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_momentum: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
||||
@@ -7,6 +7,20 @@
|
||||
seed: 100000
|
||||
dataset_repo_id: lerobot/pusht
|
||||
|
||||
override_dataset_stats:
|
||||
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
|
||||
observation.image:
|
||||
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
||||
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
||||
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
|
||||
# from the original codebase, but we should remove these and train our own pretrained model
|
||||
observation.state:
|
||||
min: [13.456424, 32.938293]
|
||||
max: [496.14618, 510.9579]
|
||||
action:
|
||||
min: [12.0, 25.0]
|
||||
max: [511.0, 511.0]
|
||||
|
||||
training:
|
||||
offline_steps: 200000
|
||||
online_steps: 0
|
||||
@@ -34,20 +48,6 @@ eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
override_dataset_stats:
|
||||
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
|
||||
observation.image:
|
||||
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
||||
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
||||
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
|
||||
# from the original codebase, but we should remove these and train our own pretrained model
|
||||
observation.state:
|
||||
min: [13.456424, 32.938293]
|
||||
max: [496.14618, 510.9579]
|
||||
action:
|
||||
min: [12.0, 25.0]
|
||||
max: [511.0, 511.0]
|
||||
|
||||
policy:
|
||||
name: diffusion
|
||||
|
||||
@@ -85,6 +85,7 @@ policy:
|
||||
diffusion_step_embed_dim: 128
|
||||
use_film_scale_modulation: True
|
||||
# Noise scheduler.
|
||||
noise_scheduler_type: DDPM
|
||||
num_train_timesteps: 100
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
beta_start: 0.0001
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# @package _global_
|
||||
|
||||
seed: 1
|
||||
dataset_repo_id: lerobot/xarm_lift_medium_replay
|
||||
dataset_repo_id: lerobot/xarm_lift_medium
|
||||
|
||||
training:
|
||||
offline_steps: 25000
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import platform
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Evaluate a policy on an environment by running rollouts and computing metrics.
|
||||
|
||||
Usage examples:
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Use this script to convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub,
|
||||
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
|
||||
@@ -10,7 +25,6 @@ python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--dataset-id pusht \
|
||||
--raw-format pusht_zarr \
|
||||
--community-id lerobot \
|
||||
--revision v1.2 \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
@@ -21,7 +35,6 @@ python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--dataset-id xarm_lift_medium \
|
||||
--raw-format xarm_pkl \
|
||||
--community-id lerobot \
|
||||
--revision v1.2 \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
@@ -32,7 +45,6 @@ python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--dataset-id aloha_sim_insertion_scripted \
|
||||
--raw-format aloha_hdf5 \
|
||||
--community-id lerobot \
|
||||
--revision v1.2 \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
@@ -43,7 +55,6 @@ python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--dataset-id umi_cup_in_the_wild \
|
||||
--raw-format umi_zarr \
|
||||
--community-id lerobot \
|
||||
--revision v1.2 \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
@@ -73,10 +84,14 @@ def get_from_raw_to_lerobot_format_fn(raw_format):
|
||||
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "aloha_hdf5":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "aloha_dora":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_dora_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "xarm_pkl":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
|
||||
else:
|
||||
raise ValueError(raw_format)
|
||||
raise ValueError(
|
||||
f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?"
|
||||
)
|
||||
|
||||
return from_raw_to_lerobot_format
|
||||
|
||||
@@ -129,7 +144,8 @@ def push_videos_to_hub(repo_id, videos_dir, revision):
|
||||
|
||||
|
||||
def push_dataset_to_hub(
|
||||
data_dir: Path,
|
||||
input_data_dir: Path,
|
||||
output_data_dir: Path,
|
||||
dataset_id: str,
|
||||
raw_format: str | None,
|
||||
community_id: str,
|
||||
@@ -146,34 +162,33 @@ def push_dataset_to_hub(
|
||||
):
|
||||
repo_id = f"{community_id}/{dataset_id}"
|
||||
|
||||
raw_dir = data_dir / f"{dataset_id}_raw"
|
||||
|
||||
out_dir = data_dir / repo_id
|
||||
meta_data_dir = out_dir / "meta_data"
|
||||
videos_dir = out_dir / "videos"
|
||||
meta_data_dir = output_data_dir / "meta_data"
|
||||
videos_dir = output_data_dir / "videos"
|
||||
|
||||
tests_out_dir = tests_data_dir / repo_id
|
||||
tests_meta_data_dir = tests_out_dir / "meta_data"
|
||||
tests_videos_dir = tests_out_dir / "videos"
|
||||
|
||||
if out_dir.exists():
|
||||
shutil.rmtree(out_dir)
|
||||
if output_data_dir.exists():
|
||||
shutil.rmtree(output_data_dir)
|
||||
|
||||
if tests_out_dir.exists() and save_tests_to_disk:
|
||||
shutil.rmtree(tests_out_dir)
|
||||
|
||||
if not raw_dir.exists():
|
||||
download_raw(raw_dir, dataset_id)
|
||||
if not input_data_dir.exists():
|
||||
download_raw(input_data_dir, dataset_id)
|
||||
|
||||
if raw_format is None:
|
||||
# TODO(rcadene, adilzouitine): implement auto_find_raw_format
|
||||
raise NotImplementedError()
|
||||
# raw_format = auto_find_raw_format(raw_dir)
|
||||
# raw_format = auto_find_raw_format(input_data_dir)
|
||||
|
||||
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
||||
|
||||
# convert dataset from original raw format to LeRobot format
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug)
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
|
||||
input_data_dir, output_data_dir, fps, video, debug
|
||||
)
|
||||
|
||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||
repo_id=repo_id,
|
||||
@@ -187,7 +202,7 @@ def push_dataset_to_hub(
|
||||
|
||||
if save_to_disk:
|
||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||
hf_dataset.save_to_disk(str(out_dir / "train"))
|
||||
hf_dataset.save_to_disk(str(output_data_dir / "train"))
|
||||
|
||||
if not dry_run or save_to_disk:
|
||||
# mandatory for upload
|
||||
@@ -212,8 +227,7 @@ def push_dataset_to_hub(
|
||||
test_hf_dataset = test_hf_dataset.with_format(None)
|
||||
test_hf_dataset.save_to_disk(str(tests_out_dir / "train"))
|
||||
|
||||
# copy meta data to tests directory
|
||||
shutil.copytree(meta_data_dir, tests_meta_data_dir)
|
||||
save_meta_data(info, stats, episode_data_index, tests_meta_data_dir)
|
||||
|
||||
# copy videos of first episode to tests directory
|
||||
episode_index = 0
|
||||
@@ -222,15 +236,25 @@ def push_dataset_to_hub(
|
||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||
shutil.copy(videos_dir / fname, tests_videos_dir / fname)
|
||||
|
||||
if not save_to_disk and output_data_dir.exists():
|
||||
# remove possible temporary files remaining in the output directory
|
||||
shutil.rmtree(output_data_dir)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
"--input-data-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Root directory containing datasets (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).",
|
||||
help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-data-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Root directory containing output dataset (e.g. `data/lerobot/aloha_mobile_chair` or `data/lerobot/pusht`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-id",
|
||||
@@ -299,7 +323,7 @@ def main():
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=16,
|
||||
default=8,
|
||||
help="Number of processes of Dataloader for computing the dataset statistics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import time
|
||||
from copy import deepcopy
|
||||
@@ -8,6 +23,7 @@ import hydra
|
||||
import torch
|
||||
from datasets import concatenate_datasets
|
||||
from datasets.utils import disable_progress_bars, enable_progress_bars
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
@@ -72,6 +88,7 @@ def make_optimizer_and_scheduler(cfg, policy):
|
||||
|
||||
|
||||
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
||||
"""Returns a dictionary of items for logging."""
|
||||
start_time = time.time()
|
||||
policy.train()
|
||||
output_dict = policy.forward(batch)
|
||||
@@ -99,6 +116,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
||||
"grad_norm": float(grad_norm),
|
||||
"lr": optimizer.param_groups[0]["lr"],
|
||||
"update_s": time.time() - start_time,
|
||||
**{k: v for k, v in output_dict.items() if k != "loss"},
|
||||
}
|
||||
|
||||
return info
|
||||
@@ -122,7 +140,7 @@ def train_notebook(out_dir=None, job_name=None, config_name="default", config_pa
|
||||
train(cfg, out_dir=out_dir, job_name=job_name)
|
||||
|
||||
|
||||
def log_train_info(logger, info, step, cfg, dataset, is_offline):
|
||||
def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
|
||||
loss = info["loss"]
|
||||
grad_norm = info["grad_norm"]
|
||||
lr = info["lr"]
|
||||
@@ -290,7 +308,7 @@ def add_episodes_inplace(
|
||||
sampler.num_samples = len(concat_dataset)
|
||||
|
||||
|
||||
def train(cfg: dict, out_dir=None, job_name=None):
|
||||
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
if job_name is None:
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
|
||||
|
||||
Note: The last frame of the episode doesnt always correspond to a final state.
|
||||
@@ -47,6 +62,7 @@ local$ rerun ws://localhost:9087
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
@@ -115,15 +131,17 @@ def visualize_dataset(
|
||||
|
||||
spawn_local_viewer = mode == "local" and not save
|
||||
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
|
||||
|
||||
# Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
|
||||
# when iterating on a dataloader with `num_workers` > 0
|
||||
# TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
|
||||
gc.collect()
|
||||
|
||||
if mode == "distant":
|
||||
rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
|
||||
|
||||
logging.info("Logging to Rerun")
|
||||
|
||||
if num_workers > 0:
|
||||
# TODO(rcadene): fix data workers hanging when `rr.init` is called
|
||||
logging.warning("If data loader is hanging, try `--num-workers 0`.")
|
||||
|
||||
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
||||
# iterate over the batch
|
||||
for i in range(len(batch["index"])):
|
||||
@@ -196,7 +214,7 @@ def main():
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=0,
|
||||
default=4,
|
||||
help="Number of processes of Dataloader for loading the data.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
1158
poetry.lock
generated
1158
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -28,40 +28,41 @@ packages = [{include = "lerobot"}]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<3.13"
|
||||
termcolor = "^2.4.0"
|
||||
omegaconf = "^2.3.0"
|
||||
wandb = "^0.16.3"
|
||||
imageio = {extras = ["ffmpeg"], version = "^2.34.0"}
|
||||
gdown = "^5.1.0"
|
||||
hydra-core = "^1.3.2"
|
||||
einops = "^0.8.0"
|
||||
pymunk = "^6.6.0"
|
||||
zarr = "^2.17.0"
|
||||
numba = "^0.59.0"
|
||||
termcolor = ">=2.4.0"
|
||||
omegaconf = ">=2.3.0"
|
||||
wandb = ">=0.16.3"
|
||||
imageio = {extras = ["ffmpeg"], version = ">=2.34.0"}
|
||||
gdown = ">=5.1.0"
|
||||
hydra-core = ">=1.3.2"
|
||||
einops = ">=0.8.0"
|
||||
pymunk = ">=6.6.0"
|
||||
zarr = ">=2.17.0"
|
||||
numba = ">=0.59.0"
|
||||
torch = "^2.2.1"
|
||||
opencv-python = "^4.9.0.80"
|
||||
opencv-python = ">=4.9.0"
|
||||
diffusers = "^0.27.2"
|
||||
torchvision = "^0.18.0"
|
||||
h5py = "^3.10.0"
|
||||
huggingface-hub = "^0.21.4"
|
||||
robomimic = "0.2.0"
|
||||
gymnasium = "^0.29.1"
|
||||
cmake = "^3.29.0.1"
|
||||
gym-pusht = { version = "^0.1.1", optional = true}
|
||||
gym-xarm = { version = "^0.1.0", optional = true}
|
||||
gym-aloha = { version = "^0.1.0", optional = true}
|
||||
pre-commit = {version = "^3.7.0", optional = true}
|
||||
debugpy = {version = "^1.8.1", optional = true}
|
||||
pytest = {version = "^8.1.0", optional = true}
|
||||
pytest-cov = {version = "^5.0.0", optional = true}
|
||||
torchvision = ">=0.18.0"
|
||||
h5py = ">=3.10.0"
|
||||
huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"}
|
||||
gymnasium = ">=0.29.1"
|
||||
cmake = ">=3.29.0.1"
|
||||
gym-dora = { path = "gym_dora", optional = true, develop = true}
|
||||
gym-pusht = { version = ">=0.1.3", optional = true}
|
||||
gym-xarm = { version = ">=0.1.1", optional = true}
|
||||
gym-aloha = { version = ">=0.1.1", optional = true}
|
||||
pre-commit = {version = ">=3.7.0", optional = true}
|
||||
debugpy = {version = ">=1.8.1", optional = true}
|
||||
pytest = {version = ">=8.1.0", optional = true}
|
||||
pytest-cov = {version = ">=5.0.0", optional = true}
|
||||
datasets = "^2.19.0"
|
||||
imagecodecs = { version = "^2024.1.1", optional = true }
|
||||
pyav = "^12.0.5"
|
||||
moviepy = "^1.0.3"
|
||||
rerun-sdk = "^0.15.1"
|
||||
imagecodecs = { version = ">=2024.1.1", optional = true }
|
||||
pyav = ">=12.0.5"
|
||||
moviepy = ">=1.0.3"
|
||||
rerun-sdk = ">=0.15.1"
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
dora = ["gym-dora"]
|
||||
pusht = ["gym-pusht"]
|
||||
xarm = ["gym-xarm"]
|
||||
aloha = ["gym-aloha"]
|
||||
@@ -104,5 +105,5 @@ ignore-init-module-imports = true
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.5.0"]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .utils import DEVICE
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9f9347c8d9ac90ee44e6dd86f65043438168df6bbe4bab2d2b875e55ef7376ef
|
||||
size 1488
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
|
||||
size 33
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:02fc4ea25766269f65752a60b0594c43d799b0ae528cd773bf024b064b5aa329
|
||||
size 4344
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:55d7b1a06fe3e3051482752740074348bdb5fc98fb2e305b06d6203994117b27
|
||||
size 592448
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8b7fbedfdb3d536847bc6fadf2cbabb9f2b5492edf3e2c274a3e8ffb447105e8
|
||||
size 1166
|
||||
3
tests/data/lerobot/aloha_mobile_cabinet/train/state.json
Normal file
3
tests/data/lerobot/aloha_mobile_cabinet/train/state.json
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:98329e4b40e9be0d63f7d36da9d86c44bbe7eeeb1b10d3ba973c923f3be70867
|
||||
size 247
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:54e42cdfd016a0ced2ab1fe2966a8c15a2384e0dbe1a2fe87433a2d1b8209ac0
|
||||
size 5220057
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:af1ded2a244cb47a96255b75f584a643edf6967e13bb5464b330ffdd9d7ad859
|
||||
size 5284692
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:13d1bebabd79984fd6715971be758ef9a354495adea5e8d33f4e7904365e112b
|
||||
size 5258380
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f33bc6810f0b91817a42610364cb49ed1b99660f058f0f9407e6f5920d0aee02
|
||||
size 1008
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
|
||||
size 33
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7b58d6c89e936a781a307805ebecf0dd473fbc02d52a7094da62e54bffb9454a
|
||||
size 4344
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a08be578285cbe2d35b78f150d464ff3e10604a9865398c976983e0d711774f9
|
||||
size 788528
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8b7fbedfdb3d536847bc6fadf2cbabb9f2b5492edf3e2c274a3e8ffb447105e8
|
||||
size 1166
|
||||
3
tests/data/lerobot/aloha_mobile_chair/train/state.json
Normal file
3
tests/data/lerobot/aloha_mobile_chair/train/state.json
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:34e36233477c8aa0b0840314ddace072062d4f486d06546bbd6550832c370065
|
||||
size 247
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:66e7349a4a82ca6042a7189608d01eb1cfa38d100d039b5445ae1a9e65d824ab
|
||||
size 14470946
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a2146f0c10c9f2611e57e617983aa4f91ad681b4fc50d91b992b97abd684f926
|
||||
size 11662185
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5affbaf1c48895ba3c626e0d8cf1309e5f4ec6bbaa135313096f52a22de66c05
|
||||
size 11410342
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6c2b195ca91b88fd16422128d386d2cabd808a1862c6d127e6bf2e83e1fe819a
|
||||
size 448
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
|
||||
size 33
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b360b6b956d2adcb20589947c553348ef1eb6b70743c989dcbe95243d8592ce5
|
||||
size 4344
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3f5c3926b4d4da9271abefcdf6a8952bb1f13258a9c39fe0fd223f548dc89dcb
|
||||
size 887728
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8b7fbedfdb3d536847bc6fadf2cbabb9f2b5492edf3e2c274a3e8ffb447105e8
|
||||
size 1166
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4993b05fb026619eec5eb70db8cadaa041ba4ab92d38b4a387167ace03b1018b
|
||||
size 247
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bd25d17ef5b7500386761b5e32920879bbdcafe0e17a8a8845628525d861e644
|
||||
size 10231081
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5b557acbfeb0681c0a38e47263d945f6cd3a03461298d8b17209c81e3fd0aae8
|
||||
size 9701371
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:da8f3b4f9f965da63819652b2c042d4cf7e07d14631113ea072087d56370310e
|
||||
size 10473741
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a053506017d8a78cfd307b2912eeafa1ac1485a280cf90913985fcc40120b5ec
|
||||
size 416
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
|
||||
size 33
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d6d172d1bca02face22ceb4c21ea2b054cf3463025485dce64711b6f36b31f8a
|
||||
size 4344
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7e5ce817a2c188041f57f8d4c465dab3b9c3e4e1aeb7a9fb270230d1b36df530
|
||||
size 1477064
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8b7fbedfdb3d536847bc6fadf2cbabb9f2b5492edf3e2c274a3e8ffb447105e8
|
||||
size 1166
|
||||
3
tests/data/lerobot/aloha_mobile_shrimp/train/state.json
Normal file
3
tests/data/lerobot/aloha_mobile_shrimp/train/state.json
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4eb2dc373e4ea7d474742590f9073d66a773f6ab94b9e73a8673df19f93fae6d
|
||||
size 247
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d2c55b146fabe78b18c8a28a7746ab56e1ee7a6918e9e3dad9bd196f97975895
|
||||
size 26158915
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:71e1958d77f56843acf1ec48da4f04311a5836c87a0e77dbe26aa47c27c6347e
|
||||
size 18786848
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:20780718399b5759ff9a3a79824986310524793066198e3b9a307222f11a93df
|
||||
size 17769988
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:279916f7689ae46af90e92a46eba9486a71fc762e3e2679ab5441eb37126827b
|
||||
size 928
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
|
||||
size 33
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7a7731051b521694b52b5631470720a7f05331915f4ac4e7f8cd83f9ff459bce
|
||||
size 4344
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:99608258e8c9fe5191f1a12edc29b47d307790104149dffb6d3046ddad6aeb1b
|
||||
size 435600
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user