Add pusht dataset (TODO verify reward is aligned), Refactor visualize_dataset, Add video_dir, fps, state_dim, action_dim to config (Training works)

This commit is contained in:
Cadene
2024-02-21 00:49:40 +00:00
parent 3dc14b5576
commit ece89730e6
8 changed files with 222 additions and 111 deletions

View File

@@ -5,18 +5,21 @@ from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
from rl.torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler
def make_offline_buffer(cfg):
def make_offline_buffer(cfg, sampler=None):
num_traj_per_batch = cfg.batch_size # // cfg.horizon
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
sampler = PrioritizedSliceSampler(
max_capacity=100_000,
alpha=cfg.per_alpha,
beta=cfg.per_beta,
num_slices=num_traj_per_batch,
strict_length=False,
)
overwrite_sampler = sampler is not None
if not overwrite_sampler:
num_traj_per_batch = cfg.batch_size # // cfg.horizon
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
sampler = PrioritizedSliceSampler(
max_capacity=100_000,
alpha=cfg.per_alpha,
beta=cfg.per_beta,
num_slices=num_traj_per_batch,
strict_length=False,
)
if cfg.env == "simxarm":
# TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here
@@ -30,9 +33,9 @@ def make_offline_buffer(cfg):
)
elif cfg.env == "pusht":
offline_buffer = PushtExperienceReplay(
f"xarm_{cfg.task}_medium",
"pusht",
# download="force",
download=True,
download=False,
streaming=False,
root="data",
sampler=sampler,
@@ -40,8 +43,9 @@ def make_offline_buffer(cfg):
else:
raise ValueError(cfg.env)
num_steps = len(offline_buffer)
index = torch.arange(0, num_steps, 1)
sampler.extend(index)
if not overwrite_sampler:
num_steps = len(offline_buffer)
index = torch.arange(0, num_steps, 1)
sampler.extend(index)
return offline_buffer