Fix diffusion (rm transpose), Add prefetch
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
import torch
|
||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
||||
|
||||
from lerobot.common.datasets.pusht import PushtExperienceReplay
|
||||
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||
# from rl.torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler
|
||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
||||
|
||||
# TODO(rcadene): implement
|
||||
|
||||
@@ -27,6 +26,17 @@ from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
||||
|
||||
|
||||
def make_offline_buffer(cfg, sampler=None):
|
||||
if cfg.policy.balanced_sampling:
|
||||
assert cfg.online_steps > 0
|
||||
batch_size = None
|
||||
pin_memory = False
|
||||
prefetch = None
|
||||
else:
|
||||
assert cfg.online_steps == 0
|
||||
num_slices = cfg.policy.batch_size
|
||||
batch_size = cfg.policy.horizon * num_slices
|
||||
pin_memory = cfg.device == "cuda"
|
||||
prefetch = cfg.prefetch
|
||||
|
||||
overwrite_sampler = sampler is not None
|
||||
|
||||
@@ -52,6 +62,9 @@ def make_offline_buffer(cfg, sampler=None):
|
||||
streaming=False,
|
||||
root="data",
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
)
|
||||
elif cfg.env.name == "pusht":
|
||||
offline_buffer = PushtExperienceReplay(
|
||||
@@ -61,6 +74,9 @@ def make_offline_buffer(cfg, sampler=None):
|
||||
streaming=False,
|
||||
root="data",
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
)
|
||||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
Reference in New Issue
Block a user