Refactor push_dataset_to_hub (#118)

This commit is contained in:
Remi
2024-04-30 14:25:41 +02:00
committed by GitHub
parent 2765877f28
commit e4e739f4f8
25 changed files with 1089 additions and 1192 deletions

View File

@@ -11,6 +11,7 @@ Example usage:
`python tests/script/save_dataset_to_safetensors.py`
"""
import os
import shutil
from pathlib import Path
@@ -21,54 +22,56 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
data_dir = Path(output_dir) / repo_id
repo_dir = Path(output_dir) / repo_id
if data_dir.exists():
shutil.rmtree(data_dir)
if repo_dir.exists():
shutil.rmtree(repo_dir)
data_dir.mkdir(parents=True, exist_ok=True)
dataset = LeRobotDataset(repo_id=repo_id, root=data_dir)
repo_dir.mkdir(parents=True, exist_ok=True)
dataset = LeRobotDataset(
repo_id=repo_id, root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
)
# save 2 first frames of first episode
i = dataset.episode_data_index["from"][0].item()
save_file(dataset[i], data_dir / f"frame_{i}.safetensors")
save_file(dataset[i + 1], data_dir / f"frame_{i+1}.safetensors")
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
# save 2 frames at the middle of first episode
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
save_file(dataset[i], data_dir / f"frame_{i}.safetensors")
save_file(dataset[i + 1], data_dir / f"frame_{i+1}.safetensors")
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
# save 2 last frames of first episode
i = dataset.episode_data_index["to"][0].item()
save_file(dataset[i - 2], data_dir / f"frame_{i-2}.safetensors")
save_file(dataset[i - 1], data_dir / f"frame_{i-1}.safetensors")
save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
# TODO(rcadene): Enable testing on second and last episode
# We currently cant because our test dataset only contains the first episode
# # save 2 first frames of second episode
# i = dataset.episode_data_index["from"][1].item()
# save_file(dataset[i], data_dir / f"frame_{i}.safetensors")
# save_file(dataset[i+1], data_dir / f"frame_{i+1}.safetensors")
# save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
# save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
# # save 2 last frames of second episode
# i = dataset.episode_data_index["to"][1].item()
# save_file(dataset[i-2], data_dir / f"frame_{i-2}.safetensors")
# save_file(dataset[i-1], data_dir / f"frame_{i-1}.safetensors")
# save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
# save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
# # save 2 last frames of last episode
# i = dataset.episode_data_index["to"][-1].item()
# save_file(dataset[i-2], data_dir / f"frame_{i-2}.safetensors")
# save_file(dataset[i-1], data_dir / f"frame_{i-1}.safetensors")
# save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
# save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
if __name__ == "__main__":
available_datasets = [
"lerobot/pusht",
"lerobot/xarm_push_medium",
"lerobot/xarm_lift_medium",
"lerobot/aloha_sim_insertion_human",
"lerobot/umi_cup_in_the_wild",
# "lerobot/umi_cup_in_the_wild",
]
for dataset in available_datasets:
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=dataset)