updating with adding masking in ACT - start adding some tests

This commit is contained in:
Thomas Wolf
2024-06-10 15:30:57 +02:00
parent ef074d7281
commit ddaaa9f279
12 changed files with 237 additions and 42 deletions

View File

@@ -1,5 +1,6 @@
# Using `lerobot` on a real world arm # Using `lerobot` on a real world arm
In this example, we'll be using `lerobot` on a real world arm to: In this example, we'll be using `lerobot` on a real world arm to:
- record a dataset in the `lerobot` format - record a dataset in the `lerobot` format
- (soon) train a policy on it - (soon) train a policy on it
@@ -25,7 +26,9 @@ Follow these steps:
- install `lerobot` - install `lerobot`
- install the Dynamixel-sdk: ` pip install dynamixel-sdk` - install the Dynamixel-sdk: ` pip install dynamixel-sdk`
## 0 - record examples ## Usage
### 0 - record examples
Run the `record_training_data.py` example, selecting the duration and number of episodes you want to record, e.g. Run the `record_training_data.py` example, selecting the duration and number of episodes you want to record, e.g.
``` ```
@@ -40,7 +43,7 @@ TODO:
- being able to drop episodes - being able to drop episodes
- checking uploading to the hub - checking uploading to the hub
## 1 - visualize the dataset ### 1 - visualize the dataset
Use the standard dataset visualization script pointing it to the right folder: Use the standard dataset visualization script pointing it to the right folder:
``` ```
@@ -49,7 +52,7 @@ DATA_DIR='./data' python ../../lerobot/scripts/visualize_dataset.py \
--episode-index 0 --episode-index 0
``` ```
## 2 - Train a policy ### 2 - Train a policy
From the example directory let's run this command to train a model using ACT From the example directory let's run this command to train a model using ACT
@@ -64,7 +67,7 @@ DATA_DIR='./data' python ../../lerobot/scripts/train.py \
wandb.enable=false wandb.enable=false
``` ```
## 3 - Evaluate the policy in the real world ### 3 - Evaluate the policy in the real world
From the example directory let's run this command to evaluate our policy. From the example directory let's run this command to evaluate our policy.
The configuration for running the policy is in the checkpoint of the model. The configuration for running the policy is in the checkpoint of the model.
@@ -75,3 +78,12 @@ python run_policy.py \
-p ./outputs/train/blue_red_sort/checkpoints/last/pretrained_model/ -p ./outputs/train/blue_red_sort/checkpoints/last/pretrained_model/
env.episode_length=1000 env.episode_length=1000
``` ```
## Convert a hdf5 dataset recorded with the original ACT repo
You can convert a dataset from the raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act with the following command:
```
python ./lerobot/scripts/push_dataset_to_hub.py
```

View File

@@ -1,4 +1,5 @@
import time import time
from unittest.mock import MagicMock
import cv2 import cv2
import gymnasium as gym import gymnasium as gym
@@ -23,6 +24,14 @@ CAMERAS_PORTS = {
LEADER_PORT = "/dev/ttyACM1" LEADER_PORT = "/dev/ttyACM1"
FOLLOWER_PORT = "/dev/ttyACM0" FOLLOWER_PORT = "/dev/ttyACM0"
MockRobot = MagicMock()
MockRobot.read_position = MagicMock()
MockRobot.read_position.return_value = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
MockCamera = MagicMock()
MockCamera.isOpened = MagicMock(return_value=True)
MockCamera.read = MagicMock(return_value=(True, np.zeros((480, 640, 3), dtype=np.uint8)))
def capture_image(cam, cam_width, cam_height): def capture_image(cam, cam_width, cam_height):
# Capture a single frame # Capture a single frame
@@ -54,6 +63,7 @@ class RealEnv(gym.Env):
trigger_torque=70, trigger_torque=70,
fps: int = FPS, fps: int = FPS,
fps_tolerance: float = 0.1, fps_tolerance: float = 0.1,
mock: bool = False,
): ):
self.num_joints = num_joints self.num_joints = num_joints
self.cameras_shapes = cameras_shapes self.cameras_shapes = cameras_shapes
@@ -68,15 +78,15 @@ class RealEnv(gym.Env):
self.fps_tolerance = fps_tolerance self.fps_tolerance = fps_tolerance
# Initialize the robot # Initialize the robot
self.follower = Robot(device_name=self.follower_port) self.follower = Robot(device_name=self.follower_port) if not mock else MockRobot
if self.record: if self.record:
self.leader = Robot(device_name=self.leader_port) self.leader = Robot(device_name=self.leader_port) if not mock else MockRobot
self.leader.set_trigger_torque(trigger_torque) self.leader.set_trigger_torque(trigger_torque)
# Initialize the cameras - sorted by camera names # Initialize the cameras - sorted by camera names
self.cameras = {} self.cameras = {}
for cn, p in sorted(self.cameras_ports.items()): for cn, p in sorted(self.cameras_ports.items()):
self.cameras[cn] = cv2.VideoCapture(p) self.cameras[cn] = cv2.VideoCapture(p) if not mock else MockCamera
if not self.cameras[cn].isOpened(): if not self.cameras[cn].isOpened():
raise OSError( raise OSError(
f"Cannot open camera port {p} for {cn}." f"Cannot open camera port {p} for {cn}."
@@ -118,7 +128,6 @@ class RealEnv(gym.Env):
self._observation = {} self._observation = {}
self._terminated = False self._terminated = False
self.starting_time = time.time()
self.timestamps = [] self.timestamps = []
def _get_obs(self): def _get_obs(self):
@@ -146,13 +155,8 @@ class RealEnv(gym.Env):
if self.timestamps: if self.timestamps:
# wait the right amount of time to stay at the desired fps # wait the right amount of time to stay at the desired fps
time.sleep(max(0, 1 / self.fps - (time.time() - self.timestamps[-1]))) time.sleep(max(0, 1 / self.fps - (time.time() - self.timestamps[-1])))
recording_time = time.time() - self.starting_time
else:
# it's the first step so we start the timer
self.starting_time = time.time()
recording_time = 0
self.timestamps.append(recording_time) self.timestamps.append(time.time())
# Get the observation # Get the observation
self._get_obs() self._get_obs()
@@ -165,13 +169,15 @@ class RealEnv(gym.Env):
reward = 0 reward = 0
terminated = truncated = self._terminated terminated = truncated = self._terminated
info = {"timestamp": recording_time, "fps_error": False} info = {"timestamp": self.timestamps[-1] - self.timestamps[0], "fps_error": False}
# Check if we are able to keep up with the desired fps # Check if we are able to keep up with the desired fps
if recording_time - self.timestamps[-1] > 1 / (self.fps - self.fps_tolerance): if len(self.timestamps) > 1 and (self.timestamps[-1] - self.timestamps[-2]) > 1 / (
self.fps - self.fps_tolerance
):
print( print(
f"Error: recording time interval {recording_time - self.timestamps[-1]:.2f} is greater" f"Error: recording fps {1 / (self.timestamps[-1] - self.timestamps[-2]):.5f} is lower"
f"than expected {1 / (self.fps - self.fps_tolerance):.2f}" f" than min admited fps {(self.fps - self.fps_tolerance):.5f}"
f" at frame {len(self.timestamps)}" f" at frame {len(self.timestamps)}"
) )
info["fps_error"] = True info["fps_error"] = True

View File

@@ -6,12 +6,14 @@ using a very simple gym environment (see in examples/real_robot_example/gym_real
import argparse import argparse
import copy import copy
import os import os
from pathlib import Path
import gym_real_world # noqa: F401 import gym_real_world # noqa: F401
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
import torch import torch
from datasets import Dataset, Features, Sequence, Value from datasets import Dataset, Features, Sequence, Value
from omegaconf import OmegaConf
from tqdm import tqdm from tqdm import tqdm
from lerobot.common.datasets.compute_stats import compute_stats from lerobot.common.datasets.compute_stats import compute_stats
@@ -30,17 +32,20 @@ parser.add_argument("--num-episodes", type=int, default=2)
parser.add_argument("--num-frames", type=int, default=400) parser.add_argument("--num-frames", type=int, default=400)
parser.add_argument("--num-workers", type=int, default=16) parser.add_argument("--num-workers", type=int, default=16)
parser.add_argument("--keep-last", action="store_true") parser.add_argument("--keep-last", action="store_true")
parser.add_argument("--data_dir", type=str, default=None)
parser.add_argument("--push-to-hub", action="store_true") parser.add_argument("--push-to-hub", action="store_true")
parser.add_argument("--fps", type=int, default=30, help="Frames per second of the recording.") parser.add_argument("--fps", type=int, default=30, help="Frames per second of the recording.")
parser.add_argument( parser.add_argument(
"--fps_tolerance", "--fps_tolerance",
type=float, type=float,
default=0.1, default=0.5,
help="Tolerance in fps for the recording before dropping episodes.", help="Tolerance in fps for the recording before dropping episodes.",
) )
parser.add_argument( parser.add_argument(
"--revision", type=str, default=CODEBASE_VERSION, help="Codebase version used to generate the dataset." "--revision", type=str, default=CODEBASE_VERSION, help="Codebase version used to generate the dataset."
) )
parser.add_argument("--gym-config", type=str, default=None, help="Path to the gym config file.")
parser.add_argument("--mock_robot", action="store_true")
args = parser.parse_args() args = parser.parse_args()
repo_id = args.repo_id repo_id = args.repo_id
@@ -50,7 +55,7 @@ revision = args.revision
fps = args.fps fps = args.fps
fps_tolerance = args.fps_tolerance fps_tolerance = args.fps_tolerance
out_data = DATA_DIR / repo_id out_data = DATA_DIR / repo_id if args.data_dir is None else Path(args.data_dir)
# During data collection, frames are stored as png images in `images_dir` # During data collection, frames are stored as png images in `images_dir`
images_dir = out_data / "images" images_dir = out_data / "images"
@@ -58,6 +63,9 @@ images_dir = out_data / "images"
videos_dir = out_data / "videos" videos_dir = out_data / "videos"
meta_data_dir = out_data / "meta_data" meta_data_dir = out_data / "meta_data"
gym_config = None
if args.config is not None:
gym_config = OmegaConf.load(args.config)
# Create image and video directories # Create image and video directories
if not os.path.exists(images_dir): if not os.path.exists(images_dir):
@@ -68,7 +76,12 @@ if not os.path.exists(videos_dir):
if __name__ == "__main__": if __name__ == "__main__":
# Create the gym environment - check the kwargs in gym_real_world/gym_environment.py # Create the gym environment - check the kwargs in gym_real_world/gym_environment.py
gym_handle = "gym_real_world/RealEnv-v0" gym_handle = "gym_real_world/RealEnv-v0"
env = gym.make(gym_handle, disable_env_checker=True, record=True, fps=fps, fps_tolerance=fps_tolerance) gym_kwargs = {}
if gym_config is not None:
gym_kwargs = OmegaConf.to_container(gym_config.gym_kwargs)
env = gym.make(
gym_handle, disable_env_checker=True, record=True, fps=fps, fps_tolerance=fps_tolerance, mock=True
)
ep_dicts = [] ep_dicts = []
episode_data_index = {"from": [], "to": []} episode_data_index = {"from": [], "to": []}

View File

@@ -10,3 +10,10 @@ env:
fps: ${fps} fps: ${fps}
episode_length: 200 episode_length: 200
real_world: true real_world: true
gym:
cameras_shapes:
images.high: [480, 640, 3]
images.low: [480, 640, 3]
cameras_ports:
images.high: /dev/video6
images.low: /dev/video0

View File

@@ -0,0 +1,19 @@
# @package _global_
fps: 30
env:
name: real_world
task: RealEnv-v0
state_dim: 6
action_dim: 6
fps: ${fps}
episode_length: 200
real_world: true
gym:
cameras_shapes:
images.top: [480, 640, 3]
images.front: [480, 640, 3]
cameras_ports:
images.top: /dev/video6
images.front: /dev/video0

View File

@@ -0,0 +1,103 @@
# @package _global_
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
# Compared to `act.yaml`, it contains 4 cameras (i.e. right_wrist, left_wrist, images,
# front) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
#
# Example of usage for training:
# ```bash
# python lerobot/scripts/train.py \
# policy=act_real \
# env=aloha_real
# ```
seed: 1000
dataset_repo_id: ???
override_dataset_stats:
observation.images.top:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.front:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
training:
offline_steps: 1000
online_steps: 0
eval_freq: -1
save_freq: 1000
log_freq: 100
save_checkpoint: true
batch_size: 8
lr: 1e-5
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
eval:
n_episodes: 1
batch_size: 1
# See `configuration_act.py` for more details.
policy:
name: act
# Input / output structure.
n_obs_steps: 1
chunk_size: 100 # chunk_size
n_action_steps: 100
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.images.top: [3, 480, 640]
observation.images.front: [3, 480, 640]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.images.top: mean_std
observation.images.front: mean_std
observation.state: mean_std
output_normalization_modes:
action: mean_std
# Architecture.
# Vision backbone.
vision_backbone: resnet18
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
replace_final_stride_with_dilation: false
# Transformer layers.
pre_norm: false
dim_model: 512
n_heads: 8
dim_feedforward: 3200
feedforward_activation: relu
n_encoder_layers: 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: 1
# VAE.
use_vae: true
latent_dim: 32
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_momentum: null
# Training and loss computation.
dropout: 0.1
kl_weight: 10.0

View File

@@ -43,9 +43,6 @@ def get_cameras(hdf5_data):
def check_format(raw_dir) -> bool: def check_format(raw_dir) -> bool:
# only frames from simulation are uncompressed
compressed_images = "sim" not in raw_dir.name
hdf5_paths = list(raw_dir.glob("episode_*.hdf5")) hdf5_paths = list(raw_dir.glob("episode_*.hdf5"))
assert len(hdf5_paths) != 0 assert len(hdf5_paths) != 0
for hdf5_path in hdf5_paths: for hdf5_path in hdf5_paths:
@@ -62,17 +59,15 @@ def check_format(raw_dir) -> bool:
for camera in get_cameras(data): for camera in get_cameras(data):
assert num_frames == data[f"/observations/images/{camera}"].shape[0] assert num_frames == data[f"/observations/images/{camera}"].shape[0]
if compressed_images: # ndim 2 when image are compressed and 4 when uncompressed
assert data[f"/observations/images/{camera}"].ndim == 2 assert data[f"/observations/images/{camera}"].ndim in [2, 4]
else: if data[f"/observations/images/{camera}"].ndim == 4:
assert data[f"/observations/images/{camera}"].ndim == 4
b, h, w, c = data[f"/observations/images/{camera}"].shape b, h, w, c = data[f"/observations/images/{camera}"].shape
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided." assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
def load_from_raw(raw_dir, out_dir, fps, video, debug): def load_from_raw(raw_dir, out_dir, fps, video, debug):
# only frames from simulation are uncompressed # only frames from simulation are uncompressed
compressed_images = "sim" not in raw_dir.name
hdf5_files = list(raw_dir.glob("*.hdf5")) hdf5_files = list(raw_dir.glob("*.hdf5"))
ep_dicts = [] ep_dicts = []
@@ -99,7 +94,7 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
for camera in get_cameras(ep): for camera in get_cameras(ep):
img_key = f"observation.images.{camera}" img_key = f"observation.images.{camera}"
if compressed_images: if ep[f"/observations/images/{camera}"].ndim == 2:
import cv2 import cv2
# load one compressed image after the other in RAM and uncompress # load one compressed image after the other in RAM and uncompress

View File

@@ -129,7 +129,9 @@ class ACTConfig:
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code # Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1. # that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521. # See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
# As a consequence we also remove the final, unused layer normalization, by default
n_decoder_layers: int = 1 n_decoder_layers: int = 1
decoder_norm: bool = False
# VAE. # VAE.
use_vae: bool = True use_vae: bool = True
latent_dim: int = 32 latent_dim: int = 32

View File

@@ -315,8 +315,14 @@ class ACT(nn.Module):
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D) pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
# Forward pass through VAE encoder to get the latent PDF parameters. # Forward pass through VAE encoder to get the latent PDF parameters.
cls_joint_is_pad = torch.full((batch_size, 2), False).to(
batch["observation.state"].device
) # False: not a padding
key_padding_mask = torch.cat([cls_joint_is_pad, batch["action_is_pad"]], axis=1) # (bs, seq+1)
cls_token_out = self.vae_encoder( cls_token_out = self.vae_encoder(
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2) vae_encoder_input.permute(1, 0, 2),
pos_embed=pos_embed.permute(1, 0, 2),
key_padding_mask=key_padding_mask,
)[0] # select the class token, with shape (B, D) )[0] # select the class token, with shape (B, D)
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out) latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
mu = latent_pdf_params[:, : self.config.latent_dim] mu = latent_pdf_params[:, : self.config.latent_dim]
@@ -402,9 +408,11 @@ class ACTEncoder(nn.Module):
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(config.n_encoder_layers)]) self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(config.n_encoder_layers)])
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity() self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor: def forward(
self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
) -> Tensor:
for layer in self.layers: for layer in self.layers:
x = layer(x, pos_embed=pos_embed) x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
x = self.norm(x) x = self.norm(x)
return x return x
@@ -427,12 +435,14 @@ class ACTEncoderLayer(nn.Module):
self.activation = get_activation_fn(config.feedforward_activation) self.activation = get_activation_fn(config.feedforward_activation)
self.pre_norm = config.pre_norm self.pre_norm = config.pre_norm
def forward(self, x, pos_embed: Tensor | None = None) -> Tensor: def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor:
skip = x skip = x
if self.pre_norm: if self.pre_norm:
x = self.norm1(x) x = self.norm1(x)
q = k = x if pos_embed is None else x + pos_embed q = k = x if pos_embed is None else x + pos_embed
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights x = self.self_attn(q, k, value=x, key_padding_mask=key_padding_mask)[
0
] # select just the output, not the attention weights
x = skip + self.dropout1(x) x = skip + self.dropout1(x)
if self.pre_norm: if self.pre_norm:
skip = x skip = x
@@ -452,7 +462,10 @@ class ACTDecoder(nn.Module):
"""Convenience module for running multiple decoder layers followed by normalization.""" """Convenience module for running multiple decoder layers followed by normalization."""
super().__init__() super().__init__()
self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]) self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)])
self.norm = nn.LayerNorm(config.dim_model) if config.decoder_norm:
self.norm = nn.LayerNorm(config.dim_model)
else:
self.norm = nn.Identity()
def forward( def forward(
self, self,
@@ -465,8 +478,7 @@ class ACTDecoder(nn.Module):
x = layer( x = layer(
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
) )
if self.norm is not None: x = self.norm(x)
x = self.norm(x)
return x return x

View File

@@ -50,6 +50,8 @@ eval:
batch_size: 1 batch_size: 1
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing). # `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
use_async_envs: false use_async_envs: false
# Specify the number of episodes to render during evaluation.
max_episodes_rendered: 10
wandb: wandb:
enable: false enable: false

View File

@@ -44,6 +44,7 @@ https://huggingface.co/lerobot/diffusion_pusht/tree/main.
import argparse import argparse
import json import json
import logging import logging
import os
import threading import threading
import time import time
from contextlib import nullcontext from contextlib import nullcontext
@@ -164,7 +165,10 @@ def rollout(
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't # VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
# available of none of the envs finished. # available of none of the envs finished.
if "final_info" in info: if "final_info" in info:
successes = [i["is_success"] if i is not None else False for i in info["final_info"]] successes = [
i["is_success"] if (i is not None and "is_success" in i) else False
for i in info["final_info"]
]
else: else:
successes = [False] * env.num_envs successes = [False] * env.num_envs
@@ -516,6 +520,7 @@ def eval(
out_dir = ( out_dir = (
f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}" f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
) )
os.makedirs(out_dir, exist_ok=True)
if out_dir is None: if out_dir is None:
raise NotImplementedError() raise NotImplementedError()
@@ -545,7 +550,7 @@ def eval(
env, env,
policy, policy,
hydra_cfg.eval.n_episodes, hydra_cfg.eval.n_episodes,
max_episodes_rendered=10, max_episodes_rendered=hydra_cfg.eval.max_episodes_rendered,
video_dir=Path(out_dir) / "eval", video_dir=Path(out_dir) / "eval",
start_seed=hydra_cfg.seed, start_seed=hydra_cfg.seed,
enable_progbar=True, enable_progbar=True,

View File

@@ -29,8 +29,8 @@ def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> s
return text return text
def _run_script(path): def _run_script(path, args=None):
subprocess.run([sys.executable, path], check=True) subprocess.run([sys.executable, path] + args if args is not None else [], check=True)
def _read_file(path): def _read_file(path):
@@ -126,3 +126,22 @@ def test_examples_basic2_basic3_advanced1():
# Restore stdout to its original state # Restore stdout to its original state
sys.stdout = sys.__stdout__ sys.stdout = sys.__stdout__
assert "Average loss on validation set" in printed_output assert "Average loss on validation set" in printed_output
def test_real_world_recording():
path = "examples/real_robot_example/record_training_data.py"
_run_script(
path,
[
"--data_dir",
"outputs/examples",
"--repo-id",
"real_world_debug",
"--num-episodes",
"2",
"--num-frames",
"10",
"--mock-robot",
],
)
assert Path("outputs/examples/real_world_debug/video/episode_0.mp4").exists()