Compare commits

..

9 Commits

Author SHA1 Message Date
Simon Alibert
c6a61e3ba2 WIP 2024-05-21 16:31:48 +02:00
Simon Alibert
62d3546f08 Move dependencies to extra 2024-05-21 16:29:44 +02:00
Simon Alibert
956f035d16 Merge remote-tracking branch 'origin/main' into user/aliberts/2024_05_14_compare_policies 2024-05-21 10:14:10 +02:00
Alexander Soare
b6c216b590 Add Automatic Mixed Precision option for training and evaluation. (#199) 2024-05-20 18:57:54 +01:00
Alexander Soare
2b270d085b Disable online training (#202)
Co-authored-by: Remi <re.cadene@gmail.com>
2024-05-20 18:27:54 +01:00
Simon Alibert
eb530fa595 Add '--independent' flag 2024-05-16 19:31:57 +02:00
Simon Alibert
fe31b7f4b7 Merge remote-tracking branch 'origin/main' into user/aliberts/2024_05_14_compare_policies 2024-05-16 17:04:33 +02:00
Simon Alibert
8f5cfcd73d Add argparse, refactor & cleanup 2024-05-16 16:55:40 +02:00
Simon Alibert
10036c1219 WIP add score tests 2024-05-15 17:50:12 +02:00
19 changed files with 1194 additions and 1338 deletions

View File

@@ -20,6 +20,8 @@ build-gpu:
test-end-to-end:
${MAKE} test-act-ete-train
${MAKE} test-act-ete-eval
${MAKE} test-act-ete-train-amp
${MAKE} test-act-ete-eval-amp
${MAKE} test-diffusion-ete-train
${MAKE} test-diffusion-ete-eval
${MAKE} test-tdmpc-ete-train
@@ -29,6 +31,7 @@ test-end-to-end:
test-act-ete-train:
python lerobot/scripts/train.py \
policy=act \
policy.dim_model=64 \
env=aloha \
wandb.enable=False \
training.offline_steps=2 \
@@ -51,9 +54,40 @@ test-act-ete-eval:
env.episode_length=8 \
device=cpu \
test-act-ete-train-amp:
python lerobot/scripts/train.py \
policy=act \
policy.dim_model=64 \
env=aloha \
wandb.enable=False \
training.offline_steps=2 \
training.online_steps=0 \
eval.n_episodes=1 \
eval.batch_size=1 \
device=cpu \
training.save_model=true \
training.save_freq=2 \
policy.n_action_steps=20 \
policy.chunk_size=20 \
training.batch_size=2 \
hydra.run.dir=tests/outputs/act/ \
use_amp=true
test-act-ete-eval-amp:
python lerobot/scripts/eval.py \
-p tests/outputs/act/checkpoints/000002 \
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=8 \
device=cpu \
use_amp=true
test-diffusion-ete-train:
python lerobot/scripts/train.py \
policy=diffusion \
policy.down_dims=\[64,128,256\] \
policy.diffusion_step_embed_dim=32 \
policy.num_inference_steps=10 \
env=pusht \
wandb.enable=False \
training.offline_steps=2 \
@@ -74,6 +108,7 @@ test-diffusion-ete-eval:
env.episode_length=8 \
device=cpu \
# TODO(alexander-soare): Restore online_steps to 2 when it is reinstated.
test-tdmpc-ete-train:
python lerobot/scripts/train.py \
policy=tdmpc \
@@ -82,7 +117,7 @@ test-tdmpc-ete-train:
dataset_repo_id=lerobot/xarm_lift_medium \
wandb.enable=False \
training.offline_steps=2 \
training.online_steps=2 \
training.online_steps=0 \
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=2 \
@@ -100,7 +135,6 @@ test-tdmpc-ete-eval:
env.episode_length=8 \
device=cpu \
test-default-ete-eval:
python lerobot/scripts/eval.py \
--config lerobot/configs/default.yaml \

View File

@@ -1 +0,0 @@
# gym_dora

View File

@@ -1,17 +0,0 @@
import gymnasium as gym
import gym_dora # noqa: F401
env = gym.make("gym_dora/DoraAloha-v0", disable_env_checker=True)
obs = env.reset()
policy = ... # make_policy
done = False
while not done:
actions = policy.select_action(obs)
observation, reward, terminated, truncated, info = env.step(actions)
done = terminated | truncated | done
env.close()

View File

@@ -1,17 +0,0 @@
from gymnasium.envs.registration import register
register(
id="gym_dora/DoraAloha-v0",
entry_point="gym_dora.env:DoraEnv",
max_episode_steps=300,
nondeterministic=True,
kwargs={"model": "aloha"},
)
register(
id="gym_dora/DoraKoch-v0",
entry_point="gym_dora.env:DoraEnv",
max_episode_steps=300,
nondeterministic=True,
kwargs={"model": "koch"},
)

View File

@@ -1,199 +0,0 @@
import os
import gymnasium as gym
import numpy as np
import pyarrow as pa
from dora import Node
from gymnasium import spaces
FPS = int(os.getenv("FPS", "30"))
IMAGE_WIDTH = int(os.getenv("IMAGE_WIDTH", "640"))
IMAGE_HEIGHT = int(os.getenv("IMAGE_HEIGHT", "480"))
ALOHA_JOINTS = [
# absolute joint position
"left_arm_waist",
"left_arm_shoulder",
"left_arm_elbow",
"left_arm_forearm_roll",
"left_arm_wrist_angle",
"left_arm_wrist_rotate",
# normalized gripper position 0: close, 1: open
"left_arm_gripper",
# absolute joint position
"right_arm_waist",
"right_arm_shoulder",
"right_arm_elbow",
"right_arm_forearm_roll",
"right_arm_wrist_angle",
"right_arm_wrist_rotate",
# normalized gripper position 0: close, 1: open
"right_arm_gripper",
]
ALOHA_ACTIONS = [
# position and quaternion for end effector
"left_arm_waist",
"left_arm_shoulder",
"left_arm_elbow",
"left_arm_forearm_roll",
"left_arm_wrist_angle",
"left_arm_wrist_rotate",
# normalized gripper position (0: close, 1: open)
"left_arm_gripper",
"right_arm_waist",
"right_arm_shoulder",
"right_arm_elbow",
"right_arm_forearm_roll",
"right_arm_wrist_angle",
"right_arm_wrist_rotate",
# normalized gripper position (0: close, 1: open)
"right_arm_gripper",
]
class DoraEnv(gym.Env):
metadata = {"render_modes": ["rgb_array"], "render_fps": FPS}
def __init__(
self,
model="aloha",
observation_width=IMAGE_WIDTH,
observation_height=IMAGE_HEIGHT,
cameras_names=None,
num_joints=None,
num_actions=None,
):
"""Initializes the Dora environment.
Args:
model (str): The model to use. Either 'aloha' or 'custom'.
observation_width (int): The width of the observation image.
observation_height (int): The height of the observation image.
cameras_names (list): A list of camera names to use. If not provided, the default is ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist'].
num_joints (int): The number of joints in the model. If not provided, the default is 14 for 'aloha' and 6 for 'fivedof'.
num_actions (int): The number of actions in the model. If not provided, the default is 14 for 'aloha' and 6 for 'fivedof'.
"""
super().__init__()
# Initialize a new node
self.node = Node() if os.environ.get("DORA_NODE_CONFIG", None) is not None else None
self.observation = {"pixels": {}, "agent_pos": None}
self.terminated = False
self.observation_height = observation_height
self.observation_width = observation_width
# Observation space
if model == "aloha":
self.observation_space = spaces.Dict(
{
"pixels": spaces.Dict(
{
"cam_high": spaces.Box(
low=0,
high=255,
shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8,
),
"cam_low": spaces.Box(
low=0,
high=255,
shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8,
),
"cam_left_wrist": spaces.Box(
low=0,
high=255,
shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8,
),
"cam_right_wrist": spaces.Box(
low=0,
high=255,
shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8,
),
}
),
"agent_pos": spaces.Box(
low=-1000.0,
high=1000.0,
shape=(len(ALOHA_JOINTS),),
dtype=np.float64,
),
}
)
elif model == "custom":
pixel_dict = {}
for camera in cameras_names:
assert camera.startswith("cam"), "Camera names must start with 'cam'"
pixel_dict[camera] = spaces.Box(
low=0,
high=255,
shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8,
)
self.observation_space = spaces.Dict(
{
"pixels": spaces.Dict(pixel_dict),
"agent_pos": spaces.Box(
low=-1000.0,
high=1000.0,
shape=(num_joints,),
dtype=np.float64,
),
}
)
else:
raise ValueError("Model must be either 'aloha' or 'custom'.")
# Action space
if model == "aloha":
self.action_space = spaces.Box(low=-1, high=1, shape=(len(ALOHA_ACTIONS),), dtype=np.float32)
elif model == "custom":
self.action_space = spaces.Box(low=-1, high=1, shape=(num_actions,), dtype=np.float32)
def _get_obs(self):
while True:
event = self.node.next(timeout=0.001)
## If event is None, the node event stream is closed and we should terminate the env
if event is None:
self.terminated = True
break
if event["type"] == "INPUT":
# Map Image input into pixels key within Aloha environment
if "cam" in event["id"]:
self.observation["pixels"][event["id"]] = (
event["value"].to_numpy().reshape(self.observation_height, self.observation_width, 3)
)
else:
# Map other inputs into the observation dictionary using the event id as key
self.observation[event["id"]] = event["value"].to_numpy()
# If the event is a timeout error break the update loop.
elif event["type"] == "ERROR":
break
def reset(self, seed: int | None = None):
self.node.send_output("reset")
self._get_obs()
self.terminated = False
info = {}
return self.observation, info
def step(self, action: np.ndarray):
# Send the action to the dataflow as action key.
self.node.send_output("action", pa.array(action))
self._get_obs()
reward = 0
terminated = truncated = self.terminated
info = {}
return self.observation, reward, terminated, truncated, info
def render(self): ...
def close(self):
# Drop the node
del self.node

182
gym_dora/poetry.lock generated
View File

@@ -1,182 +0,0 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]]
name = "cloudpickle"
version = "3.0.0"
description = "Pickler class to extend the standard pickle.Pickler functionality"
optional = false
python-versions = ">=3.8"
files = [
{file = "cloudpickle-3.0.0-py3-none-any.whl", hash = "sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7"},
{file = "cloudpickle-3.0.0.tar.gz", hash = "sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882"},
]
[[package]]
name = "dora-rs"
version = "0.3.4"
description = "`dora` goal is to be a low latency, composable, and distributed data flow."
optional = false
python-versions = "*"
files = [
{file = "dora_rs-0.3.4-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:d1b738eea5a4966d731c26c6b6a0a50a491a24f7e9e335475f983cfc6f0da19e"},
{file = "dora_rs-0.3.4-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:80b724871618c78a4e5863938fa66724176cc40352771087aebe1e62a8141157"},
{file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a3919e157b47dc1dbc74c040a73087a4485f0d1bee99b6adcdbc36559400fe2"},
{file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f7c95f6e5858fd651d6cd220e4f052e99db2944b9c37fb0b5402d60ac4b41a63"},
{file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37d915fbbca282446235c98a9ca08389aa3ef3155d4e88c6c136326e9a830042"},
{file = "dora_rs-0.3.4-cp37-abi3-win32.whl", hash = "sha256:c9f7f22f65c884ec9bee0245ce98d0c7fad25dec0f982e566f844b5e8e58818f"},
{file = "dora_rs-0.3.4-cp37-abi3-win_amd64.whl", hash = "sha256:0a6a37f96a9f6e13b58b02a6ea75af192af5fbe4f456f6a67b1f239c3cee3276"},
{file = "dora_rs-0.3.4.tar.gz", hash = "sha256:05c5d0db0d23d7c4669995ae34db11cd636dbf91f5705d832669bd04e7452903"},
]
[package.dependencies]
pyarrow = "*"
[[package]]
name = "farama-notifications"
version = "0.0.4"
description = "Notifications for all Farama Foundation maintained libraries."
optional = false
python-versions = "*"
files = [
{file = "Farama-Notifications-0.0.4.tar.gz", hash = "sha256:13fceff2d14314cf80703c8266462ebf3733c7d165336eee998fc58e545efd18"},
{file = "Farama_Notifications-0.0.4-py3-none-any.whl", hash = "sha256:14de931035a41961f7c056361dc7f980762a143d05791ef5794a751a2caf05ae"},
]
[[package]]
name = "gymnasium"
version = "0.29.1"
description = "A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym)."
optional = false
python-versions = ">=3.8"
files = [
{file = "gymnasium-0.29.1-py3-none-any.whl", hash = "sha256:61c3384b5575985bb7f85e43213bcb40f36fcdff388cae6bc229304c71f2843e"},
{file = "gymnasium-0.29.1.tar.gz", hash = "sha256:1a532752efcb7590478b1cc7aa04f608eb7a2fdad5570cd217b66b6a35274bb1"},
]
[package.dependencies]
cloudpickle = ">=1.2.0"
farama-notifications = ">=0.0.1"
numpy = ">=1.21.0"
typing-extensions = ">=4.3.0"
[package.extras]
accept-rom-license = ["autorom[accept-rom-license] (>=0.4.2,<0.5.0)"]
all = ["box2d-py (==2.3.5)", "cython (<3)", "imageio (>=2.14.1)", "jax (>=0.4.0)", "jaxlib (>=0.4.0)", "lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "mujoco (>=2.3.3)", "mujoco-py (>=2.1,<2.2)", "opencv-python (>=3.0)", "pygame (>=2.1.3)", "shimmy[atari] (>=0.1.0,<1.0)", "swig (==4.*)", "torch (>=1.0.0)"]
atari = ["shimmy[atari] (>=0.1.0,<1.0)"]
box2d = ["box2d-py (==2.3.5)", "pygame (>=2.1.3)", "swig (==4.*)"]
classic-control = ["pygame (>=2.1.3)", "pygame (>=2.1.3)"]
jax = ["jax (>=0.4.0)", "jaxlib (>=0.4.0)"]
mujoco = ["imageio (>=2.14.1)", "mujoco (>=2.3.3)"]
mujoco-py = ["cython (<3)", "cython (<3)", "mujoco-py (>=2.1,<2.2)", "mujoco-py (>=2.1,<2.2)"]
other = ["lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "opencv-python (>=3.0)", "torch (>=1.0.0)"]
testing = ["pytest (==7.1.3)", "scipy (>=1.7.3)"]
toy-text = ["pygame (>=2.1.3)", "pygame (>=2.1.3)"]
[[package]]
name = "numpy"
version = "1.26.4"
description = "Fundamental package for array computing in Python"
optional = false
python-versions = ">=3.9"
files = [
{file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"},
{file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"},
{file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"},
{file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"},
{file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"},
{file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"},
{file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"},
{file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"},
{file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"},
{file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"},
{file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"},
{file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"},
{file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"},
{file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"},
{file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"},
{file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"},
{file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"},
{file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"},
{file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"},
{file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"},
{file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"},
{file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"},
{file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"},
{file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"},
{file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"},
{file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"},
{file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"},
{file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"},
{file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"},
{file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"},
{file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"},
{file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"},
{file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"},
{file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"},
{file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"},
{file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"},
]
[[package]]
name = "pyarrow"
version = "16.1.0"
description = "Python library for Apache Arrow"
optional = false
python-versions = ">=3.8"
files = [
{file = "pyarrow-16.1.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:17e23b9a65a70cc733d8b738baa6ad3722298fa0c81d88f63ff94bf25eaa77b9"},
{file = "pyarrow-16.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4740cc41e2ba5d641071d0ab5e9ef9b5e6e8c7611351a5cb7c1d175eaf43674a"},
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:98100e0268d04e0eec47b73f20b39c45b4006f3c4233719c3848aa27a03c1aef"},
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f68f409e7b283c085f2da014f9ef81e885d90dcd733bd648cfba3ef265961848"},
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a8914cd176f448e09746037b0c6b3a9d7688cef451ec5735094055116857580c"},
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:48be160782c0556156d91adbdd5a4a7e719f8d407cb46ae3bb4eaee09b3111bd"},
{file = "pyarrow-16.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:9cf389d444b0f41d9fe1444b70650fea31e9d52cfcb5f818b7888b91b586efff"},
{file = "pyarrow-16.1.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:d0ebea336b535b37eee9eee31761813086d33ed06de9ab6fc6aaa0bace7b250c"},
{file = "pyarrow-16.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e73cfc4a99e796727919c5541c65bb88b973377501e39b9842ea71401ca6c1c"},
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf9251264247ecfe93e5f5a0cd43b8ae834f1e61d1abca22da55b20c788417f6"},
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddf5aace92d520d3d2a20031d8b0ec27b4395cab9f74e07cc95edf42a5cc0147"},
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:25233642583bf658f629eb230b9bb79d9af4d9f9229890b3c878699c82f7d11e"},
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a33a64576fddfbec0a44112eaf844c20853647ca833e9a647bfae0582b2ff94b"},
{file = "pyarrow-16.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:185d121b50836379fe012753cf15c4ba9638bda9645183ab36246923875f8d1b"},
{file = "pyarrow-16.1.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:2e51ca1d6ed7f2e9d5c3c83decf27b0d17bb207a7dea986e8dc3e24f80ff7d6f"},
{file = "pyarrow-16.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06ebccb6f8cb7357de85f60d5da50e83507954af617d7b05f48af1621d331c9a"},
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b04707f1979815f5e49824ce52d1dceb46e2f12909a48a6a753fe7cafbc44a0c"},
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d32000693deff8dc5df444b032b5985a48592c0697cb6e3071a5d59888714e2"},
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8785bb10d5d6fd5e15d718ee1d1f914fe768bf8b4d1e5e9bf253de8a26cb1628"},
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:e1369af39587b794873b8a307cc6623a3b1194e69399af0efd05bb202195a5a7"},
{file = "pyarrow-16.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:febde33305f1498f6df85e8020bca496d0e9ebf2093bab9e0f65e2b4ae2b3444"},
{file = "pyarrow-16.1.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b5f5705ab977947a43ac83b52ade3b881eb6e95fcc02d76f501d549a210ba77f"},
{file = "pyarrow-16.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0d27bf89dfc2576f6206e9cd6cf7a107c9c06dc13d53bbc25b0bd4556f19cf5f"},
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d07de3ee730647a600037bc1d7b7994067ed64d0eba797ac74b2bc77384f4c2"},
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fbef391b63f708e103df99fbaa3acf9f671d77a183a07546ba2f2c297b361e83"},
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:19741c4dbbbc986d38856ee7ddfdd6a00fc3b0fc2d928795b95410d38bb97d15"},
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:f2c5fb249caa17b94e2b9278b36a05ce03d3180e6da0c4c3b3ce5b2788f30eed"},
{file = "pyarrow-16.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:e6b6d3cd35fbb93b70ade1336022cc1147b95ec6af7d36906ca7fe432eb09710"},
{file = "pyarrow-16.1.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:18da9b76a36a954665ccca8aa6bd9f46c1145f79c0bb8f4f244f5f8e799bca55"},
{file = "pyarrow-16.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:99f7549779b6e434467d2aa43ab2b7224dd9e41bdde486020bae198978c9e05e"},
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f07fdffe4fd5b15f5ec15c8b64584868d063bc22b86b46c9695624ca3505b7b4"},
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddfe389a08ea374972bd4065d5f25d14e36b43ebc22fc75f7b951f24378bf0b5"},
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:3b20bd67c94b3a2ea0a749d2a5712fc845a69cb5d52e78e6449bbd295611f3aa"},
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:ba8ac20693c0bb0bf4b238751d4409e62852004a8cf031c73b0e0962b03e45e3"},
{file = "pyarrow-16.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:31a1851751433d89a986616015841977e0a188662fcffd1a5677453f1df2de0a"},
{file = "pyarrow-16.1.0.tar.gz", hash = "sha256:15fbb22ea96d11f0b5768504a3f961edab25eaf4197c341720c4a387f6c60315"},
]
[package.dependencies]
numpy = ">=1.16.6"
[[package]]
name = "typing-extensions"
version = "4.11.0"
description = "Backported and Experimental Type Hints for Python 3.8+"
optional = false
python-versions = ">=3.8"
files = [
{file = "typing_extensions-4.11.0-py3-none-any.whl", hash = "sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a"},
{file = "typing_extensions-4.11.0.tar.gz", hash = "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0"},
]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "7e437b5c547ebe11095f1ce4ff1851d636f8e707ad7de8a6224b0f9ad978240f"

View File

@@ -1,17 +0,0 @@
[tool.poetry]
name = "gym-dora"
version = "0.1.0"
description = ""
authors = ["Simon Alibert <alibert.sim@gmail.com>"]
readme = "README.md"
packages = [{ include = "gym_dora" }]
[tool.poetry.dependencies]
python = "^3.10"
gymnasium = ">=0.29.1"
dora-rs = ">=0.3.4"
pyarrow = ">=12.0.0"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

View File

@@ -1,200 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains utilities to process raw data format from dora-record
"""
import logging
from pathlib import Path
import pandas as pd
import torch
from datasets import Dataset, Features, Image, Sequence, Value
from lerobot.common.datasets.utils import (
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame
from lerobot.common.utils.utils import init_logging
def check_format(raw_dir) -> bool:
assert raw_dir.exists()
leader_file = list(raw_dir.glob("*.parquet"))
if len(leader_file) == 0:
raise ValueError(f"Missing parquet files in '{raw_dir}'")
return True
def load_from_raw(raw_dir: Path, out_dir: Path):
# Load data stream that will be used as reference for the timestamps synchronization
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
if len(reference_files) == 0:
raise ValueError(f"Missing reference files for camera, starting with in '{raw_dir}'")
# select first camera in alphanumeric order
reference_key = sorted(reference_files)[0].stem
reference_df = pd.read_parquet(raw_dir / f"{reference_key}.parquet")
reference_df = reference_df[["timestamp_utc", reference_key]]
# Merge all data stream using nearest backward strategy
df = reference_df
for path in raw_dir.glob("*.parquet"):
key = path.stem # action or observation.state or ...
if key == reference_key:
continue
modality_df = pd.read_parquet(path)
modality_df = modality_df[["timestamp_utc", key]]
df = pd.merge_asof(
df,
modality_df,
on="timestamp_utc",
direction="backward",
)
# Remove rows with a NaN in any column. It can happened during the first frames of an episode,
# because some cameras didnt start recording yet.
df = df.dropna(axis=0)
# Remove rows with episode_index -1 which indicates a failed episode
df = df[df["episode_index"] != -1]
# dora only use arrays, so single values are encapsulated into a list
df["episode_index"] = df["episode_index"].map(lambda x: x[0])
df["frame_index"] = df.groupby("episode_index").cumcount()
df = df.reset_index()
df["index"] = df.index
# set 'next.done' to True for the last frame of each episode
df["next.done"] = False
df.loc[df.groupby("episode_index").tail(1).index, "next.done"] = True
df["timestamp"] = df["timestamp_utc"].map(lambda x: x.timestamp())
# each episode starts with timestamp 0 to match the ones from the video
df["timestamp"] = df.groupby("episode_index")["timestamp"].transform(lambda x: x - x.iloc[0])
del df["timestamp_utc"]
# sanity check episode indices go from 0 to n-1
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
expected_ep_ids = list(range(df["episode_index"].max() + 1))
assert ep_ids == expected_ep_ids, f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}"
# Create symlink to raw videos directory (that needs to be absolute not relative)
out_dir.mkdir(parents=True, exist_ok=True)
videos_dir = out_dir / "videos"
videos_dir.symlink_to((raw_dir / "videos").absolute())
# sanity check the video paths are well formated
for key in df:
if "observation.images." not in key:
continue
for ep_idx in ep_ids:
video_path = videos_dir / f"{key}_episode_{ep_idx:06d}.mp4"
assert video_path.exists(), f"Video file not found in {video_path}"
data_dict = {}
for key in df:
# is video frame
if "observation.images." in key:
# we need `[0] because dora only use arrays, so single values are encapsulated into a list.
# it is the case for video_frame dictionary = [{"path": ..., "timestamp": ...}]
data_dict[key] = [video_frame[0] for video_frame in df[key].values]
# sanity check the video path is well formated
video_path = videos_dir.parent / data_dict[key][0]["path"]
assert video_path.exists(), f"Video file not found in {video_path}"
# is number
elif df[key].iloc[0].ndim == 0 or df[key].iloc[0].shape[0] == 1:
data_dict[key] = torch.from_numpy(df[key].values)
# is vector
elif df[key].iloc[0].shape[0] > 1:
data_dict[key] = torch.stack([torch.from_numpy(x.copy()) for x in df[key].values])
else:
raise ValueError(key)
# Get the episode index containing for each unique episode index
first_ep_index_df = df.groupby("episode_index").agg(start_index=("index", "first")).reset_index()
from_ = first_ep_index_df["start_index"].tolist()
to_ = from_[1:] + [len(df)]
episode_data_index = {
"from": from_,
"to": to_,
}
return data_dict, episode_data_index
def to_hf_dataset(data_dict, video) -> Dataset:
features = {}
keys = [key for key in data_dict if "observation.images." in key]
for key in keys:
if video:
features[key] = VideoFrame()
else:
features[key] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
)
if "observation.velocity" in data_dict:
features["observation.velocity"] = Sequence(
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
)
if "observation.effort" in data_dict:
features["observation.effort"] = Sequence(
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
)
features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
)
features["episode_index"] = Value(dtype="int64", id=None)
features["frame_index"] = Value(dtype="int64", id=None)
features["timestamp"] = Value(dtype="float32", id=None)
features["next.done"] = Value(dtype="bool", id=None)
features["index"] = Value(dtype="int64", id=None)
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
init_logging()
if debug:
logging.warning("debug=True not implemented. Falling back to debug=False.")
# sanity check
check_format(raw_dir)
if fps is None:
fps = 30
else:
raise NotImplementedError()
if not video:
raise NotImplementedError()
data_df, episode_data_index = load_from_raw(raw_dir, out_dir)
hf_dataset = to_hf_dataset(data_df, video)
info = {
"fps": fps,
"video": video,
}
return hf_dataset, episode_data_index, info

View File

@@ -28,11 +28,11 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
raise ValueError("`n_envs must be at least 1")
kwargs = {
# "obs_type": "pixels_agent_pos",
# "render_mode": "rgb_array",
"obs_type": "pixels_agent_pos",
"render_mode": "rgb_array",
"max_episode_steps": cfg.env.episode_length,
# "visualization_width": 384,
# "visualization_height": 384,
"visualization_width": 384,
"visualization_height": 384,
}
package_name = f"gym_{cfg.env.name}"

View File

@@ -10,6 +10,9 @@ hydra:
name: default
device: cuda # cpu
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: false
# `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments.
seed: ???
@@ -17,6 +20,7 @@ dataset_repo_id: lerobot/pusht
training:
offline_steps: ???
# NOTE: `online_steps` is not implemented yet. It's here as a placeholder.
online_steps: ???
online_steps_between_rollouts: ???
online_sampling_ratio: 0.5

View File

@@ -1,14 +0,0 @@
# @package _global_
fps: 30
env:
name: dora
task: DoraAloha-v0
# from_pixels: True
# pixels_only: False
# image_size: [3, 480, 640]
episode_length: 400
# fps: ${fps}
# state_dim: 14
# action_dim: 14

View File

@@ -1,101 +0,0 @@
# @package _global_
seed: 1000
dataset_repo_id: cadene/aloha_v2_static_dora_test
override_dataset_stats:
observation.images.cam_right_wrist:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.cam_left_wrist:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.cam_high:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.cam_low:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
training:
offline_steps: 80000
online_steps: 0
eval_freq: 99999999999999
save_freq: 1000
log_freq: 100
save_model: true
batch_size: 8
lr: 1e-5
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
eval:
n_episodes: 50
batch_size: 50
# See `configuration_act.py` for more details.
policy:
name: act
# Input / output structure.
n_obs_steps: 1
chunk_size: 100 # chunk_size
n_action_steps: 100
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.images.cam_right_wrist: [3, 480, 640]
observation.images.cam_left_wrist: [3, 480, 640]
observation.images.cam_high: [3, 480, 640]
observation.images.cam_low: [3, 480, 640]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.images.cam_right_wrist: mean_std
observation.images.cam_left_wrist: mean_std
observation.images.cam_high: mean_std
observation.images.cam_low: mean_std
observation.state: mean_std
output_normalization_modes:
action: mean_std
# Architecture.
# Vision backbone.
vision_backbone: resnet18
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
replace_final_stride_with_dilation: false
# Transformer layers.
pre_norm: false
dim_model: 512
n_heads: 8
dim_feedforward: 3200
feedforward_activation: relu
n_encoder_layers: 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: 1
# VAE.
use_vae: true
latent_dim: 32
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_momentum: null
# Training and loss computation.
dropout: 0.1
kl_weight: 10.0

View File

@@ -5,7 +5,8 @@ dataset_repo_id: lerobot/xarm_lift_medium
training:
offline_steps: 25000
online_steps: 25000
# TODO(alexander-soare): uncomment when online training gets reinstated
online_steps: 0 # 25000 not implemented yet
eval_freq: 5000
online_steps_between_rollouts: 1
online_sampling_ratio: 0.5

View File

@@ -0,0 +1,340 @@
"""Compare two policies on based on metrics computed from an eval.
Usage example:
You just made changes to a policy and you want to assess its new performance against
the reference policy (i.e. before your changes).
```
python lerobot/scripts/compare_policies.py \
output/eval/ref_policy/eval_info.json \
output/eval/new_policy/eval_info.json
```
This script can accept `eval_info.json` dicts with identical seeds between each eval episode of ref_policy and
new_policy (paired-samples) or from evals performed with different seeds (independent samples).
The script will first perform normality tests to determine if parametric tests can be used or not, then
evaluate if policies metrics are significantly different using the appropriate tests.
CAVEATS: by default, this script will compare seeds numbers to determine if samples can be considered paired.
If changes have been made to this environment in-between the ref_policy eval and the new_policy eval, you
should use the `--independent` flag to override this and not pair the samples even if they have identical
seeds.
"""
import argparse
import json
import logging
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as stats
from scipy.stats import anderson, kstest, mannwhitneyu, normaltest, shapiro, ttest_ind, ttest_rel, wilcoxon
from statsmodels.stats.contingency_tables import mcnemar
from termcolor import colored
from terminaltables import AsciiTable
def init_logging() -> None:
logging.basicConfig(
level=logging.INFO,
format="%(message)s",
handlers=[logging.StreamHandler()],
)
logging.getLogger("matplotlib.font_manager").disabled = True
def log_section(title: str) -> None:
section_title = f"\n{'-'*21}\n {title.center(19)} \n{'-'*21}"
logging.info(section_title)
def log_test(msg: str, p_value: float):
if p_value < 0.01:
color, interpretation = "red", "H_0 Rejected"
elif 0.01 <= p_value < 0.05:
color, interpretation = "yellow", "Inconclusive"
else:
color, interpretation = "green", "H_0 Not Rejected"
logging.info(
f"{msg}, p-value = {colored(f'{p_value:.3f}', color)} -> {colored(f'{interpretation}', color, attrs=['bold'])}"
)
def get_eval_info_episodes(eval_info_path: Path) -> dict:
with open(eval_info_path) as f:
eval_info = json.load(f)
return {
"sum_rewards": np.array([ep_stat["sum_reward"] for ep_stat in eval_info["per_episode"]]),
"max_rewards": np.array([ep_stat["max_reward"] for ep_stat in eval_info["per_episode"]]),
"successes": np.array([ep_stat["success"] for ep_stat in eval_info["per_episode"]]),
"seeds": [ep_stat["seed"] for ep_stat in eval_info["per_episode"]],
"num_episodes": len(eval_info["per_episode"]),
}
def append_table_metric(table: list, metric: str, ref_sample: dict, new_sample: dict, mean_std: bool = False):
if mean_std:
ref_metric = f"{np.mean(ref_sample[metric]):.3f} ({np.std(ref_sample[metric]):.3f})"
new_metric = f"{np.mean(new_sample[metric]):.3f} ({np.std(new_sample[metric]):.3f})"
row_header = f"{metric} - mean (std)"
else:
ref_metric = ref_sample[metric]
new_metric = new_sample[metric]
row_header = metric
row = [row_header, ref_metric, new_metric]
table.append(row)
return table
def cohens_d(x, y):
return (np.mean(x) - np.mean(y)) / np.sqrt((np.std(x, ddof=1) ** 2 + np.std(y, ddof=1) ** 2) / 2)
def normality_tests(array: np.ndarray, name: str):
ap_stat, ap_p = normaltest(array)
sw_stat, sw_p = shapiro(array)
ks_stat, ks_p = kstest(array, "norm", args=(np.mean(array), np.std(array)))
ad_stat = anderson(array)
log_test(f"{name} - D'Agostino and Pearson test: statistic = {ap_stat:.3f}", ap_p)
log_test(f"{name} - Shapiro-Wilk test: statistic = {sw_stat:.3f}", sw_p)
log_test(f"{name} - Kolmogorov-Smirnov test: statistic = {ks_stat:.3f}", ks_p)
logging.info(f"{name} - Anderson-Darling test: statistic = {ad_stat.statistic:.3f}")
for i in range(len(ad_stat.critical_values)):
cv, sl = ad_stat.critical_values[i], ad_stat.significance_level[i]
logging.info(f" Critical value at {sl}%: {cv:.3f}")
return sw_p > 0.05 and ks_p > 0.05
def perform_tests(ref_sample: dict, new_sample: dict, output_dir: Path, independent: bool = False):
seeds_a, seeds_b = ref_sample["seeds"], new_sample["seeds"]
if (seeds_a == seeds_b) and not independent:
logging.info("\nSamples are paired (identical seeds).")
paired = True
else:
logging.info("\nSamples are considered independent (seeds are different).")
paired = False
table_data = [["Metric", "Ref.", "New"]]
table_data = append_table_metric(table_data, "num_episodes", ref_sample, new_sample)
table_data = append_table_metric(table_data, "successes", ref_sample, new_sample, mean_std=True)
table_data = append_table_metric(table_data, "max_rewards", ref_sample, new_sample, mean_std=True)
table_data = append_table_metric(table_data, "sum_rewards", ref_sample, new_sample, mean_std=True)
table = AsciiTable(table_data)
print(table.table)
log_section("Effect Size")
d_max_reward = cohens_d(ref_sample["max_rewards"], new_sample["max_rewards"])
d_sum_reward = cohens_d(ref_sample["sum_rewards"], new_sample["sum_rewards"])
logging.info(f"Cohen's d for Max Reward: {d_max_reward:.3f}")
logging.info(f"Cohen's d for Sum Reward: {d_sum_reward:.3f}")
if paired:
paired_sample_tests(ref_sample, new_sample)
else:
independent_sample_tests(ref_sample, new_sample)
output_dir.mkdir(exist_ok=True, parents=True)
plot_boxplot(
ref_sample["max_rewards"],
new_sample["max_rewards"],
["Ref Sample Max Reward", "New Sample Max Reward"],
"Boxplot of Max Rewards",
f"{output_dir}/boxplot_max_reward.png",
)
plot_boxplot(
ref_sample["sum_rewards"],
new_sample["sum_rewards"],
["Ref Sample Sum Reward", "New Sample Sum Reward"],
"Boxplot of Sum Rewards",
f"{output_dir}/boxplot_sum_reward.png",
)
plot_histogram(
ref_sample["max_rewards"],
new_sample["max_rewards"],
["Ref Sample Max Reward", "New Sample Max Reward"],
"Histogram of Max Rewards",
f"{output_dir}/histogram_max_reward.png",
)
plot_histogram(
ref_sample["sum_rewards"],
new_sample["sum_rewards"],
["Ref Sample Sum Reward", "New Sample Sum Reward"],
"Histogram of Sum Rewards",
f"{output_dir}/histogram_sum_reward.png",
)
plot_qqplot(
ref_sample["max_rewards"],
"Q-Q Plot of Ref Sample Max Rewards",
f"{output_dir}/qqplot_sample_a_max_reward.png",
)
plot_qqplot(
new_sample["max_rewards"],
"Q-Q Plot of New Sample Max Rewards",
f"{output_dir}/qqplot_sample_b_max_reward.png",
)
plot_qqplot(
ref_sample["sum_rewards"],
"Q-Q Plot of Ref Sample Sum Rewards",
f"{output_dir}/qqplot_sample_a_sum_reward.png",
)
plot_qqplot(
new_sample["sum_rewards"],
"Q-Q Plot of New Sample Sum Rewards",
f"{output_dir}/qqplot_sample_b_sum_reward.png",
)
def paired_sample_tests(ref_sample: dict, new_sample: dict):
log_section("Normality tests")
max_reward_diff = ref_sample["max_rewards"] - new_sample["max_rewards"]
sum_reward_diff = ref_sample["sum_rewards"] - new_sample["sum_rewards"]
normal_max_reward_diff = normality_tests(max_reward_diff, "Max Reward Difference")
normal_sum_reward_diff = normality_tests(sum_reward_diff, "Sum Reward Difference")
log_section("Paired-sample tests")
if normal_max_reward_diff:
t_stat_max_reward, p_val_max_reward = ttest_rel(ref_sample["max_rewards"], new_sample["max_rewards"])
log_test(f"Paired t-test for Max Reward: t-statistic = {t_stat_max_reward:.3f}", p_val_max_reward)
else:
w_stat_max_reward, p_wilcox_max_reward = wilcoxon(
ref_sample["max_rewards"], new_sample["max_rewards"]
)
log_test(f"Wilcoxon test for Max Reward: statistic = {w_stat_max_reward:.3f}", p_wilcox_max_reward)
if normal_sum_reward_diff:
t_stat_sum_reward, p_val_sum_reward = ttest_rel(ref_sample["sum_rewards"], new_sample["sum_rewards"])
log_test(f"Paired t-test for Sum Reward: t-statistic = {t_stat_sum_reward:.3f}", p_val_sum_reward)
else:
w_stat_sum_reward, p_wilcox_sum_reward = wilcoxon(
ref_sample["sum_rewards"], new_sample["sum_rewards"]
)
log_test(f"Wilcoxon test for Sum Reward: statistic = {w_stat_sum_reward:.3f}", p_wilcox_sum_reward)
table = np.array(
[
[
np.sum((ref_sample["successes"] == 1) & (new_sample["successes"] == 1)),
np.sum((ref_sample["successes"] == 1) & (new_sample["successes"] == 0)),
],
[
np.sum((ref_sample["successes"] == 0) & (new_sample["successes"] == 1)),
np.sum((ref_sample["successes"] == 0) & (new_sample["successes"] == 0)),
],
]
)
mcnemar_result = mcnemar(table, exact=True)
log_test(f"McNemar's test for Success: statistic = {mcnemar_result.statistic:.3f}", mcnemar_result.pvalue)
def independent_sample_tests(ref_sample: dict, new_sample: dict):
log_section("Normality tests")
normal_max_rewards_a = normality_tests(ref_sample["max_rewards"], "Max Rewards Ref Sample")
normal_max_rewards_b = normality_tests(new_sample["max_rewards"], "Max Rewards New Sample")
normal_sum_rewards_a = normality_tests(ref_sample["sum_rewards"], "Sum Rewards Ref Sample")
normal_sum_rewards_b = normality_tests(new_sample["sum_rewards"], "Sum Rewards New Sample")
log_section("Independent samples tests")
table = [["Test", "max_rewards", "sum_rewards"]]
if normal_max_rewards_a and normal_max_rewards_b:
table = append_independent_test(
table, ref_sample, new_sample, ttest_ind, "Two-Sample t-test", kwargs={"equal_var": False}
)
t_stat_max_reward, p_val_max_reward = ttest_ind(
ref_sample["max_rewards"], new_sample["max_rewards"], equal_var=False
)
log_test(f"Two-Sample t-test for Max Reward: t-statistic = {t_stat_max_reward:.3f}", p_val_max_reward)
else:
table = append_independent_test(table, ref_sample, new_sample, mannwhitneyu, "Mann-Whitney U")
u_stat_max_reward, p_u_max_reward = mannwhitneyu(ref_sample["max_rewards"], new_sample["max_rewards"])
log_test(f"Mann-Whitney U test for Max Reward: U-statistic = {u_stat_max_reward:.3f}", p_u_max_reward)
if normal_sum_rewards_a and normal_sum_rewards_b:
t_stat_sum_reward, p_val_sum_reward = ttest_ind(
ref_sample["sum_rewards"], new_sample["sum_rewards"], equal_var=False
)
log_test(f"Two-Sample t-test for Sum Reward: t-statistic = {t_stat_sum_reward:.3f}", p_val_sum_reward)
else:
u_stat_sum_reward, p_u_sum_reward = mannwhitneyu(ref_sample["sum_rewards"], new_sample["sum_rewards"])
log_test(f"Mann-Whitney U test for Sum Reward: U-statistic = {u_stat_sum_reward:.3f}", p_u_sum_reward)
table = AsciiTable(table)
print(table.table)
def append_independent_test(
table: list,
ref_sample: dict,
new_sample: dict,
test: callable,
test_name: str,
kwargs: dict | None = None,
) -> list:
kwargs = {} if kwargs is None else kwargs
row = [f"{test_name}: p-value ≥ alpha"]
for metric in table[0][1:]:
_, p_val = test(ref_sample[metric], new_sample[metric], **kwargs)
alpha = 0.05
status = "" if p_val >= alpha else ""
row.append(f"{status} {p_val:.3f}{alpha}")
table.append(row)
return table
def plot_boxplot(data_a: np.ndarray, data_b: np.ndarray, labels: list[str], title: str, filename: str):
plt.boxplot([data_a, data_b], labels=labels)
plt.title(title)
plt.savefig(filename)
plt.close()
def plot_histogram(data_a: np.ndarray, data_b: np.ndarray, labels: list[str], title: str, filename: str):
plt.hist(data_a, bins=30, alpha=0.7, label=labels[0])
plt.hist(data_b, bins=30, alpha=0.7, label=labels[1])
plt.title(title)
plt.legend()
plt.savefig(filename)
plt.close()
def plot_qqplot(data: np.ndarray, title: str, filename: str):
stats.probplot(data, dist="norm", plot=plt)
plt.title(title)
plt.savefig(filename)
plt.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument("ref_sample_path", type=Path, help="Path to the reference sample JSON file.")
parser.add_argument("new_sample_path", type=Path, help="Path to the new sample JSON file.")
parser.add_argument(
"--independent",
action="store_true",
help="Ignore seeds and consider samples to be independent (unpaired).",
)
parser.add_argument(
"--output_dir",
type=Path,
default=Path("outputs/compare/"),
help="Directory to save the output results. Defaults to outputs/compare/",
)
args = parser.parse_args()
init_logging()
ref_sample = get_eval_info_episodes(args.ref_sample_path)
new_sample = get_eval_info_episodes(args.new_sample_path)
perform_tests(ref_sample, new_sample, args.output_dir, args.independent)

View File

@@ -46,6 +46,7 @@ import json
import logging
import threading
import time
from contextlib import nullcontext
from copy import deepcopy
from datetime import datetime as dt
from pathlib import Path
@@ -520,7 +521,7 @@ def eval(
raise NotImplementedError()
# Check device is available
get_safe_torch_device(hydra_cfg.device, log=True)
device = get_safe_torch_device(hydra_cfg.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@@ -539,16 +540,17 @@ def eval(
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
policy.eval()
info = eval_policy(
env,
policy,
hydra_cfg.eval.n_episodes,
max_episodes_rendered=10,
video_dir=Path(out_dir) / "eval",
start_seed=hydra_cfg.seed,
enable_progbar=True,
enable_inner_progbar=True,
)
with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
info = eval_policy(
env,
policy,
hydra_cfg.eval.n_episodes,
max_episodes_rendered=10,
video_dir=Path(out_dir) / "eval",
start_seed=hydra_cfg.seed,
enable_progbar=True,
enable_inner_progbar=True,
)
print(info["aggregated"])
# Save info

View File

@@ -84,14 +84,10 @@ def get_from_raw_to_lerobot_format_fn(raw_format):
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
elif raw_format == "aloha_hdf5":
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
elif raw_format == "aloha_dora":
from lerobot.common.datasets.push_dataset_to_hub.aloha_dora_format import from_raw_to_lerobot_format
elif raw_format == "xarm_pkl":
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
else:
raise ValueError(
f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?"
)
raise ValueError(raw_format)
return from_raw_to_lerobot_format
@@ -144,8 +140,7 @@ def push_videos_to_hub(repo_id, videos_dir, revision):
def push_dataset_to_hub(
input_data_dir: Path,
output_data_dir: Path,
data_dir: Path,
dataset_id: str,
raw_format: str | None,
community_id: str,
@@ -162,33 +157,34 @@ def push_dataset_to_hub(
):
repo_id = f"{community_id}/{dataset_id}"
meta_data_dir = output_data_dir / "meta_data"
videos_dir = output_data_dir / "videos"
raw_dir = data_dir / f"{dataset_id}_raw"
out_dir = data_dir / repo_id
meta_data_dir = out_dir / "meta_data"
videos_dir = out_dir / "videos"
tests_out_dir = tests_data_dir / repo_id
tests_meta_data_dir = tests_out_dir / "meta_data"
tests_videos_dir = tests_out_dir / "videos"
if output_data_dir.exists():
shutil.rmtree(output_data_dir)
if out_dir.exists():
shutil.rmtree(out_dir)
if tests_out_dir.exists() and save_tests_to_disk:
shutil.rmtree(tests_out_dir)
if not input_data_dir.exists():
download_raw(input_data_dir, dataset_id)
if not raw_dir.exists():
download_raw(raw_dir, dataset_id)
if raw_format is None:
# TODO(rcadene, adilzouitine): implement auto_find_raw_format
raise NotImplementedError()
# raw_format = auto_find_raw_format(input_data_dir)
# raw_format = auto_find_raw_format(raw_dir)
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
# convert dataset from original raw format to LeRobot format
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
input_data_dir, output_data_dir, fps, video, debug
)
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug)
lerobot_dataset = LeRobotDataset.from_preloaded(
repo_id=repo_id,
@@ -202,7 +198,7 @@ def push_dataset_to_hub(
if save_to_disk:
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
hf_dataset.save_to_disk(str(output_data_dir / "train"))
hf_dataset.save_to_disk(str(out_dir / "train"))
if not dry_run or save_to_disk:
# mandatory for upload
@@ -236,25 +232,19 @@ def push_dataset_to_hub(
fname = f"{key}_episode_{episode_index:06d}.mp4"
shutil.copy(videos_dir / fname, tests_videos_dir / fname)
if not save_to_disk and output_data_dir.exists():
if not save_to_disk and out_dir.exists():
# remove possible temporary files remaining in the output directory
shutil.rmtree(output_data_dir)
shutil.rmtree(out_dir)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input-data-dir",
"--data-dir",
type=Path,
required=True,
help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw`).",
)
parser.add_argument(
"--output-data-dir",
type=Path,
required=True,
help="Root directory containing output dataset (e.g. `data/lerobot/aloha_mobile_chair` or `data/lerobot/pusht`).",
help="Root directory containing datasets (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).",
)
parser.add_argument(
"--dataset-id",

View File

@@ -15,15 +15,14 @@
# limitations under the License.
import logging
import time
from contextlib import nullcontext
from copy import deepcopy
from pathlib import Path
import datasets
import hydra
import torch
from datasets import concatenate_datasets
from datasets.utils import disable_progress_bars, enable_progress_bars
from omegaconf import DictConfig
from torch.cuda.amp import GradScaler
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle
@@ -31,6 +30,7 @@ from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import PolicyWithUpdate
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.utils import (
format_big_number,
get_safe_torch_device,
@@ -69,7 +69,6 @@ def make_optimizer_and_scheduler(cfg, policy):
cfg.training.adam_eps,
cfg.training.adam_weight_decay,
)
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
from diffusers.optimization import get_scheduler
lr_scheduler = get_scheduler(
@@ -87,21 +86,40 @@ def make_optimizer_and_scheduler(cfg, policy):
return optimizer, lr_scheduler
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
def update_policy(
policy,
batch,
optimizer,
grad_clip_norm,
grad_scaler: GradScaler,
lr_scheduler=None,
use_amp: bool = False,
):
"""Returns a dictionary of items for logging."""
start_time = time.time()
start_time = time.perf_counter()
device = get_device_from_parameters(policy)
policy.train()
output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
loss = output_dict["loss"]
loss.backward()
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
loss = output_dict["loss"]
grad_scaler.scale(loss).backward()
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
grad_scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(),
grad_clip_norm,
error_if_nonfinite=False,
)
optimizer.step()
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
grad_scaler.step(optimizer)
# Updates the scale for next iteration.
grad_scaler.update()
optimizer.zero_grad()
if lr_scheduler is not None:
@@ -115,7 +133,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": optimizer.param_groups[0]["lr"],
"update_s": time.time() - start_time,
"update_s": time.perf_counter() - start_time,
**{k: v for k, v in output_dict.items() if k != "loss"},
}
@@ -211,103 +229,6 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline):
logger.log_dict(info, step, mode="eval")
def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
"""
Calculate the sampling weight to be assigned to samples so that a specified percentage of the batch comes from online dataset (on average).
Parameters:
- n_off (int): Number of offline samples, each with a sampling weight of 1.
- n_on (int): Number of online samples.
- pc_on (float): Desired percentage of online samples in decimal form (e.g., 50% as 0.5).
The total weight of offline samples is n_off * 1.0.
The total weight of offline samples is n_on * w.
The total combined weight of all samples is n_off + n_on * w.
The fraction of the weight that is online is n_on * w / (n_off + n_on * w).
We want this fraction to equal pc_on, so we set up the equation n_on * w / (n_off + n_on * w) = pc_on.
The solution is w = - (n_off * pc_on) / (n_on * (pc_on - 1))
"""
assert 0.0 <= pc_on <= 1.0
return -(n_off * pc_on) / (n_on * (pc_on - 1))
def add_episodes_inplace(
online_dataset: torch.utils.data.Dataset,
concat_dataset: torch.utils.data.ConcatDataset,
sampler: torch.utils.data.WeightedRandomSampler,
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
pc_online_samples: float,
):
"""
Modifies the online_dataset, concat_dataset, and sampler in place by integrating
new episodes from hf_dataset into the online_dataset, updating the concatenated
dataset's structure and adjusting the sampling strategy based on the specified
percentage of online samples.
Parameters:
- online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated.
- concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines
offline and online datasets, used for sampling purposes.
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
reflect changes in the dataset sizes and specified sampling weights.
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
They indicate the start index and end index of each episode in the dataset.
- pc_online_samples (float): The target percentage of samples that should come from
the online dataset during sampling operations.
Raises:
- AssertionError: If the first episode_id or index in hf_dataset is not 0
"""
first_episode_idx = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
last_episode_idx = hf_dataset.select_columns("episode_index")[-1]["episode_index"].item()
first_index = hf_dataset.select_columns("index")[0]["index"].item()
last_index = hf_dataset.select_columns("index")[-1]["index"].item()
# sanity check
assert first_episode_idx == 0, f"{first_episode_idx=} is not 0"
assert first_index == 0, f"{first_index=} is not 0"
assert first_index == episode_data_index["from"][first_episode_idx].item()
assert last_index == episode_data_index["to"][last_episode_idx].item() - 1
if len(online_dataset) == 0:
# initialize online dataset
online_dataset.hf_dataset = hf_dataset
online_dataset.episode_data_index = episode_data_index
else:
# get the starting indices of the new episodes and frames to be added
start_episode_idx = last_episode_idx + 1
start_index = last_index + 1
def shift_indices(episode_index, index):
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
example = {"episode_index": episode_index + start_episode_idx, "index": index + start_index}
return example
disable_progress_bars() # map has a tqdm progress bar
hf_dataset = hf_dataset.map(shift_indices, input_columns=["episode_index", "index"])
enable_progress_bars()
episode_data_index["from"] += start_index
episode_data_index["to"] += start_index
# extend online dataset
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
# update the concatenated dataset length used during sampling
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
# update the sampling weights for each frame so that online frames get sampled a certain percentage of times
len_online = len(online_dataset)
len_offline = len(concat_dataset) - len_online
weight_offline = 1.0
weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples)
sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset))
# update the total number of samples used during sampling
sampler.num_samples = len(concat_dataset)
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
if out_dir is None:
raise NotImplementedError()
@@ -316,11 +237,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
init_logging()
if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1:
logging.warning("eval.batch_size > 1 not supported for online training steps")
if cfg.training.online_steps > 0:
raise NotImplementedError("Online training is not implemented yet.")
# Check device is available
get_safe_torch_device(cfg.device, log=True)
device = get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@@ -338,6 +259,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Create optimizer and scheduler
# Temporary hack to move optimizer out of policy
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(enabled=cfg.use_amp)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
@@ -358,14 +280,15 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
def evaluate_and_checkpoint_if_needed(step):
if step % cfg.training.eval_freq == 0:
logging.info(f"Eval policy at step {step}")
eval_info = eval_policy(
eval_env,
policy,
cfg.eval.n_episodes,
video_dir=Path(out_dir) / "eval",
max_episodes_rendered=4,
start_seed=cfg.seed,
)
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
eval_info = eval_policy(
eval_env,
policy,
cfg.eval.n_episodes,
video_dir=Path(out_dir) / "eval",
max_episodes_rendered=4,
start_seed=cfg.seed,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
if cfg.wandb.enable:
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
@@ -389,23 +312,30 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
num_workers=4,
batch_size=cfg.training.batch_size,
shuffle=True,
pin_memory=cfg.device != "cpu",
pin_memory=device.type != "cpu",
drop_last=False,
)
dl_iter = cycle(dataloader)
policy.train()
step = 0 # number of policy update (forward + backward + optim)
is_offline = True
for offline_step in range(cfg.training.offline_steps):
if offline_step == 0:
for step in range(cfg.training.offline_steps):
if step == 0:
logging.info("Start offline training on a fixed dataset")
batch = next(dl_iter)
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
batch[key] = batch[key].to(device, non_blocking=True)
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler)
train_info = update_policy(
policy,
batch,
optimizer,
cfg.training.grad_clip_norm,
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.use_amp,
)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.training.log_freq == 0:
@@ -415,11 +345,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# so we pass in step + 1.
evaluate_and_checkpoint_if_needed(step + 1)
step += 1
# create an env dedicated to online episodes collection from policy rollout
online_training_env = make_env(cfg, n_envs=1)
# create an empty online dataset similar to offline dataset
online_dataset = deepcopy(offline_dataset)
online_dataset.hf_dataset = {}
@@ -436,58 +361,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
num_workers=4,
batch_size=cfg.training.batch_size,
sampler=sampler,
pin_memory=cfg.device != "cpu",
pin_memory=device.type != "cpu",
drop_last=False,
)
dl_iter = cycle(dataloader)
online_step = 0
is_offline = False
for env_step in range(cfg.training.online_steps):
if env_step == 0:
logging.info("Start online training by interacting with environment")
policy.eval()
with torch.no_grad():
eval_info = eval_policy(
online_training_env,
policy,
n_episodes=1,
return_episode_data=True,
start_seed=cfg.training.online_env_seed,
enable_progbar=True,
)
add_episodes_inplace(
online_dataset,
concat_dataset,
sampler,
hf_dataset=eval_info["episodes"]["hf_dataset"],
episode_data_index=eval_info["episodes"]["episode_data_index"],
pc_online_samples=cfg.training.online_sampling_ratio,
)
policy.train()
for _ in range(cfg.training.online_steps_between_rollouts):
batch = next(dl_iter)
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler)
if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# so we pass in step + 1.
evaluate_and_checkpoint_if_needed(step + 1)
step += 1
online_step += 1
eval_env.close()
online_training_env.close()
logging.info("End of training")

1077
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -46,7 +46,6 @@ h5py = ">=3.10.0"
huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"}
gymnasium = ">=0.29.1"
cmake = ">=3.29.0.1"
gym-dora = { path = "gym_dora", optional = true, develop = true}
gym-pusht = { version = ">=0.1.3", optional = true}
gym-xarm = { version = ">=0.1.1", optional = true}
gym-aloha = { version = ">=0.1.1", optional = true}
@@ -59,16 +58,19 @@ imagecodecs = { version = ">=2024.1.1", optional = true }
pyav = ">=12.0.5"
moviepy = ">=1.0.3"
rerun-sdk = ">=0.15.1"
statsmodels = {version = ">=0.14.2", optional = true}
matplotlib = {version = ">=3.8.4", optional = true}
terminaltables = {version = ">=3.1.10", optional = true}
[tool.poetry.extras]
dora = ["gym-dora"]
pusht = ["gym-pusht"]
xarm = ["gym-xarm"]
aloha = ["gym-aloha"]
umi = ["imagecodecs"]
compare = ["statsmodels", "matplotlib", "terminaltables"]
dev = ["pre-commit", "debugpy"]
test = ["pytest", "pytest-cov"]
umi = ["imagecodecs"]
[tool.ruff]
line-length = 110