forked from tangger/lerobot
Merge remote-tracking branch 'origin/main' into user/rcadene/2024_09_10_train_aloha
This commit is contained in:
@@ -32,7 +32,7 @@ DATASET_CARD_TEMPLATE = """
|
|||||||
---
|
---
|
||||||
# Metadata will go there
|
# Metadata will go there
|
||||||
---
|
---
|
||||||
This dataset was created using [🤗 LeRobot](https://github.com/huggingface/lerobot).
|
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
|||||||
|
|
||||||
# sanity check that images are channel last
|
# sanity check that images are channel last
|
||||||
_, h, w, c = img.shape
|
_, h, w, c = img.shape
|
||||||
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
|
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||||
|
|
||||||
# sanity check that images are uint8
|
# sanity check that images are uint8
|
||||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ eval:
|
|||||||
# `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv.
|
# `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv.
|
||||||
batch_size: 1
|
batch_size: 1
|
||||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||||
use_async_envs: true
|
use_async_envs: false
|
||||||
|
|
||||||
wandb:
|
wandb:
|
||||||
enable: false
|
enable: false
|
||||||
|
|||||||
5
lerobot/configs/env/aloha.yaml
vendored
5
lerobot/configs/env/aloha.yaml
vendored
@@ -2,11 +2,6 @@
|
|||||||
|
|
||||||
fps: 50
|
fps: 50
|
||||||
|
|
||||||
eval:
|
|
||||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
|
||||||
# set it to false to avoid some problems of the aloha env
|
|
||||||
use_async_envs: false
|
|
||||||
|
|
||||||
env:
|
env:
|
||||||
name: aloha
|
name: aloha
|
||||||
task: AlohaInsertion-v0
|
task: AlohaInsertion-v0
|
||||||
|
|||||||
5
lerobot/configs/env/xarm.yaml
vendored
5
lerobot/configs/env/xarm.yaml
vendored
@@ -2,11 +2,6 @@
|
|||||||
|
|
||||||
fps: 15
|
fps: 15
|
||||||
|
|
||||||
eval:
|
|
||||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
|
||||||
# set it to false to avoid some problems of the aloha env
|
|
||||||
use_async_envs: false
|
|
||||||
|
|
||||||
env:
|
env:
|
||||||
name: xarm
|
name: xarm
|
||||||
task: XarmLift-v0
|
task: XarmLift-v0
|
||||||
|
|||||||
@@ -179,13 +179,18 @@ def none_or_int(value):
|
|||||||
def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None):
|
def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None):
|
||||||
log_items = []
|
log_items = []
|
||||||
if episode_index is not None:
|
if episode_index is not None:
|
||||||
log_items += [f"ep:{episode_index}"]
|
log_items.append(f"ep:{episode_index}")
|
||||||
if frame_index is not None:
|
if frame_index is not None:
|
||||||
log_items += [f"frame:{frame_index}"]
|
log_items.append(f"frame:{frame_index}")
|
||||||
|
|
||||||
def log_dt(shortname, dt_val_s):
|
def log_dt(shortname, dt_val_s):
|
||||||
nonlocal log_items
|
nonlocal log_items, fps
|
||||||
log_items += [f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)"]
|
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)"
|
||||||
|
if fps is not None:
|
||||||
|
actual_fps = 1 / dt_val_s
|
||||||
|
if actual_fps < fps - 1:
|
||||||
|
info_str = colored(info_str, "yellow")
|
||||||
|
log_items.append(info_str)
|
||||||
|
|
||||||
# total step time displayed in milliseconds and its frequency
|
# total step time displayed in milliseconds and its frequency
|
||||||
log_dt("dt", dt_s)
|
log_dt("dt", dt_s)
|
||||||
@@ -210,10 +215,6 @@ def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None
|
|||||||
log_dt(f"dtR{name}", robot.logs[key])
|
log_dt(f"dtR{name}", robot.logs[key])
|
||||||
|
|
||||||
info_str = " ".join(log_items)
|
info_str = " ".join(log_items)
|
||||||
if fps is not None:
|
|
||||||
actual_fps = 1 / dt_s
|
|
||||||
if actual_fps < fps - 1:
|
|
||||||
info_str = colored(info_str, "yellow")
|
|
||||||
logging.info(info_str)
|
logging.info(info_str)
|
||||||
|
|
||||||
|
|
||||||
@@ -320,7 +321,7 @@ def record(
|
|||||||
run_compute_stats=True,
|
run_compute_stats=True,
|
||||||
push_to_hub=True,
|
push_to_hub=True,
|
||||||
tags=None,
|
tags=None,
|
||||||
num_image_writers=8,
|
num_image_writers_per_camera=4,
|
||||||
force_override=False,
|
force_override=False,
|
||||||
):
|
):
|
||||||
# TODO(rcadene): Add option to record logs
|
# TODO(rcadene): Add option to record logs
|
||||||
@@ -442,8 +443,8 @@ def record(
|
|||||||
|
|
||||||
# Save images using threads to reach high fps (30 and more)
|
# Save images using threads to reach high fps (30 and more)
|
||||||
# Using `with` to exist smoothly if an execption is raised.
|
# Using `with` to exist smoothly if an execption is raised.
|
||||||
# Using only 4 worker threads to avoid blocking the main thread.
|
|
||||||
futures = []
|
futures = []
|
||||||
|
num_image_writers = num_image_writers_per_camera * len(robot.cameras)
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
|
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
|
||||||
# Start recording all episodes
|
# Start recording all episodes
|
||||||
while episode_index < num_episodes:
|
while episode_index < num_episodes:
|
||||||
@@ -803,10 +804,14 @@ if __name__ == "__main__":
|
|||||||
help="Add tags to your dataset on the hub.",
|
help="Add tags to your dataset on the hub.",
|
||||||
)
|
)
|
||||||
parser_record.add_argument(
|
parser_record.add_argument(
|
||||||
"--num-image-writers",
|
"--num-image-writers-per-camera",
|
||||||
type=int,
|
type=int,
|
||||||
default=8,
|
default=4,
|
||||||
help="Number of threads writing the frames as png images on disk. Don't set too much as you might get unstable fps due to main thread being blocked.",
|
help=(
|
||||||
|
"Number of threads writing the frames as png images on disk, per camera. "
|
||||||
|
"Too much threads might cause unstable teleoperation fps due to main thread being blocked. "
|
||||||
|
"Not enough threads might cause low camera fps."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser_record.add_argument(
|
parser_record.add_argument(
|
||||||
"--force-override",
|
"--force-override",
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ import gymnasium as gym
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
from huggingface_hub.errors import RepositoryNotFoundError
|
||||||
from huggingface_hub.utils._validators import HFValidationError
|
from huggingface_hub.utils._validators import HFValidationError
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|||||||
@@ -82,13 +82,22 @@
|
|||||||
Episode {{ episode_id }}
|
Episode {{ episode_id }}
|
||||||
</h1>
|
</h1>
|
||||||
|
|
||||||
|
<!-- Error message -->
|
||||||
|
<div class="font-medium text-orange-700 hidden" :class="{ 'hidden': !videoCodecError }">
|
||||||
|
<p>Videos could NOT play because <a href="https://en.wikipedia.org/wiki/AV1" target="_blank" class="underline">AV1</a> decoding is not available on your browser.</p>
|
||||||
|
<ul class="list-decimal list-inside">
|
||||||
|
<li>If iPhone: <span class="italic">It is supported with A17 chip or higher.</span></li>
|
||||||
|
<li>If Mac with Safari: <span class="italic">It is supported on most browsers except Safari with M1 chip or higher and on Safari with M3 chip or higher.</span></li>
|
||||||
|
<li>Other: <span class="italic">Contact the maintainers on LeRobot discord channel:</span> <a href="https://discord.com/invite/s3KuuzsPFb" target="_blank" class="underline">https://discord.com/invite/s3KuuzsPFb</a></li>
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- Videos -->
|
<!-- Videos -->
|
||||||
<div class="flex flex-wrap gap-1">
|
<div class="flex flex-wrap gap-1">
|
||||||
<p x-show="videoCodecError" class="font-medium text-orange-700">Videos could NOT play because <a href="https://en.wikipedia.org/wiki/AV1" target="_blank" class="underline">AV1</a> decoding is not available on your browser. Learn more about <a href="https://huggingface.co/blog/video-encoding" target="_blank" class="underline">LeRobot video encoding</a>.</p>
|
|
||||||
{% for video_info in videos_info %}
|
{% for video_info in videos_info %}
|
||||||
<div x-show="!videoCodecError" class="max-w-96">
|
<div x-show="!videoCodecError" class="max-w-96">
|
||||||
<p class="text-sm text-gray-300 bg-gray-800 px-2 rounded-t-xl truncate">{{ video_info.filename }}</p>
|
<p class="text-sm text-gray-300 bg-gray-800 px-2 rounded-t-xl truncate">{{ video_info.filename }}</p>
|
||||||
<video muted loop type="video/mp4" class="min-w-64" @canplaythrough="videoCanPlay" @timeupdate="() => {
|
<video muted loop type="video/mp4" class="object-contain w-full h-full" @canplaythrough="videoCanPlay" @timeupdate="() => {
|
||||||
if (video.duration) {
|
if (video.duration) {
|
||||||
const time = video.currentTime;
|
const time = video.currentTime;
|
||||||
const pc = (100 / video.duration) * time;
|
const pc = (100 / video.duration) * time;
|
||||||
|
|||||||
18
poetry.lock
generated
18
poetry.lock
generated
@@ -1360,13 +1360,13 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "huggingface-hub"
|
name = "huggingface-hub"
|
||||||
version = "0.23.5"
|
version = "0.25.0"
|
||||||
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
|
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8.0"
|
python-versions = ">=3.8.0"
|
||||||
files = [
|
files = [
|
||||||
{file = "huggingface_hub-0.23.5-py3-none-any.whl", hash = "sha256:d7a7d337615e11a45cc14a0ce5a605db6b038dc24af42866f731684825226e90"},
|
{file = "huggingface_hub-0.25.0-py3-none-any.whl", hash = "sha256:e2f357b35d72d5012cfd127108c4e14abcd61ba4ebc90a5a374dc2456cb34e12"},
|
||||||
{file = "huggingface_hub-0.23.5.tar.gz", hash = "sha256:67a9caba79b71235be3752852ca27da86bd54311d2424ca8afdb8dda056edf98"},
|
{file = "huggingface_hub-0.25.0.tar.gz", hash = "sha256:fb5fbe6c12fcd99d187ec7db95db9110fb1a20505f23040a5449a717c1a0db4d"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -1381,17 +1381,17 @@ tqdm = ">=4.42.1"
|
|||||||
typing-extensions = ">=3.7.4.3"
|
typing-extensions = ">=3.7.4.3"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
|
all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.5.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
|
||||||
cli = ["InquirerPy (==0.3.4)"]
|
cli = ["InquirerPy (==0.3.4)"]
|
||||||
dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
|
dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.5.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
|
||||||
fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"]
|
fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"]
|
||||||
hf-transfer = ["hf-transfer (>=0.1.4)"]
|
hf-transfer = ["hf-transfer (>=0.1.4)"]
|
||||||
inference = ["aiohttp", "minijinja (>=1.0)"]
|
inference = ["aiohttp", "minijinja (>=1.0)"]
|
||||||
quality = ["mypy (==1.5.1)", "ruff (>=0.3.0)"]
|
quality = ["mypy (==1.5.1)", "ruff (>=0.5.0)"]
|
||||||
tensorflow = ["graphviz", "pydot", "tensorflow"]
|
tensorflow = ["graphviz", "pydot", "tensorflow"]
|
||||||
tensorflow-testing = ["keras (<3.0)", "tensorflow"]
|
tensorflow-testing = ["keras (<3.0)", "tensorflow"]
|
||||||
testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
|
testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
|
||||||
torch = ["safetensors", "torch"]
|
torch = ["safetensors[torch]", "torch"]
|
||||||
typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"]
|
typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -4586,4 +4586,4 @@ xarm = ["gym-xarm"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.10,<3.13"
|
python-versions = ">=3.10,<3.13"
|
||||||
content-hash = "06a8a1941b75c3ec78ade6f8b2c3ad7b5d2f1516b590fa3d5a773add73f6dbec"
|
content-hash = "c9c3beac71f760738baf2fd169378eefdaef7d3a9cd068270bc5190fbefdb42a"
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ opencv-python = ">=4.9.0"
|
|||||||
diffusers = ">=0.27.2"
|
diffusers = ">=0.27.2"
|
||||||
torchvision = ">=0.17.1"
|
torchvision = ">=0.17.1"
|
||||||
h5py = ">=3.10.0"
|
h5py = ">=3.10.0"
|
||||||
huggingface-hub = {extras = ["hf-transfer", "cli"], version = ">=0.23.0"}
|
huggingface-hub = {extras = ["hf-transfer", "cli"], version = ">=0.25.0"}
|
||||||
gymnasium = ">=0.29.1"
|
gymnasium = ">=0.29.1"
|
||||||
cmake = ">=3.29.0.1"
|
cmake = ">=3.29.0.1"
|
||||||
gym-dora = { git = "https://github.com/dora-rs/dora-lerobot.git", subdirectory = "gym_dora", optional = true }
|
gym-dora = { git = "https://github.com/dora-rs/dora-lerobot.git", subdirectory = "gym_dora", optional = true }
|
||||||
|
|||||||
Reference in New Issue
Block a user