Refactor datasets with abstract class

This commit is contained in:
Remi Cadene
2024-03-05 10:20:57 +00:00
parent e132a267aa
commit d4e0849970
4 changed files with 262 additions and 351 deletions

View File

@@ -1,6 +1,3 @@
import logging
import math
import os
from pathlib import Path
from typing import Callable
@@ -12,16 +9,14 @@ import torch
import torchrl
import tqdm
from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import Sampler
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.storages import TensorStorage
from torchrl.data.replay_buffers.writers import Writer
from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
from lerobot.common.datasets.abstract import AbstractExperienceReplay
from lerobot.common.datasets.utils import download_and_extract_zip
from lerobot.common.envs.transforms import NormalizeTransform
# as define in env
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
@@ -87,114 +82,36 @@ def add_tee(
return body
class PushtExperienceReplay(TensorDictReplayBuffer):
class PushtExperienceReplay(AbstractExperienceReplay):
def __init__(
self,
dataset_id: str,
batch_size: int = None,
*,
shuffle: bool = True,
num_slices: int = None,
slice_len: int = None,
pad: float = None,
replacement: bool = None,
streaming: bool = False,
root: Path = None,
sampler: Sampler = None,
writer: Writer = None,
collate_fn: Callable = None,
pin_memory: bool = False,
prefetch: int = None,
transform: "torchrl.envs.Transform" = None, # noqa: F821
split_trajs: bool = False,
strict_length: bool = True,
sampler: SliceSampler = None,
collate_fn: Callable = None,
writer: Writer = None,
transform: "torchrl.envs.Transform" = None, # noqa-F821
):
if streaming:
raise NotImplementedError
self.streaming = streaming
self.dataset_id = dataset_id
self.split_trajs = split_trajs
self.shuffle = shuffle
self.num_slices = num_slices
self.slice_len = slice_len
self.pad = pad
self.strict_length = strict_length
if (self.num_slices is not None) and (self.slice_len is not None):
raise ValueError("num_slices or slice_len can be not None, but not both.")
if split_trajs:
raise NotImplementedError
if root is None:
root = _get_root_dir("pusht")
os.makedirs(root, exist_ok=True)
self.root = root
if not self._is_downloaded():
storage = self._download_and_preproc()
else:
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
stats = self._compute_or_load_stats(storage)
transform = NormalizeTransform(
stats,
in_keys=[
# TODO(rcadene): imagenet normalization is applied inside diffusion policy
# We need to automate this for tdmpc and others
# ("observation", "image"),
("observation", "state"),
# TODO(rcadene): for tdmpc, we might want next image and state
# ("next", "observation", "image"),
# ("next", "observation", "state"),
("action"),
],
mode="min_max",
)
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
transform.stats["observation", "state", "min"] = torch.tensor(
[13.456424, 32.938293], dtype=torch.float32
)
transform.stats["observation", "state", "max"] = torch.tensor(
[496.14618, 510.9579], dtype=torch.float32
)
transform.stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
transform.stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
if writer is None:
writer = ImmutableDatasetWriter()
if collate_fn is None:
collate_fn = _collate_id
super().__init__(
storage=storage,
sampler=sampler,
writer=writer,
collate_fn=collate_fn,
dataset_id,
batch_size,
shuffle=shuffle,
root=root,
pin_memory=pin_memory,
prefetch=prefetch,
batch_size=batch_size,
sampler=sampler,
collate_fn=collate_fn,
writer=writer,
transform=transform,
)
@property
def num_samples(self) -> int:
return len(self)
@property
def num_episodes(self) -> int:
return len(self._storage._storage["episode"].unique())
@property
def data_path_root(self) -> Path:
return None if self.streaming else self.root / self.dataset_id
def _is_downloaded(self) -> bool:
return self.data_path_root.is_dir()
def _download_and_preproc(self):
# download
raw_dir = self.root / "raw"
raw_dir = self.data_dir / "raw"
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
if not zarr_path.is_dir():
raw_dir.mkdir(parents=True, exist_ok=True)
@@ -286,7 +203,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
if episode_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = episode[0].expand(total_frames).memmap_like(self.root / self.dataset_id)
td_data = episode[0].expand(total_frames).memmap_like(self.data_dir)
td_data[idxtd : idxtd + len(episode)] = episode
@@ -294,112 +211,3 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
idxtd = idxtd + len(episode)
return TensorStorage(td_data.lock_())
def _compute_stats(self, storage, num_batch=100, batch_size=32):
rb = TensorDictReplayBuffer(
storage=storage,
batch_size=batch_size,
prefetch=True,
)
batch = rb.sample()
image_channels = batch["observation", "image"].shape[1]
image_mean = torch.zeros(image_channels)
image_std = torch.zeros(image_channels)
image_max = torch.tensor([-math.inf] * image_channels)
image_min = torch.tensor([math.inf] * image_channels)
state_channels = batch["observation", "state"].shape[1]
state_mean = torch.zeros(state_channels)
state_std = torch.zeros(state_channels)
state_max = torch.tensor([-math.inf] * state_channels)
state_min = torch.tensor([math.inf] * state_channels)
action_channels = batch["action"].shape[1]
action_mean = torch.zeros(action_channels)
action_std = torch.zeros(action_channels)
action_max = torch.tensor([-math.inf] * action_channels)
action_min = torch.tensor([math.inf] * action_channels)
for _ in tqdm.tqdm(range(num_batch)):
image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
state_mean += einops.reduce(batch["observation", "state"], "b c -> c", "mean")
action_mean += einops.reduce(batch["action"], "b c -> c", "mean")
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
image_max = torch.maximum(image_max, b_image_max)
image_min = torch.maximum(image_min, b_image_min)
state_max = torch.maximum(state_max, b_state_max)
state_min = torch.maximum(state_min, b_state_min)
action_max = torch.maximum(action_max, b_action_max)
action_min = torch.maximum(action_min, b_action_min)
batch = rb.sample()
image_mean /= num_batch
state_mean /= num_batch
action_mean /= num_batch
for i in tqdm.tqdm(range(num_batch)):
b_image_mean = einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
b_state_mean = einops.reduce(batch["observation", "state"], "b c -> c", "mean")
b_action_mean = einops.reduce(batch["action"], "b c -> c", "mean")
image_std += (b_image_mean - image_mean) ** 2
state_std += (b_state_mean - state_mean) ** 2
action_std += (b_action_mean - action_mean) ** 2
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
image_max = torch.maximum(image_max, b_image_max)
image_min = torch.maximum(image_min, b_image_min)
state_max = torch.maximum(state_max, b_state_max)
state_min = torch.maximum(state_min, b_state_min)
action_max = torch.maximum(action_max, b_action_max)
action_min = torch.maximum(action_min, b_action_min)
if i < num_batch - 1:
batch = rb.sample()
image_std = torch.sqrt(image_std / num_batch)
state_std = torch.sqrt(state_std / num_batch)
action_std = torch.sqrt(action_std / num_batch)
stats = TensorDict(
{
("observation", "image", "mean"): image_mean[None, :, None, None],
("observation", "image", "std"): image_std[None, :, None, None],
("observation", "image", "max"): image_max[None, :, None, None],
("observation", "image", "min"): image_min[None, :, None, None],
("observation", "state", "mean"): state_mean[None, :],
("observation", "state", "std"): state_std[None, :],
("observation", "state", "max"): state_max[None, :],
("observation", "state", "min"): state_min[None, :],
("action", "mean"): action_mean[None, :],
("action", "std"): action_std[None, :],
("action", "max"): action_max[None, :],
("action", "min"): action_min[None, :],
},
batch_size=[],
)
stats["next", "observation", "image"] = stats["observation", "image"]
stats["next", "observation", "state"] = stats["observation", "state"]
return stats
def _compute_or_load_stats(self, storage) -> TensorDict:
stats_path = self.root / self.dataset_id / "stats.pth"
if stats_path.exists():
stats = torch.load(stats_path)
else:
logging.info(f"compute_stats and save to {stats_path}")
stats = self._compute_stats(storage)
torch.save(stats, stats_path)
return stats