Use v1.1, hf_transform_to_torch, Add 3 xarm datasets

This commit is contained in:
Cadene
2024-04-19 18:17:13 +00:00
parent 714a776277
commit 35a573c98e
12 changed files with 122 additions and 74 deletions

View File

@@ -19,7 +19,7 @@ from huggingface_hub import HfApi
from PIL import Image as PILImage
from safetensors.torch import save_file
from lerobot.common.datasets.utils import compute_stats, convert_images_to_channel_first_tensors, flatten_dict
from lerobot.common.datasets.utils import compute_stats, flatten_dict, hf_transform_to_torch
def download_and_upload(root, revision, dataset_id):
@@ -127,7 +127,11 @@ def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dat
# copy in tests folder, the first episode and the meta_data directory
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train")
hf_dataset.select(range(num_items_first_ep)).with_format("torch").save_to_disk(
f"tests/data/{dataset_id}/train"
)
if Path(f"tests/data/{dataset_id}/meta_data").exists():
shutil.rmtree(f"tests/data/{dataset_id}/meta_data")
shutil.copytree(meta_data_dir, f"tests/data/{dataset_id}/meta_data")
@@ -262,8 +266,7 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
}
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
hf_dataset.set_transform(hf_transform_to_torch)
info = {
"fps": fps,
@@ -274,13 +277,14 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
def download_and_upload_xarm(root, revision, dataset_id, fps=15):
root = Path(root)
raw_dir = root / f"{dataset_id}_raw"
raw_dir = root / "xarm_datasets_raw"
if not raw_dir.exists():
import zipfile
import gdown
raw_dir.mkdir(parents=True, exist_ok=True)
# from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
zip_path = raw_dir / "data.zip"
gdown.download(url, str(zip_path), quiet=False)
@@ -361,8 +365,7 @@ def download_and_upload_xarm(root, revision, dataset_id, fps=15):
}
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
hf_dataset.set_transform(hf_transform_to_torch)
info = {
"fps": fps,
@@ -468,8 +471,6 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
# "next.reward": reward,
"next.done": done,
# "next.success": success,
"episode_data_index_from": torch.tensor([id_from] * num_frames),
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
}
)
@@ -499,8 +500,7 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
}
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
hf_dataset.set_transform(hf_transform_to_torch)
info = {
"fps": fps,
@@ -515,11 +515,14 @@ if __name__ == "__main__":
dataset_ids = [
"pusht",
# "xarm_lift_medium",
# "aloha_sim_insertion_human",
# "aloha_sim_insertion_scripted",
# "aloha_sim_transfer_cube_human",
# "aloha_sim_transfer_cube_scripted",
"xarm_lift_medium",
"xarm_lift_medium_replay",
"xarm_push_medium",
"xarm_push_medium_replay",
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
]
for dataset_id in dataset_ids:
download_and_upload(root, revision, dataset_id)