Add mode to NormalizeTransform with mean_std or min_max (Not fully tested)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
@@ -134,18 +135,19 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
else:
|
||||
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
||||
|
||||
mean_std = self._compute_or_load_mean_std(storage)
|
||||
mean_std["next", "observation", "image"] = mean_std["observation", "image"]
|
||||
mean_std["next", "observation", "state"] = mean_std["observation", "state"]
|
||||
stats = self._compute_or_load_stats(storage)
|
||||
stats["next", "observation", "image"] = stats["observation", "image"]
|
||||
stats["next", "observation", "state"] = stats["observation", "state"]
|
||||
transform = NormalizeTransform(
|
||||
mean_std,
|
||||
stats,
|
||||
in_keys=[
|
||||
("observation", "image"),
|
||||
# ("observation", "image"),
|
||||
("observation", "state"),
|
||||
("next", "observation", "image"),
|
||||
# ("next", "observation", "image"),
|
||||
("next", "observation", "state"),
|
||||
("action"),
|
||||
],
|
||||
mode="min_max",
|
||||
)
|
||||
|
||||
if writer is None:
|
||||
@@ -282,7 +284,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
|
||||
return TensorStorage(td_data.lock_())
|
||||
|
||||
def _compute_mean_std(self, storage, num_batch=10, batch_size=32):
|
||||
def _compute_stats(self, storage, num_batch=100, batch_size=32):
|
||||
rb = TensorDictReplayBuffer(
|
||||
storage=storage,
|
||||
batch_size=batch_size,
|
||||
@@ -291,15 +293,27 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
batch = rb.sample()
|
||||
image_mean = torch.zeros(batch["observation", "image"].shape[1])
|
||||
image_std = torch.zeros(batch["observation", "image"].shape[1])
|
||||
image_max = -math.inf
|
||||
image_min = math.inf
|
||||
state_mean = torch.zeros(batch["observation", "state"].shape[1])
|
||||
state_std = torch.zeros(batch["observation", "state"].shape[1])
|
||||
state_max = -math.inf
|
||||
state_min = math.inf
|
||||
action_mean = torch.zeros(batch["action"].shape[1])
|
||||
action_std = torch.zeros(batch["action"].shape[1])
|
||||
action_max = -math.inf
|
||||
action_min = math.inf
|
||||
|
||||
for _ in tqdm.tqdm(range(num_batch)):
|
||||
image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean")
|
||||
state_mean += batch["observation", "state"].mean(dim=0)
|
||||
action_mean += batch["action"].mean(dim=0)
|
||||
image_max = max(image_max, batch["observation", "image"].max().item())
|
||||
image_min = min(image_min, batch["observation", "image"].min().item())
|
||||
state_max = max(state_max, batch["observation", "state"].max().item())
|
||||
state_min = min(state_min, batch["observation", "state"].min().item())
|
||||
action_max = max(action_max, batch["action"].max().item())
|
||||
action_min = min(action_min, batch["action"].min().item())
|
||||
batch = rb.sample()
|
||||
|
||||
image_mean /= num_batch
|
||||
@@ -311,6 +325,12 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
image_std += (image_mean_batch - image_mean) ** 2
|
||||
state_std += (batch["observation", "state"].mean(dim=0) - state_mean) ** 2
|
||||
action_std += (batch["action"].mean(dim=0) - action_mean) ** 2
|
||||
image_max = max(image_max, batch["observation", "image"].max().item())
|
||||
image_min = min(image_min, batch["observation", "image"].min().item())
|
||||
state_max = max(state_max, batch["observation", "state"].max().item())
|
||||
state_min = min(state_min, batch["observation", "state"].min().item())
|
||||
action_max = max(action_max, batch["action"].max().item())
|
||||
action_min = min(action_min, batch["action"].min().item())
|
||||
if i < num_batch - 1:
|
||||
batch = rb.sample()
|
||||
|
||||
@@ -318,25 +338,31 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
state_std = torch.sqrt(state_std / num_batch)
|
||||
action_std = torch.sqrt(action_std / num_batch)
|
||||
|
||||
mean_std = TensorDict(
|
||||
stats = TensorDict(
|
||||
{
|
||||
("observation", "image", "mean"): image_mean[None, :, None, None],
|
||||
("observation", "image", "std"): image_std[None, :, None, None],
|
||||
("observation", "image", "max"): torch.tensor(image_max),
|
||||
("observation", "image", "min"): torch.tensor(image_min),
|
||||
("observation", "state", "mean"): state_mean[None, :],
|
||||
("observation", "state", "std"): state_std[None, :],
|
||||
("observation", "state", "max"): torch.tensor(state_max),
|
||||
("observation", "state", "min"): torch.tensor(state_min),
|
||||
("action", "mean"): action_mean[None, :],
|
||||
("action", "std"): action_std[None, :],
|
||||
("action", "max"): torch.tensor(action_max),
|
||||
("action", "min"): torch.tensor(action_min),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
return mean_std
|
||||
return stats
|
||||
|
||||
def _compute_or_load_mean_std(self, storage) -> TensorDict:
|
||||
mean_std_path = self.root / self.dataset_id / "mean_std.pth"
|
||||
if mean_std_path.exists():
|
||||
mean_std = torch.load(mean_std_path)
|
||||
def _compute_or_load_stats(self, storage) -> TensorDict:
|
||||
stats_path = self.root / self.dataset_id / "stats.pth"
|
||||
if stats_path.exists():
|
||||
stats = torch.load(stats_path)
|
||||
else:
|
||||
logging.info(f"compute_mean_std and save to {mean_std_path}")
|
||||
mean_std = self._compute_mean_std(storage)
|
||||
torch.save(mean_std, mean_std_path)
|
||||
return mean_std
|
||||
logging.info(f"compute_stats and save to {stats_path}")
|
||||
stats = self._compute_stats(storage)
|
||||
torch.save(stats, stats_path)
|
||||
return stats
|
||||
|
||||
Reference in New Issue
Block a user