Add Normalize, non_blocking=True in tdmpc, tdmpc run (TODO: diffusion)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
@@ -6,7 +7,7 @@ from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
||||
from lerobot.common.datasets.pusht import PushtExperienceReplay
|
||||
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||
|
||||
DATA_PATH = Path("data/")
|
||||
DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
|
||||
|
||||
# TODO(rcadene): implement
|
||||
|
||||
@@ -64,7 +65,7 @@ def make_offline_buffer(cfg, sampler=None):
|
||||
# download="force",
|
||||
download=True,
|
||||
streaming=False,
|
||||
root=str(DATA_PATH),
|
||||
root=str(DATA_DIR),
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
pin_memory=pin_memory,
|
||||
@@ -74,7 +75,7 @@ def make_offline_buffer(cfg, sampler=None):
|
||||
offline_buffer = PushtExperienceReplay(
|
||||
"pusht",
|
||||
streaming=False,
|
||||
root=DATA_PATH,
|
||||
root=DATA_DIR,
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
pin_memory=pin_memory,
|
||||
|
||||
Reference in New Issue
Block a user