Added pusht dataset auto-download

This commit is contained in:
Simon Alibert
2024-03-01 14:31:54 +01:00
parent ca948c1e5b
commit b862145e22
3 changed files with 57 additions and 28 deletions

View File

@@ -1,9 +1,13 @@
from pathlib import Path
import torch
from torchrl.data.replay_buffers import PrioritizedSliceSampler
from lerobot.common.datasets.pusht import PushtExperienceReplay
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
DATA_PATH = Path("data/")
# TODO(rcadene): implement
# dataset_d4rl = D4RLExperienceReplay(
@@ -60,7 +64,7 @@ def make_offline_buffer(cfg, sampler=None):
# download="force",
download=True,
streaming=False,
root="data",
root=str(DATA_PATH),
sampler=sampler,
batch_size=batch_size,
pin_memory=pin_memory,
@@ -69,11 +73,9 @@ def make_offline_buffer(cfg, sampler=None):
elif cfg.env.name == "pusht":
offline_buffer = PushtExperienceReplay(
"pusht",
# download="force",
# TODO(aliberts): automate download
download=False,
download=True,
streaming=False,
root="data",
root=DATA_PATH,
sampler=sampler,
batch_size=batch_size,
pin_memory=pin_memory,