This commit is contained in:
Remi Cadene
2024-07-02 21:35:24 +02:00
parent 47aac0dff7
commit 8a7aa50e97
9 changed files with 207 additions and 140 deletions

View File

@@ -62,20 +62,22 @@ python lerobot/scripts/control_robot.py run_policy \
"""
import argparse
from contextlib import nullcontext
import concurrent.futures
import os
from pathlib import Path
import shutil
import time
from contextlib import nullcontext
from pathlib import Path
from PIL import Image
from omegaconf import DictConfig
import torch
from omegaconf import DictConfig
from PIL import Image
from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
from lerobot.common.datasets.utils import calculate_episode_data_index, load_hf_dataset
from lerobot.common.datasets.utils import calculate_episode_data_index
from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.factory import make_robot
@@ -83,14 +85,12 @@ from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed
from lerobot.scripts.eval import get_pretrained_policy_path
from lerobot.scripts.push_dataset_to_hub import save_meta_data
from lerobot.scripts.robot_controls.record_dataset import record_dataset
import concurrent.futures
########################################################################################
# Utilities
########################################################################################
def save_image(img_tensor, key, frame_index, episode_index, videos_dir):
img = Image.fromarray(img_tensor.numpy())
path = videos_dir / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
@@ -106,15 +106,18 @@ def busy_wait(seconds):
while time.perf_counter() < end_time:
pass
def none_or_int(value):
if value == 'None':
if value == "None":
return None
return int(value)
########################################################################################
# Control modes
########################################################################################
def teleoperate(robot: Robot, fps: int | None = None):
robot.init_teleop()
@@ -123,14 +126,24 @@ def teleoperate(robot: Robot, fps: int | None = None):
robot.teleop_step()
if fps is not None:
dt_s = (time.perf_counter() - now)
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
dt_s = (time.perf_counter() - now)
dt_s = time.perf_counter() - now
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="lerobot/debug", warmup_time_s=2, episode_time_s=10, num_episodes=50, video=True, run_compute_stats=True):
def record_dataset(
robot: Robot,
fps: int | None = None,
root="data",
repo_id="lerobot/debug",
warmup_time_s=2,
episode_time_s=10,
num_episodes=50,
video=True,
run_compute_stats=True,
):
if not video:
raise NotImplementedError()
@@ -143,7 +156,6 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
videos_dir = local_dir / "videos"
videos_dir.mkdir(parents=True, exist_ok=True)
start_time = time.perf_counter()
is_warmup_print = False
@@ -154,7 +166,6 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
# Using `with` ensures the program exists smoothly if an execption is raised.
with concurrent.futures.ThreadPoolExecutor() as executor:
for episode_index in range(num_episodes):
ep_dict = {}
frame_index = 0
@@ -169,10 +180,10 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
timestamp = time.perf_counter() - start_time
if timestamp < warmup_time_s:
dt_s = (time.perf_counter() - now)
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
dt_s = (time.perf_counter() - now)
dt_s = time.perf_counter() - now
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f} (Warmup)")
continue
@@ -199,10 +210,10 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
frame_index += 1
dt_s = (time.perf_counter() - now)
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
dt_s = (time.perf_counter() - now)
dt_s = time.perf_counter() - now
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
if timestamp > episode_time_s - warmup_time_s:
@@ -229,7 +240,7 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
for key in not_image_keys:
ep_dict[key] = torch.stack(ep_dict[key])
for key in action:
ep_dict[key] = torch.stack(ep_dict[key])
@@ -269,10 +280,7 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
info=info,
videos_dir=videos_dir,
)
if run_compute_stats:
stats = compute_stats(lerobot_dataset)
else:
stats = {}
stats = compute_stats(lerobot_dataset) if run_compute_stats else {}
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
hf_dataset.save_to_disk(str(local_dir / "train"))
@@ -293,7 +301,7 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat
to_idx = dataset.episode_data_index["to"][episode].item()
robot.init_teleop()
print("Replaying episode")
os.system('say "Replaying episode"')
@@ -303,10 +311,10 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat
action = items[idx]["action"]
robot.send_action(action)
dt_s = (time.perf_counter() - now)
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
dt_s = (time.perf_counter() - now)
dt_s = time.perf_counter() - now
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
@@ -327,15 +335,18 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig):
observation = robot.capture_observation()
with torch.inference_mode(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
with (
torch.inference_mode(),
torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(),
):
action = policy.select_action(observation)
robot.send_action(action)
dt_s = (time.perf_counter() - now)
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
dt_s = (time.perf_counter() - now)
dt_s = time.perf_counter() - now
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
@@ -345,32 +356,46 @@ if __name__ == "__main__":
# Set common options for all the subparsers
base_parser = argparse.ArgumentParser(add_help=False)
base_parser.add_argument("--robot", type=str, default="koch", help="Name of the robot provided to the `make_robot(name)` factory function.")
base_parser.add_argument(
"--robot",
type=str,
default="koch",
help="Name of the robot provided to the `make_robot(name)` factory function.",
)
parser_teleop = subparsers.add_parser("teleoperate", parents=[base_parser])
parser_teleop.add_argument('--fps', type=none_or_int, default=None, help='Frames per second (set to None to disable)')
parser_teleop.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
parser_record = subparsers.add_parser("record_dataset", parents=[base_parser])
parser_record.add_argument('--fps', type=none_or_int, default=None, help='Frames per second (set to None to disable)')
parser_record.add_argument('--root', type=Path, default="data", help='')
parser_record.add_argument('--repo-id', type=str, default="lerobot/test", help='')
parser_record.add_argument('--warmup-time-s', type=int, default=2, help='')
parser_record.add_argument('--episode-time-s', type=int, default=10, help='')
parser_record.add_argument('--num-episodes', type=int, default=50, help='')
parser_record.add_argument('--run-compute-stats', type=int, default=1, help='')
parser_record.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
parser_record.add_argument("--root", type=Path, default="data", help="")
parser_record.add_argument("--repo-id", type=str, default="lerobot/test", help="")
parser_record.add_argument("--warmup-time-s", type=int, default=2, help="")
parser_record.add_argument("--episode-time-s", type=int, default=10, help="")
parser_record.add_argument("--num-episodes", type=int, default=50, help="")
parser_record.add_argument("--run-compute-stats", type=int, default=1, help="")
parser_replay = subparsers.add_parser("replay_episode", parents=[base_parser])
parser_replay.add_argument('--fps', type=none_or_int, default=None, help='Frames per second (set to None to disable)')
parser_replay.add_argument('--root', type=Path, default="data", help='')
parser_replay.add_argument('--repo-id', type=str, default="lerobot/test", help='')
parser_replay.add_argument('--episode', type=int, default=0, help='')
parser_replay.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
parser_replay.add_argument("--root", type=Path, default="data", help="")
parser_replay.add_argument("--repo-id", type=str, default="lerobot/test", help="")
parser_replay.add_argument("--episode", type=int, default=0, help="")
parser_policy = subparsers.add_parser("run_policy", parents=[base_parser])
parser_policy.add_argument('-p', '--pretrained-policy-name-or-path', type=str,
parser_policy.add_argument(
"-p",
"--pretrained-policy-name-or-path",
type=str,
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`."
)
),
)
parser_policy.add_argument(
"overrides",

View File

@@ -580,9 +580,7 @@ def main(
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
try:
pretrained_policy_path = Path(
snapshot_download(pretrained_policy_name_or_path, revision=revision)
)
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision))
except (HFValidationError, RepositoryNotFoundError) as e:
if isinstance(e, HFValidationError):
error_message = (
@@ -644,7 +642,9 @@ if __name__ == "__main__":
if args.pretrained_policy_name_or_path is None:
main(hydra_cfg_path=args.config, out_dir=args.out_dir, config_overrides=args.overrides)
else:
pretrained_policy_path = get_pretrained_policy_path(args.pretrained_policy_name_or_path, revision=args.revision)
pretrained_policy_path = get_pretrained_policy_path(
args.pretrained_policy_name_or_path, revision=args.revision
)
main(
pretrained_policy_path=pretrained_policy_path,