Make sure targets are normalized too (#106)

This commit is contained in:
Alexander Soare
2024-04-26 11:18:39 +01:00
committed by GitHub
parent b980c5dd9e
commit 45f351c618
8 changed files with 116 additions and 92 deletions

View File

@@ -75,13 +75,13 @@ class ActionChunkingTransformerConfig:
)
# Normalization / Unnormalization
normalize_input_modes: dict[str, str] = field(
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "mean_std",
}
)
unnormalize_output_modes: dict[str, str] = field(
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"action": "mean_std",
}

View File

@@ -72,8 +72,11 @@ class ActionChunkingTransformerPolicy(nn.Module):
if cfg is None:
cfg = ActionChunkingTransformerConfig()
self.cfg = cfg
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
)
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
@@ -216,6 +219,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
self.train()
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
loss_dict = self.forward(batch)
# TODO(rcadene): self.unnormalize_outputs(out_dict)

View File

@@ -83,17 +83,13 @@ class DiffusionConfig:
)
# Normalization / Unnormalization
normalize_input_modes: dict[str, str] = field(
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "min_max",
}
)
unnormalize_output_modes: dict[str, str] = field(
default_factory=lambda: {
"action": "min_max",
}
)
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
# Architecture / modeling.
# Vision backbone.

View File

@@ -56,8 +56,11 @@ class DiffusionPolicy(nn.Module):
if cfg is None:
cfg = DiffusionConfig()
self.cfg = cfg
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
)
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
@@ -162,6 +165,7 @@ class DiffusionPolicy(nn.Module):
self.diffusion.train()
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
loss = self.forward(batch)["loss"]
loss.backward()

View File

@@ -1,27 +1,21 @@
import torch
from torch import nn
from torch import Tensor, nn
def create_stats_buffers(shapes, modes, stats=None):
def create_stats_buffers(
shapes: dict[str, list[int]],
modes: dict[str, str],
stats: dict[str, dict[str, Tensor]] | None = None,
) -> dict[str, dict[str, nn.ParameterDict]]:
"""
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max statistics.
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
statistics.
Parameters:
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values are their shapes (e.g. `[3,96,96]`]).
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values are their normalization modes among:
- "mean_std": substract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") and values are dictionaries of statistic types and their values
(e.g. `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for training the model for the first time,
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
they are already in the policy state_dict.
Args: (see Normalize and Unnormalize)
Returns:
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing `nn.Parameters` set to
`requires_grad=False`, suitable to not be updated during backpropagation.
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing
`nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
"""
stats_buffers = {}
@@ -75,24 +69,32 @@ def create_stats_buffers(shapes, modes, stats=None):
class Normalize(nn.Module):
"""
Normalizes the input data (e.g. "observation.image") for more stable and faster convergence during training.
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
Parameters:
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values are their shapes (e.g. `[3,96,96]`]).
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values are their normalization modes among:
- "mean_std": substract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") and values are dictionaries of statistic types and their values
(e.g. `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for training the model for the first time,
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
they are already in the policy state_dict.
"""
def __init__(self, shapes, modes, stats=None):
def __init__(
self,
shapes: dict[str, list[int]],
modes: dict[str, str],
stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
are their normalization modes among:
- "mean_std": subtract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
and values are dictionaries of statistic types and their values (e.g.
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
training the model for the first time, these statistics will overwrite the default buffers. If
not provided, as expected for finetuning or evaluation, the default buffers should to be
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
dataset is not needed to get the stats, since they are already in the policy state_dict.
"""
super().__init__()
self.shapes = shapes
self.modes = modes
@@ -104,29 +106,33 @@ class Normalize(nn.Module):
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
def forward(self, batch):
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
for key, mode in self.modes.items():
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if mode == "mean_std":
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(
mean
).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(
std
).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(mean).any(), (
"`mean` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
assert not torch.isinf(std).any(), (
"`std` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
batch[key] = (batch[key] - mean) / (std + 1e-8)
elif mode == "min_max":
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(
min
).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(
max
).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(min).any(), (
"`min` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
assert not torch.isinf(max).any(), (
"`max` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
# normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min)
# normalize to [-1, 1]
@@ -138,23 +144,34 @@ class Normalize(nn.Module):
class Unnormalize(nn.Module):
"""
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their original range used by the environment.
Parameters:
shapes (dict): A dictionary where keys are output modalities (e.g. "action") and values are their shapes (e.g. [10]).
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "action") and values are their unnormalization modes among:
- "mean_std": multiply by standard deviation and add mean
- "min_max": go from [-1, 1] range to original range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "action") and values are dictionaries of statistic types and their values
(e.g. `{"max": torch.tensor(1)}, "min": torch.tensor(0)}`). If provided, as expected for training the model for the first time,
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
they are already in the policy state_dict.
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
original range used by the environment.
"""
def __init__(self, shapes, modes, stats=None):
def __init__(
self,
shapes: dict[str, list[int]],
modes: dict[str, str],
stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
are their normalization modes among:
- "mean_std": subtract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
and values are dictionaries of statistic types and their values (e.g.
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
training the model for the first time, these statistics will overwrite the default buffers. If
not provided, as expected for finetuning or evaluation, the default buffers should to be
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
dataset is not needed to get the stats, since they are already in the policy state_dict.
"""
super().__init__()
self.shapes = shapes
self.modes = modes
@@ -166,29 +183,33 @@ class Unnormalize(nn.Module):
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
def forward(self, batch):
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
for key, mode in self.modes.items():
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if mode == "mean_std":
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(
mean
).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(
std
).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(mean).any(), (
"`mean` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
assert not torch.isinf(std).any(), (
"`std` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
batch[key] = batch[key] * std + mean
elif mode == "min_max":
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(
min
).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(
max
).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(min).any(), (
"`min` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
assert not torch.isinf(max).any(), (
"`max` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min
else:

View File

@@ -36,10 +36,10 @@ policy:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
normalize_input_modes:
input_normalization_modes:
observation.images.top: mean_std
observation.state: mean_std
unnormalize_output_modes:
output_normalization_modes:
action: mean_std
# Architecture.

View File

@@ -50,10 +50,10 @@ policy:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
normalize_input_modes:
input_normalization_modes:
observation.image: mean_std
observation.state: min_max
unnormalize_output_modes:
output_normalization_modes:
action: min_max
# Architecture / modeling.

View File

@@ -258,7 +258,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
policy,
video_dir=Path(out_dir) / "eval",
max_episodes_rendered=4,
transform=offline_dataset.transform,
seed=cfg.seed,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)