Refactor datasets into LeRobotDataset (#91)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi
2024-04-25 12:23:12 +02:00
committed by GitHub
parent e760e4cd63
commit 659c69a1c0
90 changed files with 167 additions and 352 deletions

View File

@@ -16,22 +16,18 @@ from pathlib import Path
from safetensors.torch import save_file
from lerobot.common.datasets.pusht import PushtDataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
def save_dataset_to_safetensors(output_dir, dataset_id="pusht"):
data_dir = Path(output_dir) / dataset_id
def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
data_dir = Path(output_dir) / repo_id
if data_dir.exists():
shutil.rmtree(data_dir)
data_dir.mkdir(parents=True, exist_ok=True)
# TODO(rcadene): make it work for all datasets with LeRobotDataset(repo_id)
dataset = PushtDataset(
dataset_id=dataset_id,
split="train",
)
dataset = LeRobotDataset(repo_id)
# save 2 first frames of first episode
i = dataset.episode_data_index["from"][0].item()