forked from tangger/lerobot
Merge remote-tracking branch 'origin/main' into user/rcadene/2024_09_10_train_aloha
This commit is contained in:
@@ -32,7 +32,7 @@ DATASET_CARD_TEMPLATE = """
|
||||
---
|
||||
# Metadata will go there
|
||||
---
|
||||
This dataset was created using [🤗 LeRobot](https://github.com/huggingface/lerobot).
|
||||
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
|
||||
@@ -120,7 +120,7 @@ eval:
|
||||
# `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv.
|
||||
batch_size: 1
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
use_async_envs: true
|
||||
use_async_envs: false
|
||||
|
||||
wandb:
|
||||
enable: false
|
||||
|
||||
5
lerobot/configs/env/aloha.yaml
vendored
5
lerobot/configs/env/aloha.yaml
vendored
@@ -2,11 +2,6 @@
|
||||
|
||||
fps: 50
|
||||
|
||||
eval:
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
# set it to false to avoid some problems of the aloha env
|
||||
use_async_envs: false
|
||||
|
||||
env:
|
||||
name: aloha
|
||||
task: AlohaInsertion-v0
|
||||
|
||||
5
lerobot/configs/env/xarm.yaml
vendored
5
lerobot/configs/env/xarm.yaml
vendored
@@ -2,11 +2,6 @@
|
||||
|
||||
fps: 15
|
||||
|
||||
eval:
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
# set it to false to avoid some problems of the aloha env
|
||||
use_async_envs: false
|
||||
|
||||
env:
|
||||
name: xarm
|
||||
task: XarmLift-v0
|
||||
|
||||
@@ -179,13 +179,18 @@ def none_or_int(value):
|
||||
def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None):
|
||||
log_items = []
|
||||
if episode_index is not None:
|
||||
log_items += [f"ep:{episode_index}"]
|
||||
log_items.append(f"ep:{episode_index}")
|
||||
if frame_index is not None:
|
||||
log_items += [f"frame:{frame_index}"]
|
||||
log_items.append(f"frame:{frame_index}")
|
||||
|
||||
def log_dt(shortname, dt_val_s):
|
||||
nonlocal log_items
|
||||
log_items += [f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)"]
|
||||
nonlocal log_items, fps
|
||||
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)"
|
||||
if fps is not None:
|
||||
actual_fps = 1 / dt_val_s
|
||||
if actual_fps < fps - 1:
|
||||
info_str = colored(info_str, "yellow")
|
||||
log_items.append(info_str)
|
||||
|
||||
# total step time displayed in milliseconds and its frequency
|
||||
log_dt("dt", dt_s)
|
||||
@@ -210,10 +215,6 @@ def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None
|
||||
log_dt(f"dtR{name}", robot.logs[key])
|
||||
|
||||
info_str = " ".join(log_items)
|
||||
if fps is not None:
|
||||
actual_fps = 1 / dt_s
|
||||
if actual_fps < fps - 1:
|
||||
info_str = colored(info_str, "yellow")
|
||||
logging.info(info_str)
|
||||
|
||||
|
||||
@@ -320,7 +321,7 @@ def record(
|
||||
run_compute_stats=True,
|
||||
push_to_hub=True,
|
||||
tags=None,
|
||||
num_image_writers=8,
|
||||
num_image_writers_per_camera=4,
|
||||
force_override=False,
|
||||
):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
@@ -442,8 +443,8 @@ def record(
|
||||
|
||||
# Save images using threads to reach high fps (30 and more)
|
||||
# Using `with` to exist smoothly if an execption is raised.
|
||||
# Using only 4 worker threads to avoid blocking the main thread.
|
||||
futures = []
|
||||
num_image_writers = num_image_writers_per_camera * len(robot.cameras)
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
|
||||
# Start recording all episodes
|
||||
while episode_index < num_episodes:
|
||||
@@ -803,10 +804,14 @@ if __name__ == "__main__":
|
||||
help="Add tags to your dataset on the hub.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--num-image-writers",
|
||||
"--num-image-writers-per-camera",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of threads writing the frames as png images on disk. Don't set too much as you might get unstable fps due to main thread being blocked.",
|
||||
default=4,
|
||||
help=(
|
||||
"Number of threads writing the frames as png images on disk, per camera. "
|
||||
"Too much threads might cause unstable teleoperation fps due to main thread being blocked. "
|
||||
"Not enough threads might cause low camera fps."
|
||||
),
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--force-override",
|
||||
|
||||
@@ -57,7 +57,7 @@ import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from huggingface_hub.errors import RepositoryNotFoundError
|
||||
from huggingface_hub.utils._validators import HFValidationError
|
||||
from torch import Tensor, nn
|
||||
from tqdm import trange
|
||||
|
||||
@@ -82,13 +82,22 @@
|
||||
Episode {{ episode_id }}
|
||||
</h1>
|
||||
|
||||
<!-- Error message -->
|
||||
<div class="font-medium text-orange-700 hidden" :class="{ 'hidden': !videoCodecError }">
|
||||
<p>Videos could NOT play because <a href="https://en.wikipedia.org/wiki/AV1" target="_blank" class="underline">AV1</a> decoding is not available on your browser.</p>
|
||||
<ul class="list-decimal list-inside">
|
||||
<li>If iPhone: <span class="italic">It is supported with A17 chip or higher.</span></li>
|
||||
<li>If Mac with Safari: <span class="italic">It is supported on most browsers except Safari with M1 chip or higher and on Safari with M3 chip or higher.</span></li>
|
||||
<li>Other: <span class="italic">Contact the maintainers on LeRobot discord channel:</span> <a href="https://discord.com/invite/s3KuuzsPFb" target="_blank" class="underline">https://discord.com/invite/s3KuuzsPFb</a></li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<!-- Videos -->
|
||||
<div class="flex flex-wrap gap-1">
|
||||
<p x-show="videoCodecError" class="font-medium text-orange-700">Videos could NOT play because <a href="https://en.wikipedia.org/wiki/AV1" target="_blank" class="underline">AV1</a> decoding is not available on your browser. Learn more about <a href="https://huggingface.co/blog/video-encoding" target="_blank" class="underline">LeRobot video encoding</a>.</p>
|
||||
{% for video_info in videos_info %}
|
||||
<div x-show="!videoCodecError" class="max-w-96">
|
||||
<p class="text-sm text-gray-300 bg-gray-800 px-2 rounded-t-xl truncate">{{ video_info.filename }}</p>
|
||||
<video muted loop type="video/mp4" class="min-w-64" @canplaythrough="videoCanPlay" @timeupdate="() => {
|
||||
<video muted loop type="video/mp4" class="object-contain w-full h-full" @canplaythrough="videoCanPlay" @timeupdate="() => {
|
||||
if (video.duration) {
|
||||
const time = video.currentTime;
|
||||
const pc = (100 / video.duration) * time;
|
||||
|
||||
18
poetry.lock
generated
18
poetry.lock
generated
@@ -1360,13 +1360,13 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "huggingface-hub"
|
||||
version = "0.23.5"
|
||||
version = "0.25.0"
|
||||
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
|
||||
optional = false
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "huggingface_hub-0.23.5-py3-none-any.whl", hash = "sha256:d7a7d337615e11a45cc14a0ce5a605db6b038dc24af42866f731684825226e90"},
|
||||
{file = "huggingface_hub-0.23.5.tar.gz", hash = "sha256:67a9caba79b71235be3752852ca27da86bd54311d2424ca8afdb8dda056edf98"},
|
||||
{file = "huggingface_hub-0.25.0-py3-none-any.whl", hash = "sha256:e2f357b35d72d5012cfd127108c4e14abcd61ba4ebc90a5a374dc2456cb34e12"},
|
||||
{file = "huggingface_hub-0.25.0.tar.gz", hash = "sha256:fb5fbe6c12fcd99d187ec7db95db9110fb1a20505f23040a5449a717c1a0db4d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -1381,17 +1381,17 @@ tqdm = ">=4.42.1"
|
||||
typing-extensions = ">=3.7.4.3"
|
||||
|
||||
[package.extras]
|
||||
all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
|
||||
all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.5.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
|
||||
cli = ["InquirerPy (==0.3.4)"]
|
||||
dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
|
||||
dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.5.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
|
||||
fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"]
|
||||
hf-transfer = ["hf-transfer (>=0.1.4)"]
|
||||
inference = ["aiohttp", "minijinja (>=1.0)"]
|
||||
quality = ["mypy (==1.5.1)", "ruff (>=0.3.0)"]
|
||||
quality = ["mypy (==1.5.1)", "ruff (>=0.5.0)"]
|
||||
tensorflow = ["graphviz", "pydot", "tensorflow"]
|
||||
tensorflow-testing = ["keras (<3.0)", "tensorflow"]
|
||||
testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
|
||||
torch = ["safetensors", "torch"]
|
||||
testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
|
||||
torch = ["safetensors[torch]", "torch"]
|
||||
typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"]
|
||||
|
||||
[[package]]
|
||||
@@ -4586,4 +4586,4 @@ xarm = ["gym-xarm"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "06a8a1941b75c3ec78ade6f8b2c3ad7b5d2f1516b590fa3d5a773add73f6dbec"
|
||||
content-hash = "c9c3beac71f760738baf2fd169378eefdaef7d3a9cd068270bc5190fbefdb42a"
|
||||
|
||||
@@ -43,7 +43,7 @@ opencv-python = ">=4.9.0"
|
||||
diffusers = ">=0.27.2"
|
||||
torchvision = ">=0.17.1"
|
||||
h5py = ">=3.10.0"
|
||||
huggingface-hub = {extras = ["hf-transfer", "cli"], version = ">=0.23.0"}
|
||||
huggingface-hub = {extras = ["hf-transfer", "cli"], version = ">=0.25.0"}
|
||||
gymnasium = ">=0.29.1"
|
||||
cmake = ">=3.29.0.1"
|
||||
gym-dora = { git = "https://github.com/dora-rs/dora-lerobot.git", subdirectory = "gym_dora", optional = true }
|
||||
|
||||
Reference in New Issue
Block a user