Make sure targets are normalized too (#106)

This commit is contained in:
Alexander Soare
2024-04-26 11:18:39 +01:00
committed by GitHub
parent b980c5dd9e
commit 45f351c618
8 changed files with 116 additions and 92 deletions

View File

@@ -75,13 +75,13 @@ class ActionChunkingTransformerConfig:
) )
# Normalization / Unnormalization # Normalization / Unnormalization
normalize_input_modes: dict[str, str] = field( input_normalization_modes: dict[str, str] = field(
default_factory=lambda: { default_factory=lambda: {
"observation.image": "mean_std", "observation.image": "mean_std",
"observation.state": "mean_std", "observation.state": "mean_std",
} }
) )
unnormalize_output_modes: dict[str, str] = field( output_normalization_modes: dict[str, str] = field(
default_factory=lambda: { default_factory=lambda: {
"action": "mean_std", "action": "mean_std",
} }

View File

@@ -72,8 +72,11 @@ class ActionChunkingTransformerPolicy(nn.Module):
if cfg is None: if cfg is None:
cfg = ActionChunkingTransformerConfig() cfg = ActionChunkingTransformerConfig()
self.cfg = cfg self.cfg = cfg
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats) self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats) self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
)
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
@@ -216,6 +219,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
self.train() self.train()
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
loss_dict = self.forward(batch) loss_dict = self.forward(batch)
# TODO(rcadene): self.unnormalize_outputs(out_dict) # TODO(rcadene): self.unnormalize_outputs(out_dict)

View File

@@ -83,17 +83,13 @@ class DiffusionConfig:
) )
# Normalization / Unnormalization # Normalization / Unnormalization
normalize_input_modes: dict[str, str] = field( input_normalization_modes: dict[str, str] = field(
default_factory=lambda: { default_factory=lambda: {
"observation.image": "mean_std", "observation.image": "mean_std",
"observation.state": "min_max", "observation.state": "min_max",
} }
) )
unnormalize_output_modes: dict[str, str] = field( output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
default_factory=lambda: {
"action": "min_max",
}
)
# Architecture / modeling. # Architecture / modeling.
# Vision backbone. # Vision backbone.

View File

@@ -56,8 +56,11 @@ class DiffusionPolicy(nn.Module):
if cfg is None: if cfg is None:
cfg = DiffusionConfig() cfg = DiffusionConfig()
self.cfg = cfg self.cfg = cfg
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats) self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats) self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
)
# queues are populated during rollout of the policy, they contain the n latest observations and actions # queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None self._queues = None
@@ -162,6 +165,7 @@ class DiffusionPolicy(nn.Module):
self.diffusion.train() self.diffusion.train()
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
loss = self.forward(batch)["loss"] loss = self.forward(batch)["loss"]
loss.backward() loss.backward()

View File

@@ -1,27 +1,21 @@
import torch import torch
from torch import nn from torch import Tensor, nn
def create_stats_buffers(shapes, modes, stats=None): def create_stats_buffers(
shapes: dict[str, list[int]],
modes: dict[str, str],
stats: dict[str, dict[str, Tensor]] | None = None,
) -> dict[str, dict[str, nn.ParameterDict]]:
""" """
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max statistics. Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
statistics.
Parameters: Args: (see Normalize and Unnormalize)
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values are their shapes (e.g. `[3,96,96]`]).
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values are their normalization modes among:
- "mean_std": substract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") and values are dictionaries of statistic types and their values
(e.g. `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for training the model for the first time,
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
they are already in the policy state_dict.
Returns: Returns:
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing `nn.Parameters` set to dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing
`requires_grad=False`, suitable to not be updated during backpropagation. `nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
""" """
stats_buffers = {} stats_buffers = {}
@@ -75,24 +69,32 @@ def create_stats_buffers(shapes, modes, stats=None):
class Normalize(nn.Module): class Normalize(nn.Module):
""" """Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
Normalizes the input data (e.g. "observation.image") for more stable and faster convergence during training.
Parameters: def __init__(
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values are their shapes (e.g. `[3,96,96]`]). self,
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height shapes: dict[str, list[int]],
and width, assuming a channel-first (c, h, w) format. modes: dict[str, str],
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values are their normalization modes among: stats: dict[str, dict[str, Tensor]] | None = None,
- "mean_std": substract the mean and divide by standard deviation. ):
- "min_max": map to [-1, 1] range. """
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") and values are dictionaries of statistic types and their values Args:
(e.g. `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for training the model for the first time, shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
they are already in the policy state_dict. is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
""" modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
are their normalization modes among:
def __init__(self, shapes, modes, stats=None): - "mean_std": subtract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
and values are dictionaries of statistic types and their values (e.g.
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
training the model for the first time, these statistics will overwrite the default buffers. If
not provided, as expected for finetuning or evaluation, the default buffers should to be
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
dataset is not needed to get the stats, since they are already in the policy state_dict.
"""
super().__init__() super().__init__()
self.shapes = shapes self.shapes = shapes
self.modes = modes self.modes = modes
@@ -104,29 +106,33 @@ class Normalize(nn.Module):
# TODO(rcadene): should we remove torch.no_grad? # TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad @torch.no_grad
def forward(self, batch): def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
for key, mode in self.modes.items(): for key, mode in self.modes.items():
buffer = getattr(self, "buffer_" + key.replace(".", "_")) buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if mode == "mean_std": if mode == "mean_std":
mean = buffer["mean"] mean = buffer["mean"]
std = buffer["std"] std = buffer["std"]
assert not torch.isinf( assert not torch.isinf(mean).any(), (
mean "`mean` is infinity. You forgot to initialize with `stats` as argument, or called "
).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`." "`policy.load_state_dict`."
assert not torch.isinf( )
std assert not torch.isinf(std).any(), (
).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`." "`std` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
batch[key] = (batch[key] - mean) / (std + 1e-8) batch[key] = (batch[key] - mean) / (std + 1e-8)
elif mode == "min_max": elif mode == "min_max":
min = buffer["min"] min = buffer["min"]
max = buffer["max"] max = buffer["max"]
assert not torch.isinf( assert not torch.isinf(min).any(), (
min "`min` is infinity. You forgot to initialize with `stats` as argument, or called "
).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`." "`policy.load_state_dict`."
assert not torch.isinf( )
max assert not torch.isinf(max).any(), (
).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`." "`max` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
# normalize to [0,1] # normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min) batch[key] = (batch[key] - min) / (max - min)
# normalize to [-1, 1] # normalize to [-1, 1]
@@ -138,23 +144,34 @@ class Normalize(nn.Module):
class Unnormalize(nn.Module): class Unnormalize(nn.Module):
""" """
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their original range used by the environment. Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
original range used by the environment.
Parameters:
shapes (dict): A dictionary where keys are output modalities (e.g. "action") and values are their shapes (e.g. [10]).
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "action") and values are their unnormalization modes among:
- "mean_std": multiply by standard deviation and add mean
- "min_max": go from [-1, 1] range to original range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "action") and values are dictionaries of statistic types and their values
(e.g. `{"max": torch.tensor(1)}, "min": torch.tensor(0)}`). If provided, as expected for training the model for the first time,
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
they are already in the policy state_dict.
""" """
def __init__(self, shapes, modes, stats=None): def __init__(
self,
shapes: dict[str, list[int]],
modes: dict[str, str],
stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
are their normalization modes among:
- "mean_std": subtract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
and values are dictionaries of statistic types and their values (e.g.
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
training the model for the first time, these statistics will overwrite the default buffers. If
not provided, as expected for finetuning or evaluation, the default buffers should to be
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
dataset is not needed to get the stats, since they are already in the policy state_dict.
"""
super().__init__() super().__init__()
self.shapes = shapes self.shapes = shapes
self.modes = modes self.modes = modes
@@ -166,29 +183,33 @@ class Unnormalize(nn.Module):
# TODO(rcadene): should we remove torch.no_grad? # TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad @torch.no_grad
def forward(self, batch): def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
for key, mode in self.modes.items(): for key, mode in self.modes.items():
buffer = getattr(self, "buffer_" + key.replace(".", "_")) buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if mode == "mean_std": if mode == "mean_std":
mean = buffer["mean"] mean = buffer["mean"]
std = buffer["std"] std = buffer["std"]
assert not torch.isinf( assert not torch.isinf(mean).any(), (
mean "`mean` is infinity. You forgot to initialize with `stats` as argument, or called "
).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`." "`policy.load_state_dict`."
assert not torch.isinf( )
std assert not torch.isinf(std).any(), (
).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`." "`std` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
batch[key] = batch[key] * std + mean batch[key] = batch[key] * std + mean
elif mode == "min_max": elif mode == "min_max":
min = buffer["min"] min = buffer["min"]
max = buffer["max"] max = buffer["max"]
assert not torch.isinf( assert not torch.isinf(min).any(), (
min "`min` is infinity. You forgot to initialize with `stats` as argument, or called "
).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`." "`policy.load_state_dict`."
assert not torch.isinf( )
max assert not torch.isinf(max).any(), (
).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`." "`max` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
batch[key] = (batch[key] + 1) / 2 batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min batch[key] = batch[key] * (max - min) + min
else: else:

View File

@@ -36,10 +36,10 @@ policy:
action: ["${env.action_dim}"] action: ["${env.action_dim}"]
# Normalization / Unnormalization # Normalization / Unnormalization
normalize_input_modes: input_normalization_modes:
observation.images.top: mean_std observation.images.top: mean_std
observation.state: mean_std observation.state: mean_std
unnormalize_output_modes: output_normalization_modes:
action: mean_std action: mean_std
# Architecture. # Architecture.

View File

@@ -50,10 +50,10 @@ policy:
action: ["${env.action_dim}"] action: ["${env.action_dim}"]
# Normalization / Unnormalization # Normalization / Unnormalization
normalize_input_modes: input_normalization_modes:
observation.image: mean_std observation.image: mean_std
observation.state: min_max observation.state: min_max
unnormalize_output_modes: output_normalization_modes:
action: min_max action: min_max
# Architecture / modeling. # Architecture / modeling.

View File

@@ -258,7 +258,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
policy, policy,
video_dir=Path(out_dir) / "eval", video_dir=Path(out_dir) / "eval",
max_episodes_rendered=4, max_episodes_rendered=4,
transform=offline_dataset.transform,
seed=cfg.seed, seed=cfg.seed,
) )
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline) log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)