Fix aloha real-world datasets (#175)
This commit is contained in:
@@ -61,13 +61,21 @@ available_datasets_per_env = {
|
||||
"lerobot/aloha_sim_insertion_scripted",
|
||||
"lerobot/aloha_sim_transfer_cube_human",
|
||||
"lerobot/aloha_sim_transfer_cube_scripted",
|
||||
"lerobot/aloha_sim_insertion_human_image",
|
||||
"lerobot/aloha_sim_insertion_scripted_image",
|
||||
"lerobot/aloha_sim_transfer_cube_human_image",
|
||||
"lerobot/aloha_sim_transfer_cube_scripted_image",
|
||||
],
|
||||
"pusht": ["lerobot/pusht"],
|
||||
"pusht": ["lerobot/pusht", "lerobot/pusht_image"],
|
||||
"xarm": [
|
||||
"lerobot/xarm_lift_medium",
|
||||
"lerobot/xarm_lift_medium_replay",
|
||||
"lerobot/xarm_push_medium",
|
||||
"lerobot/xarm_push_medium_replay",
|
||||
"lerobot/xarm_lift_medium_image",
|
||||
"lerobot/xarm_lift_medium_replay_image",
|
||||
"lerobot/xarm_push_medium_image",
|
||||
"lerobot/xarm_push_medium_replay_image",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from lerobot.common.datasets.utils import (
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
|
||||
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
CODEBASE_VERSION = "v1.3"
|
||||
CODEBASE_VERSION = "v1.4"
|
||||
|
||||
|
||||
class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
@@ -24,17 +24,16 @@ import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import tqdm
|
||||
|
||||
ALOHA_RAW_URLS_DIR = "lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls"
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
def download_raw(raw_dir, dataset_id):
|
||||
if "pusht" in dataset_id:
|
||||
if "aloha" in dataset_id or "image" in dataset_id:
|
||||
download_hub(raw_dir, dataset_id)
|
||||
elif "pusht" in dataset_id:
|
||||
download_pusht(raw_dir)
|
||||
elif "xarm" in dataset_id:
|
||||
download_xarm(raw_dir)
|
||||
elif "aloha" in dataset_id:
|
||||
download_aloha(raw_dir, dataset_id)
|
||||
elif "umi" in dataset_id:
|
||||
download_umi(raw_dir)
|
||||
else:
|
||||
@@ -103,37 +102,13 @@ def download_xarm(raw_dir: Path):
|
||||
zip_path.unlink()
|
||||
|
||||
|
||||
def download_aloha(raw_dir: Path, dataset_id: str):
|
||||
import gdown
|
||||
|
||||
subset_id = dataset_id.replace("aloha_", "")
|
||||
urls_path = Path(ALOHA_RAW_URLS_DIR) / f"{subset_id}.txt"
|
||||
assert urls_path.exists(), f"{subset_id}.txt not found in '{ALOHA_RAW_URLS_DIR}' directory."
|
||||
|
||||
with open(urls_path) as f:
|
||||
# strip lines and ignore empty lines
|
||||
urls = [url.strip() for url in f if url.strip()]
|
||||
|
||||
# sanity check
|
||||
for url in urls:
|
||||
assert (
|
||||
"drive.google.com/drive/folders" in url or "drive.google.com/file" in url
|
||||
), f"Wrong url provided '{url}' in file '{urls_path}'."
|
||||
|
||||
def download_hub(raw_dir: Path, dataset_id: str):
|
||||
raw_dir = Path(raw_dir)
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logging.info(f"Start downloading from google drive for {dataset_id}")
|
||||
for url in urls:
|
||||
if "drive.google.com/drive/folders" in url:
|
||||
# when a folder url is given, download up to 50 files from the folder
|
||||
gdown.download_folder(url, output=str(raw_dir), remaining_ok=True)
|
||||
|
||||
elif "drive.google.com/file" in url:
|
||||
# because of the 50 files limit per folder, we download the remaining files (file by file)
|
||||
gdown.download(url, output=str(raw_dir), fuzzy=True)
|
||||
|
||||
logging.info(f"End downloading from google drive for {dataset_id}")
|
||||
logging.info(f"Start downloading from huggingface.co/cadene for {dataset_id}")
|
||||
snapshot_download(f"cadene/{dataset_id}_raw", repo_type="dataset", local_dir=raw_dir)
|
||||
logging.info(f"Finish downloading from huggingface.co/cadene for {dataset_id}")
|
||||
|
||||
|
||||
def download_umi(raw_dir: Path):
|
||||
@@ -148,21 +123,30 @@ def download_umi(raw_dir: Path):
|
||||
if __name__ == "__main__":
|
||||
data_dir = Path("data")
|
||||
dataset_ids = [
|
||||
"pusht_image",
|
||||
"xarm_lift_medium_image",
|
||||
"xarm_lift_medium_replay_image",
|
||||
"xarm_push_medium_image",
|
||||
"xarm_push_medium_replay_image",
|
||||
"aloha_sim_insertion_human_image",
|
||||
"aloha_sim_insertion_scripted_image",
|
||||
"aloha_sim_transfer_cube_human_image",
|
||||
"aloha_sim_transfer_cube_scripted_image",
|
||||
"pusht",
|
||||
"xarm_lift_medium",
|
||||
"xarm_lift_medium_replay",
|
||||
"xarm_push_medium",
|
||||
"xarm_push_medium_replay",
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
"aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
"aloha_mobile_cabinet",
|
||||
"aloha_mobile_chair",
|
||||
"aloha_mobile_elevator",
|
||||
"aloha_mobile_shrimp",
|
||||
"aloha_mobile_wash_pan",
|
||||
"aloha_mobile_wipe_wine",
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
"aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
"aloha_static_battery",
|
||||
"aloha_static_candy",
|
||||
"aloha_static_coffee",
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
Contains utilities to process raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act
|
||||
"""
|
||||
|
||||
import re
|
||||
import gc
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
@@ -79,10 +79,8 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
id_from = 0
|
||||
|
||||
for ep_path in tqdm.tqdm(hdf5_files, total=len(hdf5_files)):
|
||||
for ep_idx, ep_path in tqdm.tqdm(enumerate(hdf5_files), total=len(hdf5_files)):
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
ep_idx = int(re.search(r"episode_(\d+)", ep_path.name).group(1))
|
||||
num_frames = ep["/action"].shape[0]
|
||||
|
||||
# last step of demonstration is considered done
|
||||
@@ -91,6 +89,10 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
if "/observations/qvel" in ep:
|
||||
velocity = torch.from_numpy(ep["/observations/qvel"][:])
|
||||
if "/observations/effort" in ep:
|
||||
effort = torch.from_numpy(ep["/observations/effort"][:])
|
||||
|
||||
ep_dict = {}
|
||||
|
||||
@@ -131,6 +133,10 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
ep_dict["observation.state"] = state
|
||||
if "/observations/velocity" in ep:
|
||||
ep_dict["observation.velocity"] = velocity
|
||||
if "/observations/effort" in ep:
|
||||
ep_dict["observation.effort"] = effort
|
||||
ep_dict["action"] = action
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
@@ -146,6 +152,8 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
|
||||
id_from += num_frames
|
||||
|
||||
gc.collect()
|
||||
|
||||
# process first episode only
|
||||
if debug:
|
||||
break
|
||||
@@ -167,6 +175,14 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.velocity" in data_dict:
|
||||
features["observation.velocity"] = Sequence(
|
||||
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.effort" in data_dict:
|
||||
features["observation.effort"] = Sequence(
|
||||
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
|
||||
@@ -25,7 +25,6 @@ python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--dataset-id pusht \
|
||||
--raw-format pusht_zarr \
|
||||
--community-id lerobot \
|
||||
--revision v1.2 \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
@@ -36,7 +35,6 @@ python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--dataset-id xarm_lift_medium \
|
||||
--raw-format xarm_pkl \
|
||||
--community-id lerobot \
|
||||
--revision v1.2 \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
@@ -47,7 +45,6 @@ python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--dataset-id aloha_sim_insertion_scripted \
|
||||
--raw-format aloha_hdf5 \
|
||||
--community-id lerobot \
|
||||
--revision v1.2 \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
@@ -58,7 +55,6 @@ python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--dataset-id umi_cup_in_the_wild \
|
||||
--raw-format umi_zarr \
|
||||
--community-id lerobot \
|
||||
--revision v1.2 \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
@@ -227,8 +223,7 @@ def push_dataset_to_hub(
|
||||
test_hf_dataset = test_hf_dataset.with_format(None)
|
||||
test_hf_dataset.save_to_disk(str(tests_out_dir / "train"))
|
||||
|
||||
# copy meta data to tests directory
|
||||
shutil.copytree(meta_data_dir, tests_meta_data_dir)
|
||||
save_meta_data(info, stats, episode_data_index, tests_meta_data_dir)
|
||||
|
||||
# copy videos of first episode to tests directory
|
||||
episode_index = 0
|
||||
@@ -237,6 +232,10 @@ def push_dataset_to_hub(
|
||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||
shutil.copy(videos_dir / fname, tests_videos_dir / fname)
|
||||
|
||||
if not save_to_disk and out_dir.exists():
|
||||
# remove possible temporary files remaining in the output directory
|
||||
shutil.rmtree(out_dir)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -314,7 +313,7 @@ def main():
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=16,
|
||||
default=8,
|
||||
help="Number of processes of Dataloader for computing the dataset statistics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user