Improve visualize_dataset, Improve AbstractReplayBuffer, Small improvements
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
import abc
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
@@ -50,6 +49,22 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
@property
|
||||
def stats_patterns(self) -> dict:
|
||||
return {
|
||||
("observation", "state"): "b c -> 1 c",
|
||||
("observation", "image"): "b c h w -> 1 c 1 1",
|
||||
("action"): "b c -> 1 c",
|
||||
}
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list:
|
||||
return [("observation", "image")]
|
||||
|
||||
@property
|
||||
def num_cameras(self) -> int:
|
||||
return len(self.image_keys)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self)
|
||||
@@ -67,7 +82,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||
stats = torch.load(stats_path)
|
||||
else:
|
||||
logging.info(f"compute_stats and save to {stats_path}")
|
||||
stats = self._compute_stats(self._storage._storage, num_batch, batch_size)
|
||||
stats = self._compute_stats(num_batch, batch_size)
|
||||
torch.save(stats, stats_path)
|
||||
return stats
|
||||
|
||||
@@ -85,101 +100,59 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||
def _is_downloaded(self) -> bool:
|
||||
return self.data_dir.is_dir()
|
||||
|
||||
def _compute_stats(self, storage, num_batch=100, batch_size=32):
|
||||
def _compute_stats(self, num_batch=100, batch_size=32):
|
||||
rb = TensorDictReplayBuffer(
|
||||
storage=storage,
|
||||
storage=self._storage,
|
||||
batch_size=batch_size,
|
||||
prefetch=True,
|
||||
)
|
||||
batch = rb.sample()
|
||||
|
||||
image_channels = batch["observation", "image"].shape[1]
|
||||
image_mean = torch.zeros(image_channels)
|
||||
image_std = torch.zeros(image_channels)
|
||||
image_max = torch.tensor([-math.inf] * image_channels)
|
||||
image_min = torch.tensor([math.inf] * image_channels)
|
||||
|
||||
state_channels = batch["observation", "state"].shape[1]
|
||||
state_mean = torch.zeros(state_channels)
|
||||
state_std = torch.zeros(state_channels)
|
||||
state_max = torch.tensor([-math.inf] * state_channels)
|
||||
state_min = torch.tensor([math.inf] * state_channels)
|
||||
|
||||
action_channels = batch["action"].shape[1]
|
||||
action_mean = torch.zeros(action_channels)
|
||||
action_std = torch.zeros(action_channels)
|
||||
action_max = torch.tensor([-math.inf] * action_channels)
|
||||
action_min = torch.tensor([math.inf] * action_channels)
|
||||
mean, std, max, min = {}, {}, {}, {}
|
||||
|
||||
# compute mean, min, max
|
||||
for _ in tqdm.tqdm(range(num_batch)):
|
||||
image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
|
||||
state_mean += einops.reduce(batch["observation", "state"], "b c -> c", "mean")
|
||||
action_mean += einops.reduce(batch["action"], "b c -> c", "mean")
|
||||
|
||||
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
|
||||
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
|
||||
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
|
||||
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
|
||||
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
|
||||
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
|
||||
image_max = torch.maximum(image_max, b_image_max)
|
||||
image_min = torch.maximum(image_min, b_image_min)
|
||||
state_max = torch.maximum(state_max, b_state_max)
|
||||
state_min = torch.maximum(state_min, b_state_min)
|
||||
action_max = torch.maximum(action_max, b_action_max)
|
||||
action_min = torch.maximum(action_min, b_action_min)
|
||||
|
||||
batch = rb.sample()
|
||||
|
||||
image_mean /= num_batch
|
||||
state_mean /= num_batch
|
||||
action_mean /= num_batch
|
||||
|
||||
for i in tqdm.tqdm(range(num_batch)):
|
||||
b_image_mean = einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
|
||||
b_state_mean = einops.reduce(batch["observation", "state"], "b c -> c", "mean")
|
||||
b_action_mean = einops.reduce(batch["action"], "b c -> c", "mean")
|
||||
image_std += (b_image_mean - image_mean) ** 2
|
||||
state_std += (b_state_mean - state_mean) ** 2
|
||||
action_std += (b_action_mean - action_mean) ** 2
|
||||
|
||||
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
|
||||
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
|
||||
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
|
||||
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
|
||||
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
|
||||
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
|
||||
image_max = torch.maximum(image_max, b_image_max)
|
||||
image_min = torch.maximum(image_min, b_image_min)
|
||||
state_max = torch.maximum(state_max, b_state_max)
|
||||
state_min = torch.maximum(state_min, b_state_min)
|
||||
action_max = torch.maximum(action_max, b_action_max)
|
||||
action_min = torch.maximum(action_min, b_action_min)
|
||||
|
||||
if i < num_batch - 1:
|
||||
for key, pattern in self.stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
if key not in mean:
|
||||
# first batch initialize mean, min, max
|
||||
mean[key] = einops.reduce(batch[key], pattern, "mean")
|
||||
max[key] = einops.reduce(batch[key], pattern, "max")
|
||||
min[key] = einops.reduce(batch[key], pattern, "min")
|
||||
else:
|
||||
mean[key] += einops.reduce(batch[key], pattern, "mean")
|
||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||
batch = rb.sample()
|
||||
|
||||
image_std = torch.sqrt(image_std / num_batch)
|
||||
state_std = torch.sqrt(state_std / num_batch)
|
||||
action_std = torch.sqrt(action_std / num_batch)
|
||||
for key in self.stats_patterns:
|
||||
mean[key] /= num_batch
|
||||
|
||||
stats = TensorDict(
|
||||
{
|
||||
("observation", "image", "mean"): image_mean[None, :, None, None],
|
||||
("observation", "image", "std"): image_std[None, :, None, None],
|
||||
("observation", "image", "max"): image_max[None, :, None, None],
|
||||
("observation", "image", "min"): image_min[None, :, None, None],
|
||||
("observation", "state", "mean"): state_mean[None, :],
|
||||
("observation", "state", "std"): state_std[None, :],
|
||||
("observation", "state", "max"): state_max[None, :],
|
||||
("observation", "state", "min"): state_min[None, :],
|
||||
("action", "mean"): action_mean[None, :],
|
||||
("action", "std"): action_std[None, :],
|
||||
("action", "max"): action_max[None, :],
|
||||
("action", "min"): action_min[None, :],
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
stats["next", "observation", "image"] = stats["observation", "image"]
|
||||
stats["next", "observation", "state"] = stats["observation", "state"]
|
||||
# compute std, min, max
|
||||
for _ in tqdm.tqdm(range(num_batch)):
|
||||
batch = rb.sample()
|
||||
for key, pattern in self.stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
||||
if key not in std:
|
||||
# first batch initialize std
|
||||
std[key] = (batch_mean - mean[key]) ** 2
|
||||
else:
|
||||
std[key] += (batch_mean - mean[key]) ** 2
|
||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||
|
||||
for key in self.stats_patterns:
|
||||
std[key] = torch.sqrt(std[key] / num_batch)
|
||||
|
||||
stats = TensorDict({}, batch_size=[])
|
||||
for key in self.stats_patterns:
|
||||
stats[(*key, "mean")] = mean[key]
|
||||
stats[(*key, "std")] = std[key]
|
||||
stats[(*key, "max")] = max[key]
|
||||
stats[(*key, "min")] = min[key]
|
||||
|
||||
if key[0] == "observation":
|
||||
# use same stats for the next observations
|
||||
stats[("next", *key)] = stats[key]
|
||||
return stats
|
||||
|
||||
Reference in New Issue
Block a user