186 lines
7.8 KiB
Python
186 lines
7.8 KiB
Python
import abc
|
|
import logging
|
|
import math
|
|
from pathlib import Path
|
|
from typing import Callable
|
|
|
|
import einops
|
|
import torch
|
|
import torchrl
|
|
import tqdm
|
|
from tensordict import TensorDict
|
|
from torchrl.data.datasets.utils import _get_root_dir
|
|
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
|
from torchrl.data.replay_buffers.samplers import SliceSampler
|
|
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
|
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
|
|
|
|
|
class AbstractExperienceReplay(TensorDictReplayBuffer):
|
|
def __init__(
|
|
self,
|
|
dataset_id: str,
|
|
batch_size: int = None,
|
|
*,
|
|
shuffle: bool = True,
|
|
root: Path = None,
|
|
pin_memory: bool = False,
|
|
prefetch: int = None,
|
|
sampler: SliceSampler = None,
|
|
collate_fn: Callable = None,
|
|
writer: Writer = None,
|
|
transform: "torchrl.envs.Transform" = None,
|
|
):
|
|
self.dataset_id = dataset_id
|
|
self.shuffle = shuffle
|
|
self.root = _get_root_dir(self.dataset_id) if root is None else root
|
|
self.root = Path(self.root)
|
|
self.data_dir = self.root / self.dataset_id
|
|
|
|
storage = self._download_or_load_storage()
|
|
|
|
super().__init__(
|
|
storage=storage,
|
|
sampler=sampler,
|
|
writer=ImmutableDatasetWriter() if writer is None else writer,
|
|
collate_fn=_collate_id if collate_fn is None else collate_fn,
|
|
pin_memory=pin_memory,
|
|
prefetch=prefetch,
|
|
batch_size=batch_size,
|
|
transform=transform,
|
|
)
|
|
|
|
@property
|
|
def num_samples(self) -> int:
|
|
return len(self)
|
|
|
|
@property
|
|
def num_episodes(self) -> int:
|
|
return len(self._storage._storage["episode"].unique())
|
|
|
|
def set_transform(self, transform):
|
|
self.transform = transform
|
|
|
|
def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
|
|
stats_path = self.data_dir / "stats.pth"
|
|
if stats_path.exists():
|
|
stats = torch.load(stats_path)
|
|
else:
|
|
logging.info(f"compute_stats and save to {stats_path}")
|
|
stats = self._compute_stats(self._storage._storage, num_batch, batch_size)
|
|
torch.save(stats, stats_path)
|
|
return stats
|
|
|
|
@abc.abstractmethod
|
|
def _download_and_preproc(self) -> torch.StorageBase:
|
|
raise NotImplementedError()
|
|
|
|
def _download_or_load_storage(self):
|
|
if not self._is_downloaded():
|
|
storage = self._download_and_preproc()
|
|
else:
|
|
storage = TensorStorage(TensorDict.load_memmap(self.data_dir))
|
|
return storage
|
|
|
|
def _is_downloaded(self) -> bool:
|
|
return self.data_dir.is_dir()
|
|
|
|
def _compute_stats(self, storage, num_batch=100, batch_size=32):
|
|
rb = TensorDictReplayBuffer(
|
|
storage=storage,
|
|
batch_size=batch_size,
|
|
prefetch=True,
|
|
)
|
|
batch = rb.sample()
|
|
|
|
image_channels = batch["observation", "image"].shape[1]
|
|
image_mean = torch.zeros(image_channels)
|
|
image_std = torch.zeros(image_channels)
|
|
image_max = torch.tensor([-math.inf] * image_channels)
|
|
image_min = torch.tensor([math.inf] * image_channels)
|
|
|
|
state_channels = batch["observation", "state"].shape[1]
|
|
state_mean = torch.zeros(state_channels)
|
|
state_std = torch.zeros(state_channels)
|
|
state_max = torch.tensor([-math.inf] * state_channels)
|
|
state_min = torch.tensor([math.inf] * state_channels)
|
|
|
|
action_channels = batch["action"].shape[1]
|
|
action_mean = torch.zeros(action_channels)
|
|
action_std = torch.zeros(action_channels)
|
|
action_max = torch.tensor([-math.inf] * action_channels)
|
|
action_min = torch.tensor([math.inf] * action_channels)
|
|
|
|
for _ in tqdm.tqdm(range(num_batch)):
|
|
image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
|
|
state_mean += einops.reduce(batch["observation", "state"], "b c -> c", "mean")
|
|
action_mean += einops.reduce(batch["action"], "b c -> c", "mean")
|
|
|
|
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
|
|
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
|
|
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
|
|
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
|
|
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
|
|
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
|
|
image_max = torch.maximum(image_max, b_image_max)
|
|
image_min = torch.maximum(image_min, b_image_min)
|
|
state_max = torch.maximum(state_max, b_state_max)
|
|
state_min = torch.maximum(state_min, b_state_min)
|
|
action_max = torch.maximum(action_max, b_action_max)
|
|
action_min = torch.maximum(action_min, b_action_min)
|
|
|
|
batch = rb.sample()
|
|
|
|
image_mean /= num_batch
|
|
state_mean /= num_batch
|
|
action_mean /= num_batch
|
|
|
|
for i in tqdm.tqdm(range(num_batch)):
|
|
b_image_mean = einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
|
|
b_state_mean = einops.reduce(batch["observation", "state"], "b c -> c", "mean")
|
|
b_action_mean = einops.reduce(batch["action"], "b c -> c", "mean")
|
|
image_std += (b_image_mean - image_mean) ** 2
|
|
state_std += (b_state_mean - state_mean) ** 2
|
|
action_std += (b_action_mean - action_mean) ** 2
|
|
|
|
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
|
|
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
|
|
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
|
|
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
|
|
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
|
|
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
|
|
image_max = torch.maximum(image_max, b_image_max)
|
|
image_min = torch.maximum(image_min, b_image_min)
|
|
state_max = torch.maximum(state_max, b_state_max)
|
|
state_min = torch.maximum(state_min, b_state_min)
|
|
action_max = torch.maximum(action_max, b_action_max)
|
|
action_min = torch.maximum(action_min, b_action_min)
|
|
|
|
if i < num_batch - 1:
|
|
batch = rb.sample()
|
|
|
|
image_std = torch.sqrt(image_std / num_batch)
|
|
state_std = torch.sqrt(state_std / num_batch)
|
|
action_std = torch.sqrt(action_std / num_batch)
|
|
|
|
stats = TensorDict(
|
|
{
|
|
("observation", "image", "mean"): image_mean[None, :, None, None],
|
|
("observation", "image", "std"): image_std[None, :, None, None],
|
|
("observation", "image", "max"): image_max[None, :, None, None],
|
|
("observation", "image", "min"): image_min[None, :, None, None],
|
|
("observation", "state", "mean"): state_mean[None, :],
|
|
("observation", "state", "std"): state_std[None, :],
|
|
("observation", "state", "max"): state_max[None, :],
|
|
("observation", "state", "min"): state_min[None, :],
|
|
("action", "mean"): action_mean[None, :],
|
|
("action", "std"): action_std[None, :],
|
|
("action", "max"): action_max[None, :],
|
|
("action", "min"): action_min[None, :],
|
|
},
|
|
batch_size=[],
|
|
)
|
|
stats["next", "observation", "image"] = stats["observation", "image"]
|
|
stats["next", "observation", "state"] = stats["observation", "state"]
|
|
return stats
|