Fix diffusion (rm transpose), Add prefetch

This commit is contained in:
Cadene
2024-02-28 17:45:01 +00:00
parent cf5063e50e
commit ac90b9c3ee
6 changed files with 52 additions and 11 deletions

View File

@@ -1,9 +1,8 @@
import torch
from torchrl.data.replay_buffers import PrioritizedSliceSampler
from lerobot.common.datasets.pusht import PushtExperienceReplay
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
# from rl.torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler
from torchrl.data.replay_buffers import PrioritizedSliceSampler
# TODO(rcadene): implement
@@ -27,6 +26,17 @@ from torchrl.data.replay_buffers import PrioritizedSliceSampler
def make_offline_buffer(cfg, sampler=None):
if cfg.policy.balanced_sampling:
assert cfg.online_steps > 0
batch_size = None
pin_memory = False
prefetch = None
else:
assert cfg.online_steps == 0
num_slices = cfg.policy.batch_size
batch_size = cfg.policy.horizon * num_slices
pin_memory = cfg.device == "cuda"
prefetch = cfg.prefetch
overwrite_sampler = sampler is not None
@@ -52,6 +62,9 @@ def make_offline_buffer(cfg, sampler=None):
streaming=False,
root="data",
sampler=sampler,
batch_size=batch_size,
pin_memory=pin_memory,
prefetch=prefetch,
)
elif cfg.env.name == "pusht":
offline_buffer = PushtExperienceReplay(
@@ -61,6 +74,9 @@ def make_offline_buffer(cfg, sampler=None):
streaming=False,
root="data",
sampler=sampler,
batch_size=batch_size,
pin_memory=pin_memory,
prefetch=prefetch,
)
else:
raise ValueError(cfg.env.name)