Compare commits
4 Commits
thom_arm
...
user/alibe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5df98c3877 | ||
|
|
1e6b7d249b | ||
|
|
7d24e0ee13 | ||
|
|
512df4b468 |
@@ -4,155 +4,21 @@ import gymnasium as gym
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from dora import Node
|
||||
from gymnasium import spaces
|
||||
|
||||
FPS = int(os.getenv("FPS", "30"))
|
||||
IMAGE_WIDTH = int(os.getenv("IMAGE_WIDTH", "640"))
|
||||
IMAGE_HEIGHT = int(os.getenv("IMAGE_HEIGHT", "480"))
|
||||
|
||||
ALOHA_JOINTS = [
|
||||
# absolute joint position
|
||||
"left_arm_waist",
|
||||
"left_arm_shoulder",
|
||||
"left_arm_elbow",
|
||||
"left_arm_forearm_roll",
|
||||
"left_arm_wrist_angle",
|
||||
"left_arm_wrist_rotate",
|
||||
# normalized gripper position 0: close, 1: open
|
||||
"left_arm_gripper",
|
||||
# absolute joint position
|
||||
"right_arm_waist",
|
||||
"right_arm_shoulder",
|
||||
"right_arm_elbow",
|
||||
"right_arm_forearm_roll",
|
||||
"right_arm_wrist_angle",
|
||||
"right_arm_wrist_rotate",
|
||||
# normalized gripper position 0: close, 1: open
|
||||
"right_arm_gripper",
|
||||
]
|
||||
ALOHA_ACTIONS = [
|
||||
# position and quaternion for end effector
|
||||
"left_arm_waist",
|
||||
"left_arm_shoulder",
|
||||
"left_arm_elbow",
|
||||
"left_arm_forearm_roll",
|
||||
"left_arm_wrist_angle",
|
||||
"left_arm_wrist_rotate",
|
||||
# normalized gripper position (0: close, 1: open)
|
||||
"left_arm_gripper",
|
||||
"right_arm_waist",
|
||||
"right_arm_shoulder",
|
||||
"right_arm_elbow",
|
||||
"right_arm_forearm_roll",
|
||||
"right_arm_wrist_angle",
|
||||
"right_arm_wrist_rotate",
|
||||
# normalized gripper position (0: close, 1: open)
|
||||
"right_arm_gripper",
|
||||
]
|
||||
FPS = int(os.getenv("FPS", "30"))
|
||||
|
||||
|
||||
class DoraEnv(gym.Env):
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": FPS}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model="aloha",
|
||||
observation_width=IMAGE_WIDTH,
|
||||
observation_height=IMAGE_HEIGHT,
|
||||
cameras_names=None,
|
||||
num_joints=None,
|
||||
num_actions=None,
|
||||
):
|
||||
"""Initializes the Dora environment.
|
||||
|
||||
Args:
|
||||
model (str): The model to use. Either 'aloha' or 'custom'.
|
||||
observation_width (int): The width of the observation image.
|
||||
observation_height (int): The height of the observation image.
|
||||
cameras_names (list): A list of camera names to use. If not provided, the default is ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist'].
|
||||
num_joints (int): The number of joints in the model. If not provided, the default is 14 for 'aloha' and 6 for 'fivedof'.
|
||||
num_actions (int): The number of actions in the model. If not provided, the default is 14 for 'aloha' and 6 for 'fivedof'.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
def __init__(self, model="aloha"):
|
||||
# Initialize a new node
|
||||
self.node = Node() if os.environ.get("DORA_NODE_CONFIG", None) is not None else None
|
||||
self.node = Node()
|
||||
self.observation = {"pixels": {}, "agent_pos": None}
|
||||
self.terminated = False
|
||||
|
||||
self.observation_height = observation_height
|
||||
self.observation_width = observation_width
|
||||
|
||||
# Observation space
|
||||
if model == "aloha":
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"pixels": spaces.Dict(
|
||||
{
|
||||
"cam_high": spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(self.observation_height, self.observation_width, 3),
|
||||
dtype=np.uint8,
|
||||
),
|
||||
"cam_low": spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(self.observation_height, self.observation_width, 3),
|
||||
dtype=np.uint8,
|
||||
),
|
||||
"cam_left_wrist": spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(self.observation_height, self.observation_width, 3),
|
||||
dtype=np.uint8,
|
||||
),
|
||||
"cam_right_wrist": spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(self.observation_height, self.observation_width, 3),
|
||||
dtype=np.uint8,
|
||||
),
|
||||
}
|
||||
),
|
||||
"agent_pos": spaces.Box(
|
||||
low=-1000.0,
|
||||
high=1000.0,
|
||||
shape=(len(ALOHA_JOINTS),),
|
||||
dtype=np.float64,
|
||||
),
|
||||
}
|
||||
)
|
||||
elif model == "custom":
|
||||
pixel_dict = {}
|
||||
for camera in cameras_names:
|
||||
assert camera.startswith("cam"), "Camera names must start with 'cam'"
|
||||
pixel_dict[camera] = spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(self.observation_height, self.observation_width, 3),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"pixels": spaces.Dict(pixel_dict),
|
||||
"agent_pos": spaces.Box(
|
||||
low=-1000.0,
|
||||
high=1000.0,
|
||||
shape=(num_joints,),
|
||||
dtype=np.float64,
|
||||
),
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError("Model must be either 'aloha' or 'custom'.")
|
||||
|
||||
# Action space
|
||||
if model == "aloha":
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(len(ALOHA_ACTIONS),), dtype=np.float32)
|
||||
elif model == "custom":
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(num_actions,), dtype=np.float32)
|
||||
|
||||
def _get_obs(self):
|
||||
while True:
|
||||
event = self.node.next(timeout=0.001)
|
||||
@@ -166,7 +32,7 @@ class DoraEnv(gym.Env):
|
||||
# Map Image input into pixels key within Aloha environment
|
||||
if "cam" in event["id"]:
|
||||
self.observation["pixels"][event["id"]] = (
|
||||
event["value"].to_numpy().reshape(self.observation_height, self.observation_width, 3)
|
||||
event["value"].to_numpy().reshape(IMAGE_HEIGHT, IMAGE_WIDTH, 3)
|
||||
)
|
||||
else:
|
||||
# Map other inputs into the observation dictionary using the event id as key
|
||||
|
||||
@@ -1,200 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Contains utilities to process raw data format from dora-record
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
|
||||
def check_format(raw_dir) -> bool:
|
||||
assert raw_dir.exists()
|
||||
|
||||
leader_file = list(raw_dir.glob("*.parquet"))
|
||||
if len(leader_file) == 0:
|
||||
raise ValueError(f"Missing parquet files in '{raw_dir}'")
|
||||
return True
|
||||
|
||||
|
||||
def load_from_raw(raw_dir: Path, out_dir: Path):
|
||||
# Load data stream that will be used as reference for the timestamps synchronization
|
||||
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
|
||||
if len(reference_files) == 0:
|
||||
raise ValueError(f"Missing reference files for camera, starting with in '{raw_dir}'")
|
||||
# select first camera in alphanumeric order
|
||||
reference_key = sorted(reference_files)[0].stem
|
||||
reference_df = pd.read_parquet(raw_dir / f"{reference_key}.parquet")
|
||||
reference_df = reference_df[["timestamp_utc", reference_key]]
|
||||
|
||||
# Merge all data stream using nearest backward strategy
|
||||
df = reference_df
|
||||
for path in raw_dir.glob("*.parquet"):
|
||||
key = path.stem # action or observation.state or ...
|
||||
if key == reference_key:
|
||||
continue
|
||||
modality_df = pd.read_parquet(path)
|
||||
modality_df = modality_df[["timestamp_utc", key]]
|
||||
df = pd.merge_asof(
|
||||
df,
|
||||
modality_df,
|
||||
on="timestamp_utc",
|
||||
direction="backward",
|
||||
)
|
||||
|
||||
# Remove rows with a NaN in any column. It can happened during the first frames of an episode,
|
||||
# because some cameras didnt start recording yet.
|
||||
df = df.dropna(axis=0)
|
||||
|
||||
# Remove rows with episode_index -1 which indicates a failed episode
|
||||
df = df[df["episode_index"] != -1]
|
||||
|
||||
# dora only use arrays, so single values are encapsulated into a list
|
||||
df["episode_index"] = df["episode_index"].map(lambda x: x[0])
|
||||
df["frame_index"] = df.groupby("episode_index").cumcount()
|
||||
df = df.reset_index()
|
||||
df["index"] = df.index
|
||||
|
||||
# set 'next.done' to True for the last frame of each episode
|
||||
df["next.done"] = False
|
||||
df.loc[df.groupby("episode_index").tail(1).index, "next.done"] = True
|
||||
|
||||
df["timestamp"] = df["timestamp_utc"].map(lambda x: x.timestamp())
|
||||
# each episode starts with timestamp 0 to match the ones from the video
|
||||
df["timestamp"] = df.groupby("episode_index")["timestamp"].transform(lambda x: x - x.iloc[0])
|
||||
|
||||
del df["timestamp_utc"]
|
||||
|
||||
# sanity check episode indices go from 0 to n-1
|
||||
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
|
||||
expected_ep_ids = list(range(df["episode_index"].max() + 1))
|
||||
assert ep_ids == expected_ep_ids, f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}"
|
||||
|
||||
# Create symlink to raw videos directory (that needs to be absolute not relative)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
videos_dir = out_dir / "videos"
|
||||
videos_dir.symlink_to((raw_dir / "videos").absolute())
|
||||
|
||||
# sanity check the video paths are well formated
|
||||
for key in df:
|
||||
if "observation.images." not in key:
|
||||
continue
|
||||
for ep_idx in ep_ids:
|
||||
video_path = videos_dir / f"{key}_episode_{ep_idx:06d}.mp4"
|
||||
assert video_path.exists(), f"Video file not found in {video_path}"
|
||||
|
||||
data_dict = {}
|
||||
for key in df:
|
||||
# is video frame
|
||||
if "observation.images." in key:
|
||||
# we need `[0] because dora only use arrays, so single values are encapsulated into a list.
|
||||
# it is the case for video_frame dictionary = [{"path": ..., "timestamp": ...}]
|
||||
data_dict[key] = [video_frame[0] for video_frame in df[key].values]
|
||||
|
||||
# sanity check the video path is well formated
|
||||
video_path = videos_dir.parent / data_dict[key][0]["path"]
|
||||
assert video_path.exists(), f"Video file not found in {video_path}"
|
||||
# is number
|
||||
elif df[key].iloc[0].ndim == 0 or df[key].iloc[0].shape[0] == 1:
|
||||
data_dict[key] = torch.from_numpy(df[key].values)
|
||||
# is vector
|
||||
elif df[key].iloc[0].shape[0] > 1:
|
||||
data_dict[key] = torch.stack([torch.from_numpy(x.copy()) for x in df[key].values])
|
||||
else:
|
||||
raise ValueError(key)
|
||||
|
||||
# Get the episode index containing for each unique episode index
|
||||
first_ep_index_df = df.groupby("episode_index").agg(start_index=("index", "first")).reset_index()
|
||||
from_ = first_ep_index_df["start_index"].tolist()
|
||||
to_ = from_[1:] + [len(df)]
|
||||
episode_data_index = {
|
||||
"from": from_,
|
||||
"to": to_,
|
||||
}
|
||||
|
||||
return data_dict, episode_data_index
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
features = {}
|
||||
|
||||
keys = [key for key in data_dict if "observation.images." in key]
|
||||
for key in keys:
|
||||
if video:
|
||||
features[key] = VideoFrame()
|
||||
else:
|
||||
features[key] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.velocity" in data_dict:
|
||||
features["observation.velocity"] = Sequence(
|
||||
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.effort" in data_dict:
|
||||
features["observation.effort"] = Sequence(
|
||||
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
features["timestamp"] = Value(dtype="float32", id=None)
|
||||
features["next.done"] = Value(dtype="bool", id=None)
|
||||
features["index"] = Value(dtype="int64", id=None)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||
init_logging()
|
||||
|
||||
if debug:
|
||||
logging.warning("debug=True not implemented. Falling back to debug=False.")
|
||||
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
if fps is None:
|
||||
fps = 30
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
if not video:
|
||||
raise NotImplementedError()
|
||||
|
||||
data_df, episode_data_index = load_from_raw(raw_dir, out_dir)
|
||||
hf_dataset = to_hf_dataset(data_df, video)
|
||||
|
||||
info = {
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
return hf_dataset, episode_data_index, info
|
||||
@@ -28,11 +28,11 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
||||
raise ValueError("`n_envs must be at least 1")
|
||||
|
||||
kwargs = {
|
||||
# "obs_type": "pixels_agent_pos",
|
||||
# "render_mode": "rgb_array",
|
||||
"obs_type": "pixels_agent_pos",
|
||||
"render_mode": "rgb_array",
|
||||
"max_episode_steps": cfg.env.episode_length,
|
||||
# "visualization_width": 384,
|
||||
# "visualization_height": 384,
|
||||
"visualization_width": 384,
|
||||
"visualization_height": 384,
|
||||
}
|
||||
|
||||
package_name = f"gym_{cfg.env.name}"
|
||||
|
||||
14
lerobot/configs/env/dora.yaml
vendored
14
lerobot/configs/env/dora.yaml
vendored
@@ -1,14 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
fps: 30
|
||||
|
||||
env:
|
||||
name: dora
|
||||
task: DoraAloha-v0
|
||||
# from_pixels: True
|
||||
# pixels_only: False
|
||||
# image_size: [3, 480, 640]
|
||||
episode_length: 400
|
||||
# fps: ${fps}
|
||||
# state_dim: 14
|
||||
# action_dim: 14
|
||||
@@ -1,101 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
seed: 1000
|
||||
dataset_repo_id: cadene/aloha_v2_static_dora_test
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.cam_right_wrist:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_left_wrist:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_high:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_low:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
eval_freq: 99999999999999
|
||||
save_freq: 1000
|
||||
log_freq: 100
|
||||
save_model: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100 # chunk_size
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.cam_right_wrist: [3, 480, 640]
|
||||
observation.images.cam_left_wrist: [3, 480, 640]
|
||||
observation.images.cam_high: [3, 480, 640]
|
||||
observation.images.cam_low: [3, 480, 640]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.cam_right_wrist: mean_std
|
||||
observation.images.cam_left_wrist: mean_std
|
||||
observation.images.cam_high: mean_std
|
||||
observation.images.cam_low: mean_std
|
||||
observation.state: mean_std
|
||||
output_normalization_modes:
|
||||
action: mean_std
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_momentum: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
||||
@@ -84,14 +84,10 @@ def get_from_raw_to_lerobot_format_fn(raw_format):
|
||||
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "aloha_hdf5":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "aloha_dora":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_dora_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "xarm_pkl":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?"
|
||||
)
|
||||
raise ValueError(raw_format)
|
||||
|
||||
return from_raw_to_lerobot_format
|
||||
|
||||
@@ -144,8 +140,7 @@ def push_videos_to_hub(repo_id, videos_dir, revision):
|
||||
|
||||
|
||||
def push_dataset_to_hub(
|
||||
input_data_dir: Path,
|
||||
output_data_dir: Path,
|
||||
data_dir: Path,
|
||||
dataset_id: str,
|
||||
raw_format: str | None,
|
||||
community_id: str,
|
||||
@@ -162,33 +157,34 @@ def push_dataset_to_hub(
|
||||
):
|
||||
repo_id = f"{community_id}/{dataset_id}"
|
||||
|
||||
meta_data_dir = output_data_dir / "meta_data"
|
||||
videos_dir = output_data_dir / "videos"
|
||||
raw_dir = data_dir / f"{dataset_id}_raw"
|
||||
|
||||
out_dir = data_dir / repo_id
|
||||
meta_data_dir = out_dir / "meta_data"
|
||||
videos_dir = out_dir / "videos"
|
||||
|
||||
tests_out_dir = tests_data_dir / repo_id
|
||||
tests_meta_data_dir = tests_out_dir / "meta_data"
|
||||
tests_videos_dir = tests_out_dir / "videos"
|
||||
|
||||
if output_data_dir.exists():
|
||||
shutil.rmtree(output_data_dir)
|
||||
if out_dir.exists():
|
||||
shutil.rmtree(out_dir)
|
||||
|
||||
if tests_out_dir.exists() and save_tests_to_disk:
|
||||
shutil.rmtree(tests_out_dir)
|
||||
|
||||
if not input_data_dir.exists():
|
||||
download_raw(input_data_dir, dataset_id)
|
||||
if not raw_dir.exists():
|
||||
download_raw(raw_dir, dataset_id)
|
||||
|
||||
if raw_format is None:
|
||||
# TODO(rcadene, adilzouitine): implement auto_find_raw_format
|
||||
raise NotImplementedError()
|
||||
# raw_format = auto_find_raw_format(input_data_dir)
|
||||
# raw_format = auto_find_raw_format(raw_dir)
|
||||
|
||||
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
||||
|
||||
# convert dataset from original raw format to LeRobot format
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
|
||||
input_data_dir, output_data_dir, fps, video, debug
|
||||
)
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug)
|
||||
|
||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||
repo_id=repo_id,
|
||||
@@ -202,7 +198,7 @@ def push_dataset_to_hub(
|
||||
|
||||
if save_to_disk:
|
||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||
hf_dataset.save_to_disk(str(output_data_dir / "train"))
|
||||
hf_dataset.save_to_disk(str(out_dir / "train"))
|
||||
|
||||
if not dry_run or save_to_disk:
|
||||
# mandatory for upload
|
||||
@@ -236,25 +232,19 @@ def push_dataset_to_hub(
|
||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||
shutil.copy(videos_dir / fname, tests_videos_dir / fname)
|
||||
|
||||
if not save_to_disk and output_data_dir.exists():
|
||||
if not save_to_disk and out_dir.exists():
|
||||
# remove possible temporary files remaining in the output directory
|
||||
shutil.rmtree(output_data_dir)
|
||||
shutil.rmtree(out_dir)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--input-data-dir",
|
||||
"--data-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-data-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Root directory containing output dataset (e.g. `data/lerobot/aloha_mobile_chair` or `data/lerobot/pusht`).",
|
||||
help="Root directory containing datasets (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-id",
|
||||
|
||||
Reference in New Issue
Block a user