Improve visualize_dataset, Improve AbstractReplayBuffer, Small improvements

This commit is contained in:
Remi Cadene
2024-03-06 10:14:03 +00:00
parent 2f80d71c3e
commit f95ecd66fc
7 changed files with 195 additions and 150 deletions

View File

@@ -10,7 +10,9 @@ from lerobot.common.envs.transforms import NormalizeTransform
DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
def make_offline_buffer(cfg, sampler=None):
def make_offline_buffer(
cfg, overwrite_sampler=None, normalize=True, overwrite_batch_size=None, overwrite_prefetch=None
):
if cfg.policy.balanced_sampling:
assert cfg.online_steps > 0
batch_size = None
@@ -23,9 +25,13 @@ def make_offline_buffer(cfg, sampler=None):
pin_memory = cfg.device == "cuda"
prefetch = cfg.prefetch
overwrite_sampler = sampler is not None
if overwrite_batch_size is not None:
batch_size = overwrite_batch_size
if not overwrite_sampler:
if overwrite_prefetch is not None:
prefetch = overwrite_prefetch
if overwrite_sampler is None:
# TODO(rcadene): move batch_size outside
num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
@@ -46,6 +52,8 @@ def make_offline_buffer(cfg, sampler=None):
num_slices=num_traj_per_batch,
strict_length=False,
)
else:
sampler = overwrite_sampler
if cfg.env.name == "simxarm":
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
@@ -70,30 +78,31 @@ def make_offline_buffer(cfg, sampler=None):
prefetch=prefetch if isinstance(prefetch, int) else None,
)
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
stats = offline_buffer.compute_or_load_stats()
in_keys = [("observation", "state"), ("action")]
if normalize:
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
stats = offline_buffer.compute_or_load_stats()
in_keys = [("observation", "state"), ("action")]
if cfg.policy == "tdmpc":
# TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc
in_keys.append(("observation", "image"))
# since we use next observations in tdmpc
in_keys.append(("next", "observation", "image"))
in_keys.append(("next", "observation", "state"))
if cfg.policy == "tdmpc":
for key in offline_buffer.image_keys:
# TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc
in_keys.append(key)
# since we use next observations in tdmpc
in_keys.append(("next", *key))
in_keys.append(("next", "observation", "state"))
if cfg.policy == "diffusion" and cfg.env.name == "pusht":
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
if cfg.policy == "diffusion" and cfg.env.name == "pusht":
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
transform = NormalizeTransform(stats, in_keys, mode="min_max")
offline_buffer.set_transform(transform)
transform = NormalizeTransform(stats, in_keys, mode="min_max")
offline_buffer.set_transform(transform)
if not overwrite_sampler:
num_steps = len(offline_buffer)
index = torch.arange(0, num_steps, 1)
index = torch.arange(0, offline_buffer.num_frames, 1)
sampler.extend(index)
return offline_buffer