Make sure targets are normalized too (#106)
This commit is contained in:
@@ -75,13 +75,13 @@ class ActionChunkingTransformerConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
# Normalization / Unnormalization
|
||||||
normalize_input_modes: dict[str, str] = field(
|
input_normalization_modes: dict[str, str] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"observation.image": "mean_std",
|
"observation.image": "mean_std",
|
||||||
"observation.state": "mean_std",
|
"observation.state": "mean_std",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
unnormalize_output_modes: dict[str, str] = field(
|
output_normalization_modes: dict[str, str] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"action": "mean_std",
|
"action": "mean_std",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -72,8 +72,11 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||||||
if cfg is None:
|
if cfg is None:
|
||||||
cfg = ActionChunkingTransformerConfig()
|
cfg = ActionChunkingTransformerConfig()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
|
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
|
||||||
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
|
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
|
||||||
|
self.unnormalize_outputs = Unnormalize(
|
||||||
|
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
|
||||||
|
)
|
||||||
|
|
||||||
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
||||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||||
@@ -216,6 +219,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||||||
self.train()
|
self.train()
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
batch = self.normalize_targets(batch)
|
||||||
|
|
||||||
loss_dict = self.forward(batch)
|
loss_dict = self.forward(batch)
|
||||||
# TODO(rcadene): self.unnormalize_outputs(out_dict)
|
# TODO(rcadene): self.unnormalize_outputs(out_dict)
|
||||||
|
|||||||
@@ -83,17 +83,13 @@ class DiffusionConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
# Normalization / Unnormalization
|
||||||
normalize_input_modes: dict[str, str] = field(
|
input_normalization_modes: dict[str, str] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"observation.image": "mean_std",
|
"observation.image": "mean_std",
|
||||||
"observation.state": "min_max",
|
"observation.state": "min_max",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
unnormalize_output_modes: dict[str, str] = field(
|
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
|
||||||
default_factory=lambda: {
|
|
||||||
"action": "min_max",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Architecture / modeling.
|
# Architecture / modeling.
|
||||||
# Vision backbone.
|
# Vision backbone.
|
||||||
|
|||||||
@@ -56,8 +56,11 @@ class DiffusionPolicy(nn.Module):
|
|||||||
if cfg is None:
|
if cfg is None:
|
||||||
cfg = DiffusionConfig()
|
cfg = DiffusionConfig()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
|
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
|
||||||
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
|
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
|
||||||
|
self.unnormalize_outputs = Unnormalize(
|
||||||
|
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
|
||||||
|
)
|
||||||
|
|
||||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||||
self._queues = None
|
self._queues = None
|
||||||
@@ -162,6 +165,7 @@ class DiffusionPolicy(nn.Module):
|
|||||||
self.diffusion.train()
|
self.diffusion.train()
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
batch = self.normalize_targets(batch)
|
||||||
|
|
||||||
loss = self.forward(batch)["loss"]
|
loss = self.forward(batch)["loss"]
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|||||||
@@ -1,27 +1,21 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
def create_stats_buffers(shapes, modes, stats=None):
|
def create_stats_buffers(
|
||||||
|
shapes: dict[str, list[int]],
|
||||||
|
modes: dict[str, str],
|
||||||
|
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||||
|
) -> dict[str, dict[str, nn.ParameterDict]]:
|
||||||
"""
|
"""
|
||||||
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max statistics.
|
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
|
||||||
|
statistics.
|
||||||
|
|
||||||
Parameters:
|
Args: (see Normalize and Unnormalize)
|
||||||
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values are their shapes (e.g. `[3,96,96]`]).
|
|
||||||
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
|
|
||||||
and width, assuming a channel-first (c, h, w) format.
|
|
||||||
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values are their normalization modes among:
|
|
||||||
- "mean_std": substract the mean and divide by standard deviation.
|
|
||||||
- "min_max": map to [-1, 1] range.
|
|
||||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") and values are dictionaries of statistic types and their values
|
|
||||||
(e.g. `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for training the model for the first time,
|
|
||||||
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
|
|
||||||
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
|
|
||||||
they are already in the policy state_dict.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing `nn.Parameters` set to
|
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing
|
||||||
`requires_grad=False`, suitable to not be updated during backpropagation.
|
`nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
|
||||||
"""
|
"""
|
||||||
stats_buffers = {}
|
stats_buffers = {}
|
||||||
|
|
||||||
@@ -75,24 +69,32 @@ def create_stats_buffers(shapes, modes, stats=None):
|
|||||||
|
|
||||||
|
|
||||||
class Normalize(nn.Module):
|
class Normalize(nn.Module):
|
||||||
"""
|
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
|
||||||
Normalizes the input data (e.g. "observation.image") for more stable and faster convergence during training.
|
|
||||||
|
|
||||||
Parameters:
|
def __init__(
|
||||||
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values are their shapes (e.g. `[3,96,96]`]).
|
self,
|
||||||
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
|
shapes: dict[str, list[int]],
|
||||||
and width, assuming a channel-first (c, h, w) format.
|
modes: dict[str, str],
|
||||||
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values are their normalization modes among:
|
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||||
- "mean_std": substract the mean and divide by standard deviation.
|
):
|
||||||
- "min_max": map to [-1, 1] range.
|
"""
|
||||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") and values are dictionaries of statistic types and their values
|
Args:
|
||||||
(e.g. `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for training the model for the first time,
|
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
|
||||||
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
|
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
||||||
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
|
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
||||||
they are already in the policy state_dict.
|
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
||||||
"""
|
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
|
||||||
|
are their normalization modes among:
|
||||||
def __init__(self, shapes, modes, stats=None):
|
- "mean_std": subtract the mean and divide by standard deviation.
|
||||||
|
- "min_max": map to [-1, 1] range.
|
||||||
|
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
|
||||||
|
and values are dictionaries of statistic types and their values (e.g.
|
||||||
|
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
||||||
|
training the model for the first time, these statistics will overwrite the default buffers. If
|
||||||
|
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||||
|
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
||||||
|
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.shapes = shapes
|
self.shapes = shapes
|
||||||
self.modes = modes
|
self.modes = modes
|
||||||
@@ -104,29 +106,33 @@ class Normalize(nn.Module):
|
|||||||
|
|
||||||
# TODO(rcadene): should we remove torch.no_grad?
|
# TODO(rcadene): should we remove torch.no_grad?
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def forward(self, batch):
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
for key, mode in self.modes.items():
|
for key, mode in self.modes.items():
|
||||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||||
|
|
||||||
if mode == "mean_std":
|
if mode == "mean_std":
|
||||||
mean = buffer["mean"]
|
mean = buffer["mean"]
|
||||||
std = buffer["std"]
|
std = buffer["std"]
|
||||||
assert not torch.isinf(
|
assert not torch.isinf(mean).any(), (
|
||||||
mean
|
"`mean` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||||
).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
"`policy.load_state_dict`."
|
||||||
assert not torch.isinf(
|
)
|
||||||
std
|
assert not torch.isinf(std).any(), (
|
||||||
).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
"`std` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||||
|
"`policy.load_state_dict`."
|
||||||
|
)
|
||||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||||
elif mode == "min_max":
|
elif mode == "min_max":
|
||||||
min = buffer["min"]
|
min = buffer["min"]
|
||||||
max = buffer["max"]
|
max = buffer["max"]
|
||||||
assert not torch.isinf(
|
assert not torch.isinf(min).any(), (
|
||||||
min
|
"`min` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||||
).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
"`policy.load_state_dict`."
|
||||||
assert not torch.isinf(
|
)
|
||||||
max
|
assert not torch.isinf(max).any(), (
|
||||||
).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
"`max` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||||
|
"`policy.load_state_dict`."
|
||||||
|
)
|
||||||
# normalize to [0,1]
|
# normalize to [0,1]
|
||||||
batch[key] = (batch[key] - min) / (max - min)
|
batch[key] = (batch[key] - min) / (max - min)
|
||||||
# normalize to [-1, 1]
|
# normalize to [-1, 1]
|
||||||
@@ -138,23 +144,34 @@ class Normalize(nn.Module):
|
|||||||
|
|
||||||
class Unnormalize(nn.Module):
|
class Unnormalize(nn.Module):
|
||||||
"""
|
"""
|
||||||
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their original range used by the environment.
|
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
|
||||||
|
original range used by the environment.
|
||||||
Parameters:
|
|
||||||
shapes (dict): A dictionary where keys are output modalities (e.g. "action") and values are their shapes (e.g. [10]).
|
|
||||||
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
|
|
||||||
and width, assuming a channel-first (c, h, w) format.
|
|
||||||
modes (dict): A dictionary where keys are output modalities (e.g. "action") and values are their unnormalization modes among:
|
|
||||||
- "mean_std": multiply by standard deviation and add mean
|
|
||||||
- "min_max": go from [-1, 1] range to original range.
|
|
||||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "action") and values are dictionaries of statistic types and their values
|
|
||||||
(e.g. `{"max": torch.tensor(1)}, "min": torch.tensor(0)}`). If provided, as expected for training the model for the first time,
|
|
||||||
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
|
|
||||||
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
|
|
||||||
they are already in the policy state_dict.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, shapes, modes, stats=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
shapes: dict[str, list[int]],
|
||||||
|
modes: dict[str, str],
|
||||||
|
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
|
||||||
|
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
||||||
|
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
||||||
|
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
||||||
|
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
|
||||||
|
are their normalization modes among:
|
||||||
|
- "mean_std": subtract the mean and divide by standard deviation.
|
||||||
|
- "min_max": map to [-1, 1] range.
|
||||||
|
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
|
||||||
|
and values are dictionaries of statistic types and their values (e.g.
|
||||||
|
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
||||||
|
training the model for the first time, these statistics will overwrite the default buffers. If
|
||||||
|
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||||
|
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
||||||
|
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.shapes = shapes
|
self.shapes = shapes
|
||||||
self.modes = modes
|
self.modes = modes
|
||||||
@@ -166,29 +183,33 @@ class Unnormalize(nn.Module):
|
|||||||
|
|
||||||
# TODO(rcadene): should we remove torch.no_grad?
|
# TODO(rcadene): should we remove torch.no_grad?
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def forward(self, batch):
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
for key, mode in self.modes.items():
|
for key, mode in self.modes.items():
|
||||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||||
|
|
||||||
if mode == "mean_std":
|
if mode == "mean_std":
|
||||||
mean = buffer["mean"]
|
mean = buffer["mean"]
|
||||||
std = buffer["std"]
|
std = buffer["std"]
|
||||||
assert not torch.isinf(
|
assert not torch.isinf(mean).any(), (
|
||||||
mean
|
"`mean` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||||
).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
"`policy.load_state_dict`."
|
||||||
assert not torch.isinf(
|
)
|
||||||
std
|
assert not torch.isinf(std).any(), (
|
||||||
).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
"`std` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||||
|
"`policy.load_state_dict`."
|
||||||
|
)
|
||||||
batch[key] = batch[key] * std + mean
|
batch[key] = batch[key] * std + mean
|
||||||
elif mode == "min_max":
|
elif mode == "min_max":
|
||||||
min = buffer["min"]
|
min = buffer["min"]
|
||||||
max = buffer["max"]
|
max = buffer["max"]
|
||||||
assert not torch.isinf(
|
assert not torch.isinf(min).any(), (
|
||||||
min
|
"`min` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||||
).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
"`policy.load_state_dict`."
|
||||||
assert not torch.isinf(
|
)
|
||||||
max
|
assert not torch.isinf(max).any(), (
|
||||||
).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
"`max` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||||
|
"`policy.load_state_dict`."
|
||||||
|
)
|
||||||
batch[key] = (batch[key] + 1) / 2
|
batch[key] = (batch[key] + 1) / 2
|
||||||
batch[key] = batch[key] * (max - min) + min
|
batch[key] = batch[key] * (max - min) + min
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -36,10 +36,10 @@ policy:
|
|||||||
action: ["${env.action_dim}"]
|
action: ["${env.action_dim}"]
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
# Normalization / Unnormalization
|
||||||
normalize_input_modes:
|
input_normalization_modes:
|
||||||
observation.images.top: mean_std
|
observation.images.top: mean_std
|
||||||
observation.state: mean_std
|
observation.state: mean_std
|
||||||
unnormalize_output_modes:
|
output_normalization_modes:
|
||||||
action: mean_std
|
action: mean_std
|
||||||
|
|
||||||
# Architecture.
|
# Architecture.
|
||||||
|
|||||||
@@ -50,10 +50,10 @@ policy:
|
|||||||
action: ["${env.action_dim}"]
|
action: ["${env.action_dim}"]
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
# Normalization / Unnormalization
|
||||||
normalize_input_modes:
|
input_normalization_modes:
|
||||||
observation.image: mean_std
|
observation.image: mean_std
|
||||||
observation.state: min_max
|
observation.state: min_max
|
||||||
unnormalize_output_modes:
|
output_normalization_modes:
|
||||||
action: min_max
|
action: min_max
|
||||||
|
|
||||||
# Architecture / modeling.
|
# Architecture / modeling.
|
||||||
|
|||||||
@@ -258,7 +258,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||||||
policy,
|
policy,
|
||||||
video_dir=Path(out_dir) / "eval",
|
video_dir=Path(out_dir) / "eval",
|
||||||
max_episodes_rendered=4,
|
max_episodes_rendered=4,
|
||||||
transform=offline_dataset.transform,
|
|
||||||
seed=cfg.seed,
|
seed=cfg.seed,
|
||||||
)
|
)
|
||||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
|
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
|
||||||
|
|||||||
Reference in New Issue
Block a user