Improve visualize_dataset, Improve AbstractReplayBuffer, Small improvements
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
import abc
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
@@ -50,6 +49,22 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
@property
|
||||
def stats_patterns(self) -> dict:
|
||||
return {
|
||||
("observation", "state"): "b c -> 1 c",
|
||||
("observation", "image"): "b c h w -> 1 c 1 1",
|
||||
("action"): "b c -> 1 c",
|
||||
}
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list:
|
||||
return [("observation", "image")]
|
||||
|
||||
@property
|
||||
def num_cameras(self) -> int:
|
||||
return len(self.image_keys)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self)
|
||||
@@ -67,7 +82,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||
stats = torch.load(stats_path)
|
||||
else:
|
||||
logging.info(f"compute_stats and save to {stats_path}")
|
||||
stats = self._compute_stats(self._storage._storage, num_batch, batch_size)
|
||||
stats = self._compute_stats(num_batch, batch_size)
|
||||
torch.save(stats, stats_path)
|
||||
return stats
|
||||
|
||||
@@ -85,101 +100,59 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||
def _is_downloaded(self) -> bool:
|
||||
return self.data_dir.is_dir()
|
||||
|
||||
def _compute_stats(self, storage, num_batch=100, batch_size=32):
|
||||
def _compute_stats(self, num_batch=100, batch_size=32):
|
||||
rb = TensorDictReplayBuffer(
|
||||
storage=storage,
|
||||
storage=self._storage,
|
||||
batch_size=batch_size,
|
||||
prefetch=True,
|
||||
)
|
||||
batch = rb.sample()
|
||||
|
||||
image_channels = batch["observation", "image"].shape[1]
|
||||
image_mean = torch.zeros(image_channels)
|
||||
image_std = torch.zeros(image_channels)
|
||||
image_max = torch.tensor([-math.inf] * image_channels)
|
||||
image_min = torch.tensor([math.inf] * image_channels)
|
||||
|
||||
state_channels = batch["observation", "state"].shape[1]
|
||||
state_mean = torch.zeros(state_channels)
|
||||
state_std = torch.zeros(state_channels)
|
||||
state_max = torch.tensor([-math.inf] * state_channels)
|
||||
state_min = torch.tensor([math.inf] * state_channels)
|
||||
|
||||
action_channels = batch["action"].shape[1]
|
||||
action_mean = torch.zeros(action_channels)
|
||||
action_std = torch.zeros(action_channels)
|
||||
action_max = torch.tensor([-math.inf] * action_channels)
|
||||
action_min = torch.tensor([math.inf] * action_channels)
|
||||
mean, std, max, min = {}, {}, {}, {}
|
||||
|
||||
# compute mean, min, max
|
||||
for _ in tqdm.tqdm(range(num_batch)):
|
||||
image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
|
||||
state_mean += einops.reduce(batch["observation", "state"], "b c -> c", "mean")
|
||||
action_mean += einops.reduce(batch["action"], "b c -> c", "mean")
|
||||
|
||||
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
|
||||
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
|
||||
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
|
||||
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
|
||||
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
|
||||
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
|
||||
image_max = torch.maximum(image_max, b_image_max)
|
||||
image_min = torch.maximum(image_min, b_image_min)
|
||||
state_max = torch.maximum(state_max, b_state_max)
|
||||
state_min = torch.maximum(state_min, b_state_min)
|
||||
action_max = torch.maximum(action_max, b_action_max)
|
||||
action_min = torch.maximum(action_min, b_action_min)
|
||||
|
||||
batch = rb.sample()
|
||||
|
||||
image_mean /= num_batch
|
||||
state_mean /= num_batch
|
||||
action_mean /= num_batch
|
||||
|
||||
for i in tqdm.tqdm(range(num_batch)):
|
||||
b_image_mean = einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
|
||||
b_state_mean = einops.reduce(batch["observation", "state"], "b c -> c", "mean")
|
||||
b_action_mean = einops.reduce(batch["action"], "b c -> c", "mean")
|
||||
image_std += (b_image_mean - image_mean) ** 2
|
||||
state_std += (b_state_mean - state_mean) ** 2
|
||||
action_std += (b_action_mean - action_mean) ** 2
|
||||
|
||||
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
|
||||
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
|
||||
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
|
||||
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
|
||||
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
|
||||
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
|
||||
image_max = torch.maximum(image_max, b_image_max)
|
||||
image_min = torch.maximum(image_min, b_image_min)
|
||||
state_max = torch.maximum(state_max, b_state_max)
|
||||
state_min = torch.maximum(state_min, b_state_min)
|
||||
action_max = torch.maximum(action_max, b_action_max)
|
||||
action_min = torch.maximum(action_min, b_action_min)
|
||||
|
||||
if i < num_batch - 1:
|
||||
for key, pattern in self.stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
if key not in mean:
|
||||
# first batch initialize mean, min, max
|
||||
mean[key] = einops.reduce(batch[key], pattern, "mean")
|
||||
max[key] = einops.reduce(batch[key], pattern, "max")
|
||||
min[key] = einops.reduce(batch[key], pattern, "min")
|
||||
else:
|
||||
mean[key] += einops.reduce(batch[key], pattern, "mean")
|
||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||
batch = rb.sample()
|
||||
|
||||
image_std = torch.sqrt(image_std / num_batch)
|
||||
state_std = torch.sqrt(state_std / num_batch)
|
||||
action_std = torch.sqrt(action_std / num_batch)
|
||||
for key in self.stats_patterns:
|
||||
mean[key] /= num_batch
|
||||
|
||||
stats = TensorDict(
|
||||
{
|
||||
("observation", "image", "mean"): image_mean[None, :, None, None],
|
||||
("observation", "image", "std"): image_std[None, :, None, None],
|
||||
("observation", "image", "max"): image_max[None, :, None, None],
|
||||
("observation", "image", "min"): image_min[None, :, None, None],
|
||||
("observation", "state", "mean"): state_mean[None, :],
|
||||
("observation", "state", "std"): state_std[None, :],
|
||||
("observation", "state", "max"): state_max[None, :],
|
||||
("observation", "state", "min"): state_min[None, :],
|
||||
("action", "mean"): action_mean[None, :],
|
||||
("action", "std"): action_std[None, :],
|
||||
("action", "max"): action_max[None, :],
|
||||
("action", "min"): action_min[None, :],
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
stats["next", "observation", "image"] = stats["observation", "image"]
|
||||
stats["next", "observation", "state"] = stats["observation", "state"]
|
||||
# compute std, min, max
|
||||
for _ in tqdm.tqdm(range(num_batch)):
|
||||
batch = rb.sample()
|
||||
for key, pattern in self.stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
||||
if key not in std:
|
||||
# first batch initialize std
|
||||
std[key] = (batch_mean - mean[key]) ** 2
|
||||
else:
|
||||
std[key] += (batch_mean - mean[key]) ** 2
|
||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||
|
||||
for key in self.stats_patterns:
|
||||
std[key] = torch.sqrt(std[key] / num_batch)
|
||||
|
||||
stats = TensorDict({}, batch_size=[])
|
||||
for key in self.stats_patterns:
|
||||
stats[(*key, "mean")] = mean[key]
|
||||
stats[(*key, "std")] = std[key]
|
||||
stats[(*key, "max")] = max[key]
|
||||
stats[(*key, "min")] = min[key]
|
||||
|
||||
if key[0] == "observation":
|
||||
# use same stats for the next observations
|
||||
stats[("next", *key)] = stats[key]
|
||||
return stats
|
||||
|
||||
@@ -10,7 +10,9 @@ from lerobot.common.envs.transforms import NormalizeTransform
|
||||
DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
|
||||
|
||||
|
||||
def make_offline_buffer(cfg, sampler=None):
|
||||
def make_offline_buffer(
|
||||
cfg, overwrite_sampler=None, normalize=True, overwrite_batch_size=None, overwrite_prefetch=None
|
||||
):
|
||||
if cfg.policy.balanced_sampling:
|
||||
assert cfg.online_steps > 0
|
||||
batch_size = None
|
||||
@@ -23,9 +25,13 @@ def make_offline_buffer(cfg, sampler=None):
|
||||
pin_memory = cfg.device == "cuda"
|
||||
prefetch = cfg.prefetch
|
||||
|
||||
overwrite_sampler = sampler is not None
|
||||
if overwrite_batch_size is not None:
|
||||
batch_size = overwrite_batch_size
|
||||
|
||||
if not overwrite_sampler:
|
||||
if overwrite_prefetch is not None:
|
||||
prefetch = overwrite_prefetch
|
||||
|
||||
if overwrite_sampler is None:
|
||||
# TODO(rcadene): move batch_size outside
|
||||
num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon
|
||||
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
|
||||
@@ -46,6 +52,8 @@ def make_offline_buffer(cfg, sampler=None):
|
||||
num_slices=num_traj_per_batch,
|
||||
strict_length=False,
|
||||
)
|
||||
else:
|
||||
sampler = overwrite_sampler
|
||||
|
||||
if cfg.env.name == "simxarm":
|
||||
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||
@@ -70,30 +78,31 @@ def make_offline_buffer(cfg, sampler=None):
|
||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
)
|
||||
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
|
||||
stats = offline_buffer.compute_or_load_stats()
|
||||
in_keys = [("observation", "state"), ("action")]
|
||||
if normalize:
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
|
||||
stats = offline_buffer.compute_or_load_stats()
|
||||
in_keys = [("observation", "state"), ("action")]
|
||||
|
||||
if cfg.policy == "tdmpc":
|
||||
# TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc
|
||||
in_keys.append(("observation", "image"))
|
||||
# since we use next observations in tdmpc
|
||||
in_keys.append(("next", "observation", "image"))
|
||||
in_keys.append(("next", "observation", "state"))
|
||||
if cfg.policy == "tdmpc":
|
||||
for key in offline_buffer.image_keys:
|
||||
# TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc
|
||||
in_keys.append(key)
|
||||
# since we use next observations in tdmpc
|
||||
in_keys.append(("next", *key))
|
||||
in_keys.append(("next", "observation", "state"))
|
||||
|
||||
if cfg.policy == "diffusion" and cfg.env.name == "pusht":
|
||||
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
||||
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
||||
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
||||
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
if cfg.policy == "diffusion" and cfg.env.name == "pusht":
|
||||
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
||||
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
||||
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
||||
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
|
||||
transform = NormalizeTransform(stats, in_keys, mode="min_max")
|
||||
offline_buffer.set_transform(transform)
|
||||
transform = NormalizeTransform(stats, in_keys, mode="min_max")
|
||||
offline_buffer.set_transform(transform)
|
||||
|
||||
if not overwrite_sampler:
|
||||
num_steps = len(offline_buffer)
|
||||
index = torch.arange(0, num_steps, 1)
|
||||
index = torch.arange(0, offline_buffer.num_frames, 1)
|
||||
sampler.extend(index)
|
||||
|
||||
return offline_buffer
|
||||
|
||||
@@ -183,8 +183,7 @@ class PushtExperienceReplay(AbstractExperienceReplay):
|
||||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
|
||||
print("before " + """episode = TensorDict(""")
|
||||
episode = TensorDict(
|
||||
ep_td = TensorDict(
|
||||
{
|
||||
("observation", "image"): image[:-1],
|
||||
("observation", "state"): agent_pos[:-1],
|
||||
@@ -203,11 +202,11 @@ class PushtExperienceReplay(AbstractExperienceReplay):
|
||||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = episode[0].expand(total_frames).memmap_like(self.data_dir)
|
||||
td_data = ep_td[0].expand(total_frames).memmap_like(self.data_dir)
|
||||
|
||||
td_data[idxtd : idxtd + len(episode)] = episode
|
||||
td_data[idxtd : idxtd + len(ep_td)] = ep_td
|
||||
|
||||
idx0 = idx1
|
||||
idxtd = idxtd + len(episode)
|
||||
idxtd = idxtd + len(ep_td)
|
||||
|
||||
return TensorStorage(td_data.lock_())
|
||||
|
||||
Reference in New Issue
Block a user