Add Normalize, non_blocking=True in tdmpc, tdmpc run (TODO: diffusion)

This commit is contained in:
Remi Cadene
2024-03-02 15:53:29 +00:00
parent b5a2f460ea
commit 1ae6205269
9 changed files with 183 additions and 67 deletions

View File

@@ -1,3 +1,4 @@
import os
from pathlib import Path
import torch
@@ -6,7 +7,7 @@ from torchrl.data.replay_buffers import PrioritizedSliceSampler
from lerobot.common.datasets.pusht import PushtExperienceReplay
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
DATA_PATH = Path("data/")
DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
# TODO(rcadene): implement
@@ -64,7 +65,7 @@ def make_offline_buffer(cfg, sampler=None):
# download="force",
download=True,
streaming=False,
root=str(DATA_PATH),
root=str(DATA_DIR),
sampler=sampler,
batch_size=batch_size,
pin_memory=pin_memory,
@@ -74,7 +75,7 @@ def make_offline_buffer(cfg, sampler=None):
offline_buffer = PushtExperienceReplay(
"pusht",
streaming=False,
root=DATA_PATH,
root=DATA_DIR,
sampler=sampler,
batch_size=batch_size,
pin_memory=pin_memory,