Add Normalize, non_blocking=True in tdmpc, tdmpc run (TODO: diffusion)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
@@ -16,9 +17,10 @@ from torchrl.data.replay_buffers.samplers import Sampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
||||
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
||||
from lerobot.common.datasets import utils
|
||||
from lerobot.common.datasets.utils import download_and_extract_zip
|
||||
from lerobot.common.envs.transforms import NormalizeTransform
|
||||
|
||||
# as define in env
|
||||
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
||||
@@ -132,29 +134,16 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
else:
|
||||
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
||||
|
||||
# if num_slices is not None or slice_len is not None:
|
||||
# if sampler is not None:
|
||||
# raise ValueError(
|
||||
# "`num_slices` and `slice_len` are exclusive with the `sampler` argument."
|
||||
# )
|
||||
|
||||
# if replacement:
|
||||
# if not self.shuffle:
|
||||
# raise RuntimeError(
|
||||
# "shuffle=False can only be used when replacement=False."
|
||||
# )
|
||||
# sampler = SliceSampler(
|
||||
# num_slices=num_slices,
|
||||
# slice_len=slice_len,
|
||||
# strict_length=strict_length,
|
||||
# )
|
||||
# else:
|
||||
# sampler = SliceSamplerWithoutReplacement(
|
||||
# num_slices=num_slices,
|
||||
# slice_len=slice_len,
|
||||
# strict_length=strict_length,
|
||||
# shuffle=self.shuffle,
|
||||
# )
|
||||
mean_std = self._compute_or_load_mean_std(storage)
|
||||
mean_std["next", "observation", "image"] = mean_std["observation", "image"]
|
||||
mean_std["next", "observation", "state"] = mean_std["observation", "state"]
|
||||
transform = NormalizeTransform(mean_std, in_keys=[
|
||||
("observation", "image"),
|
||||
("observation", "state"),
|
||||
("next", "observation", "image"),
|
||||
("next", "observation", "state"),
|
||||
("action"),
|
||||
])
|
||||
|
||||
if writer is None:
|
||||
writer = ImmutableDatasetWriter()
|
||||
@@ -193,10 +182,10 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
||||
if not zarr_path.is_dir():
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
utils.download_and_extract_zip(PUSHT_URL, raw_dir)
|
||||
download_and_extract_zip(PUSHT_URL, raw_dir)
|
||||
|
||||
# load
|
||||
dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
|
||||
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
|
||||
|
||||
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
|
||||
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
||||
@@ -287,3 +276,62 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
idxtd = idxtd + len(episode)
|
||||
|
||||
return TensorStorage(td_data.lock_())
|
||||
|
||||
def _compute_mean_std(self, storage, num_batch=10, batch_size=32):
|
||||
rb = TensorDictReplayBuffer(
|
||||
storage=storage,
|
||||
batch_size=batch_size,
|
||||
prefetch=True,
|
||||
)
|
||||
batch = rb.sample()
|
||||
image_mean = torch.zeros(batch["observation", "image"].shape[1])
|
||||
image_std = torch.zeros(batch["observation", "image"].shape[1])
|
||||
state_mean = torch.zeros(batch["observation", "state"].shape[1])
|
||||
state_std = torch.zeros(batch["observation", "state"].shape[1])
|
||||
action_mean = torch.zeros(batch["action"].shape[1])
|
||||
action_std = torch.zeros(batch["action"].shape[1])
|
||||
|
||||
for i in tqdm.tqdm(range(num_batch)):
|
||||
image_mean += einops.reduce(batch["observation", "image"], 'b c h w -> c', reduction='mean')
|
||||
state_mean += batch["observation", "state"].mean(dim=0)
|
||||
action_mean += batch["action"].mean(dim=0)
|
||||
batch = rb.sample()
|
||||
|
||||
image_mean /= num_batch
|
||||
state_mean /= num_batch
|
||||
action_mean /= num_batch
|
||||
|
||||
for i in tqdm.tqdm(range(num_batch)):
|
||||
image_mean_batch = einops.reduce(batch["observation", "image"], 'b c h w -> c', reduction='mean')
|
||||
image_std += (image_mean_batch - image_mean) ** 2
|
||||
state_std += (batch["observation", "state"].mean(dim=0) - state_mean) ** 2
|
||||
action_std += (batch["action"].mean(dim=0) - action_mean) ** 2
|
||||
if i < num_batch - 1:
|
||||
batch = rb.sample()
|
||||
|
||||
image_std = torch.sqrt(image_std / num_batch)
|
||||
state_std = torch.sqrt(state_std / num_batch)
|
||||
action_std = torch.sqrt(action_std / num_batch)
|
||||
|
||||
mean_std = TensorDict(
|
||||
{
|
||||
("observation", "image", "mean"): image_mean[None,:,None,None],
|
||||
("observation", "image", "std"): image_std[None,:,None,None],
|
||||
("observation", "state", "mean"): state_mean[None,:],
|
||||
("observation", "state", "std"): state_std[None,:],
|
||||
("action", "mean"): action_mean[None,:],
|
||||
("action", "std"): action_std[None,:],
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
return mean_std
|
||||
|
||||
def _compute_or_load_mean_std(self, storage) -> TensorDict:
|
||||
mean_std_path = self.root / self.dataset_id / "mean_std.pth"
|
||||
if mean_std_path.exists():
|
||||
mean_std = torch.load(mean_std_path)
|
||||
else:
|
||||
logging.info(f"compute_mean_std and save to {mean_std_path}")
|
||||
mean_std = self._compute_mean_std(storage)
|
||||
torch.save(mean_std, mean_std_path)
|
||||
return mean_std
|
||||
|
||||
Reference in New Issue
Block a user