forked from tangger/lerobot
Merge (No verify)
This commit is contained in:
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script configure a single motor at a time to a given ID and baudrate.
|
||||
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utilities to control a robot.
|
||||
|
||||
@@ -122,15 +135,19 @@ python lerobot/scripts/control_robot.py \
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
|
||||
import rerun as rr
|
||||
|
||||
# from safetensors.torch import load_file, save_file
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.control_configs import (
|
||||
CalibrateControlConfig,
|
||||
ControlConfig,
|
||||
ControlPipelineConfig,
|
||||
RecordControlConfig,
|
||||
RemoteRobotConfig,
|
||||
@@ -140,6 +157,7 @@ from lerobot.common.robot_devices.control_configs import (
|
||||
from lerobot.common.robot_devices.control_utils import (
|
||||
control_loop,
|
||||
init_keyboard_listener,
|
||||
is_headless,
|
||||
log_control_info,
|
||||
record_episode,
|
||||
reset_environment,
|
||||
@@ -219,7 +237,7 @@ def teleoperate(robot: Robot, cfg: TeleoperateControlConfig):
|
||||
control_time_s=cfg.teleop_time_s,
|
||||
fps=cfg.fps,
|
||||
teleoperate=True,
|
||||
display_cameras=cfg.display_cameras,
|
||||
display_data=cfg.display_data,
|
||||
)
|
||||
|
||||
|
||||
@@ -254,7 +272,7 @@ def record(
|
||||
)
|
||||
|
||||
# Load pretrained policy
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, cfg.device, ds_meta=dataset.meta)
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
@@ -267,7 +285,7 @@ def record(
|
||||
# 3. place the cameras windows on screen
|
||||
enable_teleoperation = policy is None
|
||||
log_say("Warmup record", cfg.play_sounds)
|
||||
warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps)
|
||||
warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_data, cfg.fps)
|
||||
|
||||
if has_method(robot, "teleop_safety_stop"):
|
||||
robot.teleop_safety_stop()
|
||||
@@ -283,10 +301,8 @@ def record(
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
episode_time_s=cfg.episode_time_s,
|
||||
display_cameras=cfg.display_cameras,
|
||||
display_data=cfg.display_data,
|
||||
policy=policy,
|
||||
device=cfg.device,
|
||||
use_amp=cfg.use_amp,
|
||||
fps=cfg.fps,
|
||||
single_task=cfg.single_task,
|
||||
)
|
||||
@@ -315,7 +331,7 @@ def record(
|
||||
break
|
||||
|
||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
stop_recording(robot, listener, cfg.display_cameras)
|
||||
stop_recording(robot, listener, cfg.display_data)
|
||||
|
||||
if cfg.push_to_hub:
|
||||
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
|
||||
@@ -352,6 +368,40 @@ def replay(
|
||||
log_control_info(robot, dt_s, fps=cfg.fps)
|
||||
|
||||
|
||||
def _init_rerun(control_config: ControlConfig, session_name: str = "lerobot_control_loop") -> None:
|
||||
"""Initializes the Rerun SDK for visualizing the control loop.
|
||||
|
||||
Args:
|
||||
control_config: Configuration determining data display and robot type.
|
||||
session_name: Rerun session name. Defaults to "lerobot_control_loop".
|
||||
|
||||
Raises:
|
||||
ValueError: If viewer IP is missing for non-remote configurations with display enabled.
|
||||
"""
|
||||
if (control_config.display_data and not is_headless()) or (
|
||||
control_config.display_data and isinstance(control_config, RemoteRobotConfig)
|
||||
):
|
||||
# Configure Rerun flush batch size default to 8KB if not set
|
||||
batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
|
||||
os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
|
||||
|
||||
# Initialize Rerun based on configuration
|
||||
rr.init(session_name)
|
||||
if isinstance(control_config, RemoteRobotConfig):
|
||||
viewer_ip = control_config.viewer_ip
|
||||
viewer_port = control_config.viewer_port
|
||||
if not viewer_ip or not viewer_port:
|
||||
raise ValueError(
|
||||
"Viewer IP & Port are required for remote config. Set via config file/CLI or disable control_config.display_data."
|
||||
)
|
||||
logging.info(f"Connecting to viewer at {viewer_ip}:{viewer_port}")
|
||||
rr.connect_tcp(f"{viewer_ip}:{viewer_port}")
|
||||
else:
|
||||
# Get memory limit for rerun viewer parameters
|
||||
memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "10%")
|
||||
rr.spawn(memory_limit=memory_limit)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def control_robot(cfg: ControlPipelineConfig):
|
||||
init_logging()
|
||||
@@ -359,17 +409,22 @@ def control_robot(cfg: ControlPipelineConfig):
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
|
||||
# TODO(Steven): Blueprint for fixed window size
|
||||
|
||||
if isinstance(cfg.control, CalibrateControlConfig):
|
||||
calibrate(robot, cfg.control)
|
||||
elif isinstance(cfg.control, TeleoperateControlConfig):
|
||||
_init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_teleop")
|
||||
teleoperate(robot, cfg.control)
|
||||
elif isinstance(cfg.control, RecordControlConfig):
|
||||
_init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_record")
|
||||
record(robot, cfg.control)
|
||||
elif isinstance(cfg.control, ReplayControlConfig):
|
||||
replay(robot, cfg.control)
|
||||
elif isinstance(cfg.control, RemoteRobotConfig):
|
||||
from lerobot.common.robot_devices.robots.lekiwi_remote import run_lekiwi
|
||||
|
||||
_init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_remote")
|
||||
run_lekiwi(cfg.robot)
|
||||
|
||||
if robot.is_connected:
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utilities to control a robot in simulation.
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ from torch import Tensor, nn
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
@@ -124,7 +124,6 @@ def rollout(
|
||||
|
||||
# Reset the policy and environments.
|
||||
policy.reset()
|
||||
|
||||
observation, info = env.reset(seed=seeds)
|
||||
if render_callback is not None:
|
||||
render_callback(env)
|
||||
@@ -145,6 +144,7 @@ def rollout(
|
||||
disable=inside_slurm(), # we dont want progress bar when we use slurm, since it clutters the logs
|
||||
leave=False,
|
||||
)
|
||||
check_env_attributes_and_types(env)
|
||||
while not np.all(done):
|
||||
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
||||
observation = preprocess_observation(observation)
|
||||
@@ -155,6 +155,10 @@ def rollout(
|
||||
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
|
||||
}
|
||||
|
||||
# Infer "task" from attributes of environments.
|
||||
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
|
||||
observation = add_envs_task(env, observation)
|
||||
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
|
||||
@@ -458,7 +462,7 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
@@ -470,14 +474,14 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||
|
||||
logging.info("Making policy.")
|
||||
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
device=device,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
policy.eval()
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||
info = eval_policy(
|
||||
env,
|
||||
policy,
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
@@ -1,364 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Use this script to convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub,
|
||||
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
|
||||
installation of neural net specific packages like pytorch, tensorflow, jax.
|
||||
|
||||
Example of how to download raw datasets, convert them into LeRobotDataset format, and push them to the hub:
|
||||
```
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--raw-dir data/pusht_raw \
|
||||
--raw-format pusht_zarr \
|
||||
--repo-id lerobot/pusht
|
||||
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--raw-dir data/xarm_lift_medium_raw \
|
||||
--raw-format xarm_pkl \
|
||||
--repo-id lerobot/xarm_lift_medium
|
||||
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--raw-dir data/aloha_sim_insertion_scripted_raw \
|
||||
--raw-format aloha_hdf5 \
|
||||
--repo-id lerobot/aloha_sim_insertion_scripted
|
||||
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--raw-dir data/umi_cup_in_the_wild_raw \
|
||||
--raw-format umi_zarr \
|
||||
--repo-id lerobot/umi_cup_in_the_wild
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import shutil
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.compute_stats import compute_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
||||
from lerobot.common.datasets.utils import create_branch, create_lerobot_dataset_card, flatten_dict
|
||||
|
||||
|
||||
def get_from_raw_to_lerobot_format_fn(raw_format: str):
|
||||
if raw_format == "pusht_zarr":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "umi_zarr":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "aloha_hdf5":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
|
||||
elif raw_format in ["rlds", "openx"]:
|
||||
from lerobot.common.datasets.push_dataset_to_hub.openx_rlds_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "dora_parquet":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "xarm_pkl":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "cam_png":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.cam_png_format import from_raw_to_lerobot_format
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?"
|
||||
)
|
||||
|
||||
return from_raw_to_lerobot_format
|
||||
|
||||
|
||||
def save_meta_data(
|
||||
info: dict[str, Any], stats: dict, episode_data_index: dict[str, list], meta_data_dir: Path
|
||||
):
|
||||
meta_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# save info
|
||||
info_path = meta_data_dir / "info.json"
|
||||
with open(str(info_path), "w") as f:
|
||||
json.dump(info, f, indent=4)
|
||||
|
||||
# save stats
|
||||
stats_path = meta_data_dir / "stats.safetensors"
|
||||
save_file(flatten_dict(stats), stats_path)
|
||||
|
||||
# save episode_data_index
|
||||
episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
|
||||
ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
|
||||
save_file(episode_data_index, ep_data_idx_path)
|
||||
|
||||
|
||||
def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str | None):
|
||||
"""Expect all meta data files to be all stored in a single "meta_data" directory.
|
||||
On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root.
|
||||
"""
|
||||
api = HfApi()
|
||||
api.upload_folder(
|
||||
folder_path=meta_data_dir,
|
||||
path_in_repo="meta_data",
|
||||
repo_id=repo_id,
|
||||
revision=revision,
|
||||
repo_type="dataset",
|
||||
)
|
||||
|
||||
|
||||
def push_dataset_card_to_hub(
|
||||
repo_id: str,
|
||||
revision: str | None,
|
||||
tags: list | None = None,
|
||||
license: str = "apache-2.0",
|
||||
**card_kwargs,
|
||||
):
|
||||
"""Creates and pushes a LeRobotDataset Card with appropriate tags to easily find it on the hub."""
|
||||
card = create_lerobot_dataset_card(tags=tags, license=license, **card_kwargs)
|
||||
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=revision)
|
||||
|
||||
|
||||
def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | None):
|
||||
"""Expect mp4 files to be all stored in a single "videos" directory.
|
||||
On the hugging face repositery, they will be uploaded in a "videos" directory at the root.
|
||||
"""
|
||||
api = HfApi()
|
||||
api.upload_folder(
|
||||
folder_path=videos_dir,
|
||||
path_in_repo="videos",
|
||||
repo_id=repo_id,
|
||||
revision=revision,
|
||||
repo_type="dataset",
|
||||
allow_patterns="*.mp4",
|
||||
)
|
||||
|
||||
|
||||
def push_dataset_to_hub(
|
||||
raw_dir: Path,
|
||||
raw_format: str,
|
||||
repo_id: str,
|
||||
push_to_hub: bool = True,
|
||||
local_dir: Path | None = None,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 8,
|
||||
episodes: list[int] | None = None,
|
||||
force_override: bool = False,
|
||||
resume: bool = False,
|
||||
cache_dir: Path = Path("/tmp"),
|
||||
tests_data_dir: Path | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
check_repo_id(repo_id)
|
||||
user_id, dataset_id = repo_id.split("/")
|
||||
|
||||
# Robustify when `raw_dir` is str instead of Path
|
||||
raw_dir = Path(raw_dir)
|
||||
if not raw_dir.exists():
|
||||
raise NotADirectoryError(
|
||||
f"{raw_dir} does not exists. Check your paths or run this command to download an existing raw dataset on the hub: "
|
||||
f"`python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py --raw-dir your/raw/dir --repo-id your/repo/id_raw`"
|
||||
)
|
||||
|
||||
if local_dir:
|
||||
# Robustify when `local_dir` is str instead of Path
|
||||
local_dir = Path(local_dir)
|
||||
|
||||
# Send warning if local_dir isn't well formatted
|
||||
if local_dir.parts[-2] != user_id or local_dir.parts[-1] != dataset_id:
|
||||
warnings.warn(
|
||||
f"`local_dir` ({local_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht'). Following this naming convention is advised, but not mandatory.",
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
# Check we don't override an existing `local_dir` by mistake
|
||||
if local_dir.exists():
|
||||
if force_override:
|
||||
shutil.rmtree(local_dir)
|
||||
elif not resume:
|
||||
raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.")
|
||||
|
||||
meta_data_dir = local_dir / "meta_data"
|
||||
videos_dir = local_dir / "videos"
|
||||
else:
|
||||
# Temporary directory used to store images, videos, meta_data
|
||||
meta_data_dir = Path(cache_dir) / "meta_data"
|
||||
videos_dir = Path(cache_dir) / "videos"
|
||||
|
||||
if raw_format is None:
|
||||
# TODO(rcadene, adilzouitine): implement auto_find_raw_format
|
||||
raise NotImplementedError()
|
||||
# raw_format = auto_find_raw_format(raw_dir)
|
||||
|
||||
# convert dataset from original raw format to LeRobot format
|
||||
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
||||
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
|
||||
raw_dir,
|
||||
videos_dir,
|
||||
fps,
|
||||
video,
|
||||
episodes,
|
||||
encoding,
|
||||
)
|
||||
|
||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||
repo_id=repo_id,
|
||||
hf_dataset=hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
info=info,
|
||||
videos_dir=videos_dir,
|
||||
)
|
||||
stats = compute_stats(lerobot_dataset, batch_size, num_workers)
|
||||
|
||||
if local_dir:
|
||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||
hf_dataset.save_to_disk(str(local_dir / "train"))
|
||||
|
||||
if push_to_hub or local_dir:
|
||||
# mandatory for upload
|
||||
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
||||
|
||||
if push_to_hub:
|
||||
hf_dataset.push_to_hub(repo_id, revision="main")
|
||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
||||
push_dataset_card_to_hub(repo_id, revision="main")
|
||||
if video:
|
||||
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
||||
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
||||
|
||||
if tests_data_dir:
|
||||
# get the first episode
|
||||
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
||||
test_hf_dataset = hf_dataset.select(range(num_items_first_ep))
|
||||
episode_data_index = {k: v[:1] for k, v in episode_data_index.items()}
|
||||
|
||||
test_hf_dataset = test_hf_dataset.with_format(None)
|
||||
test_hf_dataset.save_to_disk(str(tests_data_dir / repo_id / "train"))
|
||||
|
||||
tests_meta_data = tests_data_dir / repo_id / "meta_data"
|
||||
save_meta_data(info, stats, episode_data_index, tests_meta_data)
|
||||
|
||||
# copy videos of first episode to tests directory
|
||||
episode_index = 0
|
||||
tests_videos_dir = tests_data_dir / repo_id / "videos"
|
||||
tests_videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
for key in lerobot_dataset.camera_keys:
|
||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||
shutil.copy(videos_dir / fname, tests_videos_dir / fname)
|
||||
|
||||
if local_dir is None:
|
||||
# clear cache
|
||||
shutil.rmtree(meta_data_dir)
|
||||
shutil.rmtree(videos_dir)
|
||||
|
||||
return lerobot_dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--raw-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
|
||||
)
|
||||
# TODO(rcadene): add automatic detection of the format
|
||||
parser.add_argument(
|
||||
"--raw-format",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`, `rlds`, `openx`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=Path,
|
||||
help="When provided, writes the dataset converted to LeRobotDataset format in this directory (e.g. `data/lerobot/aloha_mobile_chair`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Upload to hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fps",
|
||||
type=int,
|
||||
help="Frame rate used to collect videos. If not provided, use the default one specified in the code.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--video",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Convert each episode of the raw dataset to an mp4 video. This option allows 60 times lower disk space consumption and 25 faster loading time during training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Batch size loaded by DataLoader for computing the dataset statistics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of processes of Dataloader for computing the dataset statistics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episodes",
|
||||
type=int,
|
||||
nargs="*",
|
||||
help="When provided, only converts the provided episodes (e.g `--episodes 2 3 4`). Useful to test the code on 1 episode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-override",
|
||||
type=int,
|
||||
default=0,
|
||||
help="When set to 1, removes provided output directory if it already exists. By default, raises a ValueError exception.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=int,
|
||||
default=0,
|
||||
help="When set to 1, resumes a previous run.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache-dir",
|
||||
type=Path,
|
||||
required=False,
|
||||
default="/tmp",
|
||||
help="Directory to store the temporary videos and images generated while creating the dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tests-data-dir",
|
||||
type=Path,
|
||||
help=(
|
||||
"When provided, save tests artifacts into the given directory "
|
||||
"(e.g. `--tests-data-dir tests/data` will save to tests/data/{--repo-id})."
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
push_dataset_to_hub(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -120,7 +120,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
set_seed(cfg.seed)
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
@@ -133,18 +133,17 @@ def train(cfg: TrainPipelineConfig):
|
||||
eval_env = None
|
||||
if cfg.eval_freq > 0 and cfg.env is not None:
|
||||
logging.info("Creating env")
|
||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size)
|
||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||
|
||||
logging.info("Creating policy")
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
device=device,
|
||||
ds_meta=dataset.meta,
|
||||
)
|
||||
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
grad_scaler = GradScaler(device, enabled=cfg.use_amp)
|
||||
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
|
||||
|
||||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
|
||||
@@ -219,7 +218,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
cfg.optimizer.grad_clip_norm,
|
||||
grad_scaler=grad_scaler,
|
||||
lr_scheduler=lr_scheduler,
|
||||
use_amp=cfg.use_amp,
|
||||
use_amp=cfg.policy.use_amp,
|
||||
)
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
@@ -250,7 +249,10 @@ def train(cfg: TrainPipelineConfig):
|
||||
if cfg.env and is_eval_step:
|
||||
step_id = get_step_identifier(step, cfg.steps)
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
||||
):
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
policy,
|
||||
|
||||
@@ -265,13 +265,25 @@ def main():
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tolerance-s",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help=(
|
||||
"Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
|
||||
"This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
|
||||
"If not given, defaults to 1e-4."
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
kwargs = vars(args)
|
||||
repo_id = kwargs.pop("repo_id")
|
||||
root = kwargs.pop("root")
|
||||
tolerance_s = kwargs.pop("tolerance_s")
|
||||
|
||||
logging.info("Loading dataset")
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
dataset = LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s)
|
||||
|
||||
visualize_dataset(dataset, **vars(args))
|
||||
|
||||
|
||||
@@ -158,7 +158,7 @@ def run_server(
|
||||
if major_version < 2:
|
||||
return "Make sure to convert your LeRobotDataset to v2 & above."
|
||||
|
||||
episode_data_csv_str, columns = get_episode_data(dataset, episode_id)
|
||||
episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id)
|
||||
dataset_info = {
|
||||
"repo_id": f"{dataset_namespace}/{dataset_name}",
|
||||
"num_samples": dataset.num_frames
|
||||
@@ -218,6 +218,7 @@ def run_server(
|
||||
videos_info=videos_info,
|
||||
episode_data_csv_str=episode_data_csv_str,
|
||||
columns=columns,
|
||||
ignored_columns=ignored_columns,
|
||||
)
|
||||
|
||||
app.run(host=host, port=port)
|
||||
@@ -233,9 +234,17 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
||||
This file will be loaded by Dygraph javascript to plot data in real time."""
|
||||
columns = []
|
||||
|
||||
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] == "float32"]
|
||||
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]]
|
||||
selected_columns.remove("timestamp")
|
||||
|
||||
ignored_columns = []
|
||||
for column_name in selected_columns:
|
||||
shape = dataset.features[column_name]["shape"]
|
||||
shape_dim = len(shape)
|
||||
if shape_dim > 1:
|
||||
selected_columns.remove(column_name)
|
||||
ignored_columns.append(column_name)
|
||||
|
||||
# init header of csv with state and action names
|
||||
header = ["timestamp"]
|
||||
|
||||
@@ -291,7 +300,7 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
||||
csv_writer.writerows(rows)
|
||||
csv_string = csv_buffer.getvalue()
|
||||
|
||||
return csv_string, columns
|
||||
return csv_string, columns, ignored_columns
|
||||
|
||||
|
||||
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
||||
@@ -437,15 +446,31 @@ def main():
|
||||
help="Delete the output directory if it exists already.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tolerance-s",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help=(
|
||||
"Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
|
||||
"This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
|
||||
"If not given, defaults to 1e-4."
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
kwargs = vars(args)
|
||||
repo_id = kwargs.pop("repo_id")
|
||||
load_from_hf_hub = kwargs.pop("load_from_hf_hub")
|
||||
root = kwargs.pop("root")
|
||||
tolerance_s = kwargs.pop("tolerance_s")
|
||||
|
||||
dataset = None
|
||||
if repo_id:
|
||||
dataset = LeRobotDataset(repo_id, root=root) if not load_from_hf_hub else get_dataset_info(repo_id)
|
||||
dataset = (
|
||||
LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s)
|
||||
if not load_from_hf_hub
|
||||
else get_dataset_info(repo_id)
|
||||
)
|
||||
|
||||
visualize_dataset_html(dataset, **vars(args))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user