Added pusht dataset auto-download
This commit is contained in:
@@ -1,9 +1,13 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
||||
|
||||
from lerobot.common.datasets.pusht import PushtExperienceReplay
|
||||
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||
|
||||
DATA_PATH = Path("data/")
|
||||
|
||||
# TODO(rcadene): implement
|
||||
|
||||
# dataset_d4rl = D4RLExperienceReplay(
|
||||
@@ -60,7 +64,7 @@ def make_offline_buffer(cfg, sampler=None):
|
||||
# download="force",
|
||||
download=True,
|
||||
streaming=False,
|
||||
root="data",
|
||||
root=str(DATA_PATH),
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
pin_memory=pin_memory,
|
||||
@@ -69,11 +73,9 @@ def make_offline_buffer(cfg, sampler=None):
|
||||
elif cfg.env.name == "pusht":
|
||||
offline_buffer = PushtExperienceReplay(
|
||||
"pusht",
|
||||
# download="force",
|
||||
# TODO(aliberts): automate download
|
||||
download=False,
|
||||
download=True,
|
||||
streaming=False,
|
||||
root="data",
|
||||
root=DATA_PATH,
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
pin_memory=pin_memory,
|
||||
|
||||
Reference in New Issue
Block a user