Refactor datasets into LeRobotDataset

This commit is contained in:
Cadene
2024-04-21 12:00:32 +00:00
parent 6d56bcb5de
commit 7626b9a4a3
90 changed files with 167 additions and 352 deletions

View File

@@ -16,22 +16,18 @@ from pathlib import Path
from safetensors.torch import save_file
from lerobot.common.datasets.pusht import PushtDataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
def save_dataset_to_safetensors(output_dir, dataset_id="pusht"):
data_dir = Path(output_dir) / dataset_id
def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
data_dir = Path(output_dir) / repo_id
if data_dir.exists():
shutil.rmtree(data_dir)
data_dir.mkdir(parents=True, exist_ok=True)
# TODO(rcadene): make it work for all datasets with LeRobotDataset(repo_id)
dataset = PushtDataset(
dataset_id=dataset_id,
split="train",
)
dataset = LeRobotDataset(repo_id)
# save 2 first frames of first episode
i = dataset.episode_data_index["from"][0].item()