Refactor datasets with abstract class

This commit is contained in:
Remi Cadene
2024-03-05 10:20:57 +00:00
parent e132a267aa
commit d4e0849970
4 changed files with 262 additions and 351 deletions

View File

@@ -5,31 +5,10 @@ from pathlib import Path
import torch
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
from lerobot.common.datasets.pusht import PushtExperienceReplay
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
from lerobot.common.envs.transforms import NormalizeTransform
DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
# TODO(rcadene): implement
# dataset_d4rl = D4RLExperienceReplay(
# dataset_id="maze2d-umaze-v1",
# split_trajs=False,
# batch_size=1,
# sampler=SamplerWithoutReplacement(drop_last=False),
# prefetch=4,
# direct_download=True,
# )
# dataset_openx = OpenXExperienceReplay(
# "cmu_stretch",
# batch_size=1,
# num_slices=1,
# #download="force",
# streaming=False,
# root="data",
# )
def make_offline_buffer(cfg, sampler=None):
if cfg.policy.balanced_sampling:
@@ -69,31 +48,49 @@ def make_offline_buffer(cfg, sampler=None):
)
if cfg.env.name == "simxarm":
# TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here
offline_buffer = SimxarmExperienceReplay(
f"xarm_{cfg.env.task}_medium",
# download="force",
download=True,
streaming=False,
root=str(DATA_DIR),
sampler=sampler,
batch_size=batch_size,
pin_memory=pin_memory,
prefetch=prefetch if isinstance(prefetch, int) else None,
)
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
clsfunc = SimxarmExperienceReplay
dataset_id = f"xarm_{cfg.env.task}_medium"
elif cfg.env.name == "pusht":
offline_buffer = PushtExperienceReplay(
"pusht",
streaming=False,
root=DATA_DIR,
sampler=sampler,
batch_size=batch_size,
pin_memory=pin_memory,
prefetch=prefetch if isinstance(prefetch, int) else None,
)
from lerobot.common.datasets.pusht import PushtExperienceReplay
clsfunc = PushtExperienceReplay
dataset_id = "pusht"
else:
raise ValueError(cfg.env.name)
offline_buffer = clsfunc(
dataset_id=dataset_id,
root=DATA_DIR,
sampler=sampler,
batch_size=batch_size,
pin_memory=pin_memory,
prefetch=prefetch if isinstance(prefetch, int) else None,
)
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
stats = offline_buffer.compute_or_load_stats()
in_keys = [("observation", "state"), ("action")]
if cfg.policy == "tdmpc":
# TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc
in_keys.append(("observation", "image"))
# since we use next observations in tdmpc
in_keys.append(("next", "observation", "image"))
in_keys.append(("next", "observation", "state"))
if cfg.policy == "diffusion" and cfg.env.name == "pusht":
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
transform = NormalizeTransform(stats, in_keys, mode="min_max")
offline_buffer.set_transform(transform)
if not overwrite_sampler:
num_steps = len(offline_buffer)
index = torch.arange(0, num_steps, 1)