make load_state_dict work
This commit is contained in:
@@ -21,10 +21,24 @@ class ActionChunkingTransformerConfig:
|
|||||||
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
||||||
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
||||||
environment, and throws the other 50 out.
|
environment, and throws the other 50 out.
|
||||||
image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in
|
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||||
[0, 1]) for normalization.
|
The key represents the input data name, and the value is a list indicating the dimensions
|
||||||
image_normalization_std: Value by which to divide the input image pixels (after the mean has been
|
of the corresponding data. For example, "observation.images.top" refers to an input from the
|
||||||
subtracted).
|
"top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
||||||
|
Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||||
|
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||||
|
The key represents the output data name, and the value is a list indicating the dimensions
|
||||||
|
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||||
|
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||||
|
normalize_input_modes: A dictionary specifying the normalization mode to be applied to various inputs.
|
||||||
|
The key represents the input data name, and the value specifies the type of normalization to apply.
|
||||||
|
Common normalization methods include "mean_std" (mean and standard deviation) or "min_max" (to normalize
|
||||||
|
between -1 and 1).
|
||||||
|
unnormalize_output_modes: A dictionary specifying the method to unnormalize outputs.
|
||||||
|
This parameter maps output data types to their unnormalization modes, allowing the results to be
|
||||||
|
transformed back from a normalized state to a standard state. It is typically used when output
|
||||||
|
data needs to be interpreted in its original scale or units. For example, for "action", the
|
||||||
|
unnormalization mode might be "mean_std" or "min_max".
|
||||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||||
use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
|
use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
|
||||||
torchvision.
|
torchvision.
|
||||||
@@ -51,6 +65,7 @@ class ActionChunkingTransformerConfig:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Environment.
|
# Environment.
|
||||||
|
# TODO(rcadene, alexander-soar): remove these as they are defined in input_shapes, output_shapes
|
||||||
state_dim: int = 14
|
state_dim: int = 14
|
||||||
action_dim: int = 14
|
action_dim: int = 14
|
||||||
|
|
||||||
@@ -60,6 +75,18 @@ class ActionChunkingTransformerConfig:
|
|||||||
chunk_size: int = 100
|
chunk_size: int = 100
|
||||||
n_action_steps: int = 100
|
n_action_steps: int = 100
|
||||||
|
|
||||||
|
input_shapes: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"observation.images.top": [3, 480, 640],
|
||||||
|
"observation.state": [14],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
output_shapes: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"action": [14],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
# Normalization / Unnormalization
|
||||||
normalize_input_modes: dict[str, str] = field(
|
normalize_input_modes: dict[str, str] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
@@ -72,6 +99,7 @@ class ActionChunkingTransformerConfig:
|
|||||||
"action": "mean_std",
|
"action": "mean_std",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Architecture.
|
# Architecture.
|
||||||
# Vision backbone.
|
# Vision backbone.
|
||||||
vision_backbone: str = "resnet18"
|
vision_backbone: str = "resnet18"
|
||||||
|
|||||||
@@ -20,11 +20,7 @@ from torchvision.models._utils import IntermediateLayerGetter
|
|||||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||||
|
|
||||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||||
from lerobot.common.policies.utils import (
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
normalize_inputs,
|
|
||||||
to_buffer_dict,
|
|
||||||
unnormalize_outputs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ActionChunkingTransformerPolicy(nn.Module):
|
class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
@@ -76,9 +72,10 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||||||
if cfg is None:
|
if cfg is None:
|
||||||
cfg = ActionChunkingTransformerConfig()
|
cfg = ActionChunkingTransformerConfig()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.dataset_stats = to_buffer_dict(dataset_stats)
|
|
||||||
self.normalize_input_modes = cfg.normalize_input_modes
|
self.normalize_input_modes = cfg.normalize_input_modes
|
||||||
self.unnormalize_output_modes = cfg.unnormalize_output_modes
|
self.unnormalize_output_modes = cfg.unnormalize_output_modes
|
||||||
|
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
|
||||||
|
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
|
||||||
|
|
||||||
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
||||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||||
@@ -174,7 +171,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||||||
"""
|
"""
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
|
batch = self.normalize_inputs(batch)
|
||||||
|
|
||||||
if len(self._action_queue) == 0:
|
if len(self._action_queue) == 0:
|
||||||
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
|
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
|
||||||
@@ -182,9 +179,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||||||
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
|
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
|
||||||
|
|
||||||
# TODO(rcadene): make _forward return output dictionary?
|
# TODO(rcadene): make _forward return output dictionary?
|
||||||
out_dict = {"action": actions}
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
out_dict = unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
|
|
||||||
actions = out_dict["action"]
|
|
||||||
|
|
||||||
self._action_queue.extend(actions.transpose(0, 1))
|
self._action_queue.extend(actions.transpose(0, 1))
|
||||||
return self._action_queue.popleft()
|
return self._action_queue.popleft()
|
||||||
@@ -218,9 +213,10 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
self.train()
|
self.train()
|
||||||
|
|
||||||
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
|
batch = self.normalize_inputs(batch)
|
||||||
|
|
||||||
loss_dict = self.forward(batch)
|
loss_dict = self.forward(batch)
|
||||||
# TODO(rcadene): unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
|
# TODO(rcadene): self.unnormalize_outputs(out_dict)
|
||||||
loss = loss_dict["loss"]
|
loss = loss_dict["loss"]
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
|||||||
@@ -19,10 +19,24 @@ class DiffusionConfig:
|
|||||||
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||||
See `DiffusionPolicy.select_action` for more details.
|
See `DiffusionPolicy.select_action` for more details.
|
||||||
image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in
|
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||||
[0, 1]) for normalization.
|
The key represents the input data name, and the value is a list indicating the dimensions
|
||||||
image_normalization_std: Value by which to divide the input image pixels (after the mean has been
|
of the corresponding data. For example, "observation.image" refers to an input from
|
||||||
subtracted).
|
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
||||||
|
Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||||
|
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||||
|
The key represents the output data name, and the value is a list indicating the dimensions
|
||||||
|
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||||
|
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||||
|
normalize_input_modes: A dictionary specifying the normalization mode to be applied to various inputs.
|
||||||
|
The key represents the input data name, and the value specifies the type of normalization to apply.
|
||||||
|
Common normalization methods include "mean_std" (mean and standard deviation) or "min_max" (to normalize
|
||||||
|
between -1 and 1).
|
||||||
|
unnormalize_output_modes: A dictionary specifying the method to unnormalize outputs.
|
||||||
|
This parameter maps output data types to their unnormalization modes, allowing the results to be
|
||||||
|
transformed back from a normalized state to a standard state. It is typically used when output
|
||||||
|
data needs to be interpreted in its original scale or units. For example, for "action", the
|
||||||
|
unnormalization mode might be "mean_std" or "min_max".
|
||||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||||
within the image size. If None, no cropping is done.
|
within the image size. If None, no cropping is done.
|
||||||
@@ -60,6 +74,7 @@ class DiffusionConfig:
|
|||||||
|
|
||||||
# Environment.
|
# Environment.
|
||||||
# Inherit these from the environment config.
|
# Inherit these from the environment config.
|
||||||
|
# TODO(rcadene, alexander-soar): remove these as they are defined in input_shapes, output_shapes
|
||||||
state_dim: int = 2
|
state_dim: int = 2
|
||||||
action_dim: int = 2
|
action_dim: int = 2
|
||||||
image_size: tuple[int, int] = (96, 96)
|
image_size: tuple[int, int] = (96, 96)
|
||||||
@@ -69,6 +84,18 @@ class DiffusionConfig:
|
|||||||
horizon: int = 16
|
horizon: int = 16
|
||||||
n_action_steps: int = 8
|
n_action_steps: int = 8
|
||||||
|
|
||||||
|
input_shapes: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"observation.image": [3, 96, 96],
|
||||||
|
"observation.state": [2],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
output_shapes: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"action": [2],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
# Normalization / Unnormalization
|
||||||
normalize_input_modes: dict[str, str] = field(
|
normalize_input_modes: dict[str, str] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
|
|||||||
@@ -26,13 +26,11 @@ from torch import Tensor, nn
|
|||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.utils import (
|
from lerobot.common.policies.utils import (
|
||||||
get_device_from_parameters,
|
get_device_from_parameters,
|
||||||
get_dtype_from_parameters,
|
get_dtype_from_parameters,
|
||||||
normalize_inputs,
|
|
||||||
populate_queues,
|
populate_queues,
|
||||||
to_buffer_dict,
|
|
||||||
unnormalize_outputs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -58,9 +56,10 @@ class DiffusionPolicy(nn.Module):
|
|||||||
if cfg is None:
|
if cfg is None:
|
||||||
cfg = DiffusionConfig()
|
cfg = DiffusionConfig()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.dataset_stats = to_buffer_dict(dataset_stats)
|
|
||||||
self.normalize_input_modes = cfg.normalize_input_modes
|
self.normalize_input_modes = cfg.normalize_input_modes
|
||||||
self.unnormalize_output_modes = cfg.unnormalize_output_modes
|
self.unnormalize_output_modes = cfg.unnormalize_output_modes
|
||||||
|
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
|
||||||
|
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
|
||||||
|
|
||||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||||
self._queues = None
|
self._queues = None
|
||||||
@@ -133,7 +132,7 @@ class DiffusionPolicy(nn.Module):
|
|||||||
assert "observation.state" in batch
|
assert "observation.state" in batch
|
||||||
assert len(batch) == 2
|
assert len(batch) == 2
|
||||||
|
|
||||||
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
|
batch = self.normalize_inputs(batch)
|
||||||
|
|
||||||
self._queues = populate_queues(self._queues, batch)
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
@@ -146,9 +145,7 @@ class DiffusionPolicy(nn.Module):
|
|||||||
actions = self.diffusion.generate_actions(batch)
|
actions = self.diffusion.generate_actions(batch)
|
||||||
|
|
||||||
# TODO(rcadene): make above methods return output dictionary?
|
# TODO(rcadene): make above methods return output dictionary?
|
||||||
out_dict = {"action": actions}
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
out_dict = unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
|
|
||||||
actions = out_dict["action"]
|
|
||||||
|
|
||||||
self._queues["action"].extend(actions.transpose(0, 1))
|
self._queues["action"].extend(actions.transpose(0, 1))
|
||||||
|
|
||||||
@@ -166,12 +163,12 @@ class DiffusionPolicy(nn.Module):
|
|||||||
|
|
||||||
self.diffusion.train()
|
self.diffusion.train()
|
||||||
|
|
||||||
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
|
batch = self.normalize_inputs(batch)
|
||||||
|
|
||||||
loss = self.forward(batch)["loss"]
|
loss = self.forward(batch)["loss"]
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# TODO(rcadene): unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
|
# TODO(rcadene): self.unnormalize_outputs(out_dict)
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
self.diffusion.parameters(),
|
self.diffusion.parameters(),
|
||||||
|
|||||||
174
lerobot/common/policies/normalize.py
Normal file
174
lerobot/common/policies/normalize.py
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
def create_stats_buffers(shapes, modes, stats=None):
|
||||||
|
"""
|
||||||
|
This function generates buffers to store the mean and standard deviation, or minimum and maximum values,
|
||||||
|
used for normalizing tensors. The mode of normalization is determined by the `modes` dictionary, which can
|
||||||
|
be either "mean_std" (for mean and standard deviation) or "min_max" (for minimum and maximum). These buffers
|
||||||
|
are created as PyTorch nn.ParameterDict objects with nn.Parameters set to not require gradients, suitable
|
||||||
|
for normalization purposes.
|
||||||
|
|
||||||
|
If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
|
||||||
|
and width, assuming a channel-first (c, h, w) format.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
shapes (dict): A dictionary where keys represent tensor identifiers and values represent the shapes of those tensors.
|
||||||
|
modes (dict): A dictionary specifying the normalization mode for each key in `shapes`. Valid modes are "mean_std" or "min_max".
|
||||||
|
stats (dict, optional): A dictionary containing pre-defined statistics for normalization. It can contain 'mean' and 'std' for
|
||||||
|
"mean_std" mode, or 'min' and 'max' for "min_max" mode. If provided, these statistics will overwrite the default buffers.
|
||||||
|
It's expected for training the model for the first time. If not provided, the default buffers are supposed to be overriden
|
||||||
|
by a call to `policy.load_state_dict(state_dict)`. It's useful for loading a pretrained model for finetuning or evaluation,
|
||||||
|
without requiring to initialize the dataset used to train the model just to acess the `stats`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary where keys match the `modes` and `shapes` keys, and values are nn.ParameterDict objects containing
|
||||||
|
the appropriate buffers for normalization.
|
||||||
|
"""
|
||||||
|
stats_buffers = {}
|
||||||
|
|
||||||
|
for key, mode in modes.items():
|
||||||
|
assert mode in ["mean_std", "min_max"]
|
||||||
|
|
||||||
|
shape = shapes[key]
|
||||||
|
|
||||||
|
# override shape to be invariant to height and width
|
||||||
|
if "image" in key:
|
||||||
|
# assume shape is channel first (b, c, h, w) or (b, t, c, h, w)
|
||||||
|
shape[-1] = 1
|
||||||
|
shape[-2] = 1
|
||||||
|
|
||||||
|
buffer = {}
|
||||||
|
if mode == "mean_std":
|
||||||
|
mean = torch.zeros(shape, dtype=torch.float32)
|
||||||
|
std = torch.ones(shape, dtype=torch.float32)
|
||||||
|
buffer = nn.ParameterDict(
|
||||||
|
{
|
||||||
|
"mean": nn.Parameter(mean, requires_grad=False),
|
||||||
|
"std": nn.Parameter(std, requires_grad=False),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif mode == "min_max":
|
||||||
|
# TODO(rcadene): should we assume input is in [-1, 1] range?
|
||||||
|
min = torch.ones(shape, dtype=torch.float32) * -1
|
||||||
|
max = torch.ones(shape, dtype=torch.float32)
|
||||||
|
buffer = nn.ParameterDict(
|
||||||
|
{
|
||||||
|
"min": nn.Parameter(min, requires_grad=False),
|
||||||
|
"max": nn.Parameter(max, requires_grad=False),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if stats is not None:
|
||||||
|
if mode == "mean_std":
|
||||||
|
buffer["mean"].data = stats[key]["mean"]
|
||||||
|
buffer["std"].data = stats[key]["std"]
|
||||||
|
elif mode == "min_max":
|
||||||
|
buffer["min"].data = stats[key]["min"]
|
||||||
|
buffer["max"].data = stats[key]["max"]
|
||||||
|
|
||||||
|
stats_buffers[key] = buffer
|
||||||
|
return stats_buffers
|
||||||
|
|
||||||
|
|
||||||
|
class Normalize(nn.Module):
|
||||||
|
"""
|
||||||
|
A PyTorch module for normalizing data based on predefined statistics.
|
||||||
|
|
||||||
|
The class is initialized with a set of shapes, modes, and optional pre-defined statistics. It creates buffers for normalization based
|
||||||
|
on these inputs, which are then used to adjust data during the forward pass. The normalization process operates on a batch of data,
|
||||||
|
with different keys in the batch being normalized according to the specified modes. The following normalization modes are supported:
|
||||||
|
- "mean_std": Normalizes data using the mean and standard deviation.
|
||||||
|
- "min_max": Normalizes data to a [0, 1] range and then to a [-1, 1] range.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
shapes (dict): A dictionary where keys represent tensor identifiers and values represent the shapes of those tensors.
|
||||||
|
modes (dict): A dictionary indicating the normalization mode for each tensor key. Valid modes are "mean_std" or "min_max".
|
||||||
|
stats (dict, optional): A dictionary containing pre-defined statistics for normalization. It can contain 'mean' and 'std' for
|
||||||
|
"mean_std" mode, or 'min' and 'max' for "min_max" mode. If provided, these statistics will overwrite the default buffers.
|
||||||
|
It's expected for training the model for the first time. If not provided, the default buffers are supposed to be overriden
|
||||||
|
by a call to `policy.load_state_dict(state_dict)`. It's useful for loading a pretrained model for finetuning or evaluation,
|
||||||
|
without requiring to initialize the dataset used to train the model just to acess the `stats`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, shapes, modes, stats=None):
|
||||||
|
super().__init__()
|
||||||
|
self.shapes = shapes
|
||||||
|
self.modes = modes
|
||||||
|
self.stats = stats
|
||||||
|
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||||
|
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||||
|
for key, buffer in stats_buffers.items():
|
||||||
|
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||||
|
|
||||||
|
# TODO(rcadene): should we remove torch.no_grad?
|
||||||
|
@torch.no_grad
|
||||||
|
def forward(self, batch):
|
||||||
|
for key, mode in self.modes.items():
|
||||||
|
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||||
|
|
||||||
|
if mode == "mean_std":
|
||||||
|
mean = buffer["mean"].unsqueeze(0)
|
||||||
|
std = buffer["std"].unsqueeze(0)
|
||||||
|
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||||
|
elif mode == "min_max":
|
||||||
|
min = buffer["min"].unsqueeze(0)
|
||||||
|
max = buffer["max"].unsqueeze(0)
|
||||||
|
# normalize to [0,1]
|
||||||
|
batch[key] = (batch[key] - min) / (max - min)
|
||||||
|
# normalize to [-1, 1]
|
||||||
|
batch[key] = batch[key] * 2 - 1
|
||||||
|
else:
|
||||||
|
raise ValueError(mode)
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
class Unnormalize(nn.Module):
|
||||||
|
"""
|
||||||
|
A PyTorch module for unnormalizing data based on predefined statistics.
|
||||||
|
|
||||||
|
The class is initialized with a set of shapes, modes, and optional pre-defined statistics. It creates buffers for unnormalization based
|
||||||
|
on these inputs, which are then used to adjust data during the forward pass. The unnormalization process operates on a batch of data,
|
||||||
|
with different keys in the batch being normalized according to the specified modes. The following unnormalization modes are supported:
|
||||||
|
- "mean_std": Unnormalizes data using the mean and standard deviation.
|
||||||
|
- "min_max": Unnormalizes data to a [0, 1] range and then to a [-1, 1] range.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
shapes (dict): A dictionary where keys represent tensor identifiers and values represent the shapes of those tensors.
|
||||||
|
modes (dict): A dictionary indicating the unnormalization mode for each tensor key. Valid modes are "mean_std" or "min_max".
|
||||||
|
stats (dict, optional): A dictionary containing pre-defined statistics for unnormalization. It can contain 'mean' and 'std' for
|
||||||
|
"mean_std" mode, or 'min' and 'max' for "min_max" mode. If provided, these statistics will overwrite the default buffers.
|
||||||
|
It's expected for training the model for the first time. If not provided, the default buffers are supposed to be overriden
|
||||||
|
by a call to `policy.load_state_dict(state_dict)`. It's useful for loading a pretrained model for finetuning or evaluation,
|
||||||
|
without requiring to initialize the dataset used to train the model just to acess the `stats`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, shapes, modes, stats=None):
|
||||||
|
super().__init__()
|
||||||
|
self.shapes = shapes
|
||||||
|
self.modes = modes
|
||||||
|
self.stats = stats
|
||||||
|
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||||
|
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||||
|
for key, buffer in stats_buffers.items():
|
||||||
|
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||||
|
|
||||||
|
# TODO(rcadene): should we remove torch.no_grad?
|
||||||
|
@torch.no_grad
|
||||||
|
def forward(self, batch):
|
||||||
|
for key, mode in self.modes.items():
|
||||||
|
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||||
|
|
||||||
|
if mode == "mean_std":
|
||||||
|
mean = buffer["mean"].unsqueeze(0)
|
||||||
|
std = buffer["std"].unsqueeze(0)
|
||||||
|
batch[key] = batch[key] * std + mean
|
||||||
|
elif mode == "min_max":
|
||||||
|
min = buffer["min"].unsqueeze(0)
|
||||||
|
max = buffer["max"].unsqueeze(0)
|
||||||
|
batch[key] = (batch[key] + 1) / 2
|
||||||
|
batch[key] = batch[key] * (max - min) + min
|
||||||
|
else:
|
||||||
|
raise ValueError(mode)
|
||||||
|
return batch
|
||||||
@@ -28,58 +28,3 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
|
|||||||
Note: assumes that all parameters have the same dtype.
|
Note: assumes that all parameters have the same dtype.
|
||||||
"""
|
"""
|
||||||
return next(iter(module.parameters())).dtype
|
return next(iter(module.parameters())).dtype
|
||||||
|
|
||||||
|
|
||||||
def normalize_inputs(batch, stats, normalize_input_modes):
|
|
||||||
if normalize_input_modes is None:
|
|
||||||
return batch
|
|
||||||
for key, mode in normalize_input_modes.items():
|
|
||||||
if mode == "mean_std":
|
|
||||||
mean = stats[key]["mean"].unsqueeze(0)
|
|
||||||
std = stats[key]["std"].unsqueeze(0)
|
|
||||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
|
||||||
elif mode == "min_max":
|
|
||||||
min = stats[key]["min"].unsqueeze(0)
|
|
||||||
max = stats[key]["max"].unsqueeze(0)
|
|
||||||
# normalize to [0,1]
|
|
||||||
batch[key] = (batch[key] - min) / (max - min)
|
|
||||||
# normalize to [-1, 1]
|
|
||||||
batch[key] = batch[key] * 2 - 1
|
|
||||||
else:
|
|
||||||
raise ValueError(mode)
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
def unnormalize_outputs(batch, stats, unnormalize_output_modes):
|
|
||||||
if unnormalize_output_modes is None:
|
|
||||||
return batch
|
|
||||||
for key, mode in unnormalize_output_modes.items():
|
|
||||||
if mode == "mean_std":
|
|
||||||
mean = stats[key]["mean"].unsqueeze(0)
|
|
||||||
std = stats[key]["std"].unsqueeze(0)
|
|
||||||
batch[key] = batch[key] * std + mean
|
|
||||||
elif mode == "min_max":
|
|
||||||
min = stats[key]["min"].unsqueeze(0)
|
|
||||||
max = stats[key]["max"].unsqueeze(0)
|
|
||||||
batch[key] = (batch[key] + 1) / 2
|
|
||||||
batch[key] = batch[key] * (max - min) + min
|
|
||||||
else:
|
|
||||||
raise ValueError(mode)
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
def to_buffer_dict(dataset_stats):
|
|
||||||
# TODO(rcadene): replace this function by `torch.BufferDict` when it exists
|
|
||||||
# see: https://github.com/pytorch/pytorch/issues/37386
|
|
||||||
# TODO(rcadene): make `to_buffer_dict` generic and add docstring
|
|
||||||
if dataset_stats is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
new_ds_stats = {}
|
|
||||||
for key, stats_dict in dataset_stats.items():
|
|
||||||
new_stats_dict = {}
|
|
||||||
for stats_type, value in stats_dict.items():
|
|
||||||
# set requires_grad=False to have the same behavior as a nn.Buffer
|
|
||||||
new_stats_dict[stats_type] = nn.Parameter(value, requires_grad=False)
|
|
||||||
new_ds_stats[key] = nn.ParameterDict(new_stats_dict)
|
|
||||||
return nn.ParameterDict(new_ds_stats)
|
|
||||||
|
|||||||
@@ -34,6 +34,13 @@ policy:
|
|||||||
chunk_size: 100 # chunk_size
|
chunk_size: 100 # chunk_size
|
||||||
n_action_steps: 100
|
n_action_steps: 100
|
||||||
|
|
||||||
|
input_shapes:
|
||||||
|
# TODO(rcadene, alexander-soar): add variables for height and width from the dataset/env?
|
||||||
|
observation.images.top: [3, 480, 640]
|
||||||
|
observation.state: ["${policy.state_dim}"]
|
||||||
|
output_shapes:
|
||||||
|
action: ["${policy.action_dim}"]
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
# Normalization / Unnormalization
|
||||||
normalize_input_modes:
|
normalize_input_modes:
|
||||||
observation.images.top: mean_std
|
observation.images.top: mean_std
|
||||||
|
|||||||
@@ -50,6 +50,13 @@ policy:
|
|||||||
horizon: ${horizon}
|
horizon: ${horizon}
|
||||||
n_action_steps: ${n_action_steps}
|
n_action_steps: ${n_action_steps}
|
||||||
|
|
||||||
|
input_shapes:
|
||||||
|
# TODO(rcadene, alexander-soar): add variables for height and width from the dataset/env?
|
||||||
|
observation.image: [3, 96, 96]
|
||||||
|
observation.state: ["${policy.state_dim}"]
|
||||||
|
output_shapes:
|
||||||
|
action: ["${policy.action_dim}"]
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
# Normalization / Unnormalization
|
||||||
normalize_input_modes:
|
normalize_input_modes:
|
||||||
observation.image: mean_std
|
observation.image: mean_std
|
||||||
|
|||||||
@@ -6,10 +6,10 @@ from lerobot.common.datasets.utils import cycle
|
|||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.policy_protocol import Policy
|
from lerobot.common.policies.policy_protocol import Policy
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
|
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
|
||||||
from .utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(aliberts): refactor using lerobot/__init__.py variables
|
# TODO(aliberts): refactor using lerobot/__init__.py variables
|
||||||
@@ -93,3 +93,111 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||||||
|
|
||||||
# Test step through policy
|
# Test step through policy
|
||||||
env.step(action)
|
env.step(action)
|
||||||
|
|
||||||
|
# Test load state_dict
|
||||||
|
if policy_name != "tdmpc":
|
||||||
|
# TODO(rcadene, alexander-soar): make it work for tdmpc
|
||||||
|
# TODO(rcadene, alexander-soar): how to remove need for dataset_stats?
|
||||||
|
new_policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||||
|
new_policy.load_state_dict(policy.state_dict())
|
||||||
|
new_policy.update(batch, step=0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"insert_temporal_dim",
|
||||||
|
[
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_normalize(insert_temporal_dim):
|
||||||
|
# TODO(rcadene, alexander-soar): test with real data and assert results of normalization/unnormalization
|
||||||
|
|
||||||
|
input_shapes = {
|
||||||
|
"observation.image": [3, 96, 96],
|
||||||
|
"observation.state": [10],
|
||||||
|
}
|
||||||
|
output_shapes = {
|
||||||
|
"action": [5],
|
||||||
|
}
|
||||||
|
|
||||||
|
normalize_input_modes = {
|
||||||
|
"observation.image": "mean_std",
|
||||||
|
"observation.state": "min_max",
|
||||||
|
}
|
||||||
|
unnormalize_output_modes = {
|
||||||
|
"action": "min_max",
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset_stats = {
|
||||||
|
"observation.image": {
|
||||||
|
"mean": torch.randn(3, 1, 1),
|
||||||
|
"std": torch.randn(3, 1, 1),
|
||||||
|
"min": torch.randn(3, 1, 1),
|
||||||
|
"max": torch.randn(3, 1, 1),
|
||||||
|
},
|
||||||
|
"observation.state": {
|
||||||
|
"mean": torch.randn(10),
|
||||||
|
"std": torch.randn(10),
|
||||||
|
"min": torch.randn(10),
|
||||||
|
"max": torch.randn(10),
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"mean": torch.randn(5),
|
||||||
|
"std": torch.randn(5),
|
||||||
|
"min": torch.randn(5),
|
||||||
|
"max": torch.randn(5),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
bsize = 2
|
||||||
|
input_batch = {
|
||||||
|
"observation.image": torch.randn(bsize, 3, 96, 96),
|
||||||
|
"observation.state": torch.randn(bsize, 10),
|
||||||
|
}
|
||||||
|
output_batch = {
|
||||||
|
"action": torch.randn(bsize, 5),
|
||||||
|
}
|
||||||
|
|
||||||
|
if insert_temporal_dim:
|
||||||
|
tdim = 4
|
||||||
|
|
||||||
|
for key in input_batch:
|
||||||
|
# [2,3,96,96] -> [2,tdim,3,96,96]
|
||||||
|
input_batch[key] = torch.stack([input_batch[key]] * tdim, dim=1)
|
||||||
|
|
||||||
|
for key in output_batch:
|
||||||
|
output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1)
|
||||||
|
|
||||||
|
# test without stats
|
||||||
|
normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
|
||||||
|
normalize(input_batch)
|
||||||
|
|
||||||
|
# test with stats
|
||||||
|
normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats)
|
||||||
|
normalize(input_batch)
|
||||||
|
|
||||||
|
# test loading pretrained models
|
||||||
|
new_normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
|
||||||
|
new_normalize.load_state_dict(normalize.state_dict())
|
||||||
|
new_normalize(input_batch)
|
||||||
|
|
||||||
|
# test wihtout stats
|
||||||
|
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
|
||||||
|
unnormalize(output_batch)
|
||||||
|
|
||||||
|
# test with stats
|
||||||
|
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=dataset_stats)
|
||||||
|
unnormalize(output_batch)
|
||||||
|
|
||||||
|
# test loading pretrained models
|
||||||
|
new_unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
|
||||||
|
new_unnormalize.load_state_dict(unnormalize.state_dict())
|
||||||
|
unnormalize(output_batch)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_policy(
|
||||||
|
*("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"])
|
||||||
|
)
|
||||||
|
# test_policy(insert_temporal_dim=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user