forked from tangger/lerobot
Improve visualize_dataset, Improve AbstractReplayBuffer, Small improvements
This commit is contained in:
@@ -10,7 +10,9 @@ from lerobot.common.envs.transforms import NormalizeTransform
|
||||
DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
|
||||
|
||||
|
||||
def make_offline_buffer(cfg, sampler=None):
|
||||
def make_offline_buffer(
|
||||
cfg, overwrite_sampler=None, normalize=True, overwrite_batch_size=None, overwrite_prefetch=None
|
||||
):
|
||||
if cfg.policy.balanced_sampling:
|
||||
assert cfg.online_steps > 0
|
||||
batch_size = None
|
||||
@@ -23,9 +25,13 @@ def make_offline_buffer(cfg, sampler=None):
|
||||
pin_memory = cfg.device == "cuda"
|
||||
prefetch = cfg.prefetch
|
||||
|
||||
overwrite_sampler = sampler is not None
|
||||
if overwrite_batch_size is not None:
|
||||
batch_size = overwrite_batch_size
|
||||
|
||||
if not overwrite_sampler:
|
||||
if overwrite_prefetch is not None:
|
||||
prefetch = overwrite_prefetch
|
||||
|
||||
if overwrite_sampler is None:
|
||||
# TODO(rcadene): move batch_size outside
|
||||
num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon
|
||||
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
|
||||
@@ -46,6 +52,8 @@ def make_offline_buffer(cfg, sampler=None):
|
||||
num_slices=num_traj_per_batch,
|
||||
strict_length=False,
|
||||
)
|
||||
else:
|
||||
sampler = overwrite_sampler
|
||||
|
||||
if cfg.env.name == "simxarm":
|
||||
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||
@@ -70,30 +78,31 @@ def make_offline_buffer(cfg, sampler=None):
|
||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
)
|
||||
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
|
||||
stats = offline_buffer.compute_or_load_stats()
|
||||
in_keys = [("observation", "state"), ("action")]
|
||||
if normalize:
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
|
||||
stats = offline_buffer.compute_or_load_stats()
|
||||
in_keys = [("observation", "state"), ("action")]
|
||||
|
||||
if cfg.policy == "tdmpc":
|
||||
# TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc
|
||||
in_keys.append(("observation", "image"))
|
||||
# since we use next observations in tdmpc
|
||||
in_keys.append(("next", "observation", "image"))
|
||||
in_keys.append(("next", "observation", "state"))
|
||||
if cfg.policy == "tdmpc":
|
||||
for key in offline_buffer.image_keys:
|
||||
# TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc
|
||||
in_keys.append(key)
|
||||
# since we use next observations in tdmpc
|
||||
in_keys.append(("next", *key))
|
||||
in_keys.append(("next", "observation", "state"))
|
||||
|
||||
if cfg.policy == "diffusion" and cfg.env.name == "pusht":
|
||||
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
||||
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
||||
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
||||
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
if cfg.policy == "diffusion" and cfg.env.name == "pusht":
|
||||
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
||||
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
||||
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
||||
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
|
||||
transform = NormalizeTransform(stats, in_keys, mode="min_max")
|
||||
offline_buffer.set_transform(transform)
|
||||
transform = NormalizeTransform(stats, in_keys, mode="min_max")
|
||||
offline_buffer.set_transform(transform)
|
||||
|
||||
if not overwrite_sampler:
|
||||
num_steps = len(offline_buffer)
|
||||
index = torch.arange(0, num_steps, 1)
|
||||
index = torch.arange(0, offline_buffer.num_frames, 1)
|
||||
sampler.extend(index)
|
||||
|
||||
return offline_buffer
|
||||
|
||||
Reference in New Issue
Block a user