Refactor train, eval_policy, logger, Add diffusion.yaml (WIP)

This commit is contained in:
Cadene
2024-02-26 01:10:09 +00:00
parent 5a219fed6e
commit 21670dce90
12 changed files with 306 additions and 443 deletions

View File

@@ -4,6 +4,26 @@ from lerobot.common.datasets.pusht import PushtExperienceReplay
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
from rl.torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler
# TODO(rcadene): implement
# dataset_d4rl = D4RLExperienceReplay(
# dataset_id="maze2d-umaze-v1",
# split_trajs=False,
# batch_size=1,
# sampler=SamplerWithoutReplacement(drop_last=False),
# prefetch=4,
# direct_download=True,
# )
# dataset_openx = OpenXExperienceReplay(
# "cmu_stretch",
# batch_size=1,
# num_slices=1,
# #download="force",
# streaming=False,
# root="data",
# )
def make_offline_buffer(cfg, sampler=None):