make load_state_dict work
This commit is contained in:
@@ -21,10 +21,24 @@ class ActionChunkingTransformerConfig:
|
||||
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
||||
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
||||
environment, and throws the other 50 out.
|
||||
image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in
|
||||
[0, 1]) for normalization.
|
||||
image_normalization_std: Value by which to divide the input image pixels (after the mean has been
|
||||
subtracted).
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
The key represents the input data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "observation.images.top" refers to an input from the
|
||||
"top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
||||
Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
The key represents the output data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
normalize_input_modes: A dictionary specifying the normalization mode to be applied to various inputs.
|
||||
The key represents the input data name, and the value specifies the type of normalization to apply.
|
||||
Common normalization methods include "mean_std" (mean and standard deviation) or "min_max" (to normalize
|
||||
between -1 and 1).
|
||||
unnormalize_output_modes: A dictionary specifying the method to unnormalize outputs.
|
||||
This parameter maps output data types to their unnormalization modes, allowing the results to be
|
||||
transformed back from a normalized state to a standard state. It is typically used when output
|
||||
data needs to be interpreted in its original scale or units. For example, for "action", the
|
||||
unnormalization mode might be "mean_std" or "min_max".
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
|
||||
torchvision.
|
||||
@@ -51,6 +65,7 @@ class ActionChunkingTransformerConfig:
|
||||
"""
|
||||
|
||||
# Environment.
|
||||
# TODO(rcadene, alexander-soar): remove these as they are defined in input_shapes, output_shapes
|
||||
state_dim: int = 14
|
||||
action_dim: int = 14
|
||||
|
||||
@@ -60,6 +75,18 @@ class ActionChunkingTransformerConfig:
|
||||
chunk_size: int = 100
|
||||
n_action_steps: int = 100
|
||||
|
||||
input_shapes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.images.top": [3, 480, 640],
|
||||
"observation.state": [14],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [14],
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
normalize_input_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
@@ -72,6 +99,7 @@ class ActionChunkingTransformerConfig:
|
||||
"action": "mean_std",
|
||||
}
|
||||
)
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: str = "resnet18"
|
||||
|
||||
@@ -20,11 +20,7 @@ from torchvision.models._utils import IntermediateLayerGetter
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||
from lerobot.common.policies.utils import (
|
||||
normalize_inputs,
|
||||
to_buffer_dict,
|
||||
unnormalize_outputs,
|
||||
)
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
|
||||
|
||||
class ActionChunkingTransformerPolicy(nn.Module):
|
||||
@@ -76,9 +72,10 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
if cfg is None:
|
||||
cfg = ActionChunkingTransformerConfig()
|
||||
self.cfg = cfg
|
||||
self.dataset_stats = to_buffer_dict(dataset_stats)
|
||||
self.normalize_input_modes = cfg.normalize_input_modes
|
||||
self.unnormalize_output_modes = cfg.unnormalize_output_modes
|
||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
|
||||
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
|
||||
|
||||
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||
@@ -174,7 +171,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
|
||||
@@ -182,9 +179,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
|
||||
|
||||
# TODO(rcadene): make _forward return output dictionary?
|
||||
out_dict = {"action": actions}
|
||||
out_dict = unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
|
||||
actions = out_dict["action"]
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
@@ -218,9 +213,10 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
start_time = time.time()
|
||||
self.train()
|
||||
|
||||
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
loss_dict = self.forward(batch)
|
||||
# TODO(rcadene): unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
|
||||
# TODO(rcadene): self.unnormalize_outputs(out_dict)
|
||||
loss = loss_dict["loss"]
|
||||
loss.backward()
|
||||
|
||||
|
||||
@@ -19,10 +19,24 @@ class DiffusionConfig:
|
||||
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||
See `DiffusionPolicy.select_action` for more details.
|
||||
image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in
|
||||
[0, 1]) for normalization.
|
||||
image_normalization_std: Value by which to divide the input image pixels (after the mean has been
|
||||
subtracted).
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
The key represents the input data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "observation.image" refers to an input from
|
||||
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
||||
Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
The key represents the output data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
normalize_input_modes: A dictionary specifying the normalization mode to be applied to various inputs.
|
||||
The key represents the input data name, and the value specifies the type of normalization to apply.
|
||||
Common normalization methods include "mean_std" (mean and standard deviation) or "min_max" (to normalize
|
||||
between -1 and 1).
|
||||
unnormalize_output_modes: A dictionary specifying the method to unnormalize outputs.
|
||||
This parameter maps output data types to their unnormalization modes, allowing the results to be
|
||||
transformed back from a normalized state to a standard state. It is typically used when output
|
||||
data needs to be interpreted in its original scale or units. For example, for "action", the
|
||||
unnormalization mode might be "mean_std" or "min_max".
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
@@ -60,6 +74,7 @@ class DiffusionConfig:
|
||||
|
||||
# Environment.
|
||||
# Inherit these from the environment config.
|
||||
# TODO(rcadene, alexander-soar): remove these as they are defined in input_shapes, output_shapes
|
||||
state_dim: int = 2
|
||||
action_dim: int = 2
|
||||
image_size: tuple[int, int] = (96, 96)
|
||||
@@ -69,6 +84,18 @@ class DiffusionConfig:
|
||||
horizon: int = 16
|
||||
n_action_steps: int = 8
|
||||
|
||||
input_shapes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 96, 96],
|
||||
"observation.state": [2],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [2],
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
normalize_input_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
|
||||
@@ -26,13 +26,11 @@ from torch import Tensor, nn
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.utils import (
|
||||
get_device_from_parameters,
|
||||
get_dtype_from_parameters,
|
||||
normalize_inputs,
|
||||
populate_queues,
|
||||
to_buffer_dict,
|
||||
unnormalize_outputs,
|
||||
)
|
||||
|
||||
|
||||
@@ -58,9 +56,10 @@ class DiffusionPolicy(nn.Module):
|
||||
if cfg is None:
|
||||
cfg = DiffusionConfig()
|
||||
self.cfg = cfg
|
||||
self.dataset_stats = to_buffer_dict(dataset_stats)
|
||||
self.normalize_input_modes = cfg.normalize_input_modes
|
||||
self.unnormalize_output_modes = cfg.unnormalize_output_modes
|
||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
|
||||
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
|
||||
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
self._queues = None
|
||||
@@ -133,7 +132,7 @@ class DiffusionPolicy(nn.Module):
|
||||
assert "observation.state" in batch
|
||||
assert len(batch) == 2
|
||||
|
||||
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -146,9 +145,7 @@ class DiffusionPolicy(nn.Module):
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
|
||||
# TODO(rcadene): make above methods return output dictionary?
|
||||
out_dict = {"action": actions}
|
||||
out_dict = unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
|
||||
actions = out_dict["action"]
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
self._queues["action"].extend(actions.transpose(0, 1))
|
||||
|
||||
@@ -166,12 +163,12 @@ class DiffusionPolicy(nn.Module):
|
||||
|
||||
self.diffusion.train()
|
||||
|
||||
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
loss = self.forward(batch)["loss"]
|
||||
loss.backward()
|
||||
|
||||
# TODO(rcadene): unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
|
||||
# TODO(rcadene): self.unnormalize_outputs(out_dict)
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.diffusion.parameters(),
|
||||
|
||||
174
lerobot/common/policies/normalize.py
Normal file
174
lerobot/common/policies/normalize.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def create_stats_buffers(shapes, modes, stats=None):
|
||||
"""
|
||||
This function generates buffers to store the mean and standard deviation, or minimum and maximum values,
|
||||
used for normalizing tensors. The mode of normalization is determined by the `modes` dictionary, which can
|
||||
be either "mean_std" (for mean and standard deviation) or "min_max" (for minimum and maximum). These buffers
|
||||
are created as PyTorch nn.ParameterDict objects with nn.Parameters set to not require gradients, suitable
|
||||
for normalization purposes.
|
||||
|
||||
If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
|
||||
and width, assuming a channel-first (c, h, w) format.
|
||||
|
||||
Parameters:
|
||||
shapes (dict): A dictionary where keys represent tensor identifiers and values represent the shapes of those tensors.
|
||||
modes (dict): A dictionary specifying the normalization mode for each key in `shapes`. Valid modes are "mean_std" or "min_max".
|
||||
stats (dict, optional): A dictionary containing pre-defined statistics for normalization. It can contain 'mean' and 'std' for
|
||||
"mean_std" mode, or 'min' and 'max' for "min_max" mode. If provided, these statistics will overwrite the default buffers.
|
||||
It's expected for training the model for the first time. If not provided, the default buffers are supposed to be overriden
|
||||
by a call to `policy.load_state_dict(state_dict)`. It's useful for loading a pretrained model for finetuning or evaluation,
|
||||
without requiring to initialize the dataset used to train the model just to acess the `stats`.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary where keys match the `modes` and `shapes` keys, and values are nn.ParameterDict objects containing
|
||||
the appropriate buffers for normalization.
|
||||
"""
|
||||
stats_buffers = {}
|
||||
|
||||
for key, mode in modes.items():
|
||||
assert mode in ["mean_std", "min_max"]
|
||||
|
||||
shape = shapes[key]
|
||||
|
||||
# override shape to be invariant to height and width
|
||||
if "image" in key:
|
||||
# assume shape is channel first (b, c, h, w) or (b, t, c, h, w)
|
||||
shape[-1] = 1
|
||||
shape[-2] = 1
|
||||
|
||||
buffer = {}
|
||||
if mode == "mean_std":
|
||||
mean = torch.zeros(shape, dtype=torch.float32)
|
||||
std = torch.ones(shape, dtype=torch.float32)
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"mean": nn.Parameter(mean, requires_grad=False),
|
||||
"std": nn.Parameter(std, requires_grad=False),
|
||||
}
|
||||
)
|
||||
elif mode == "min_max":
|
||||
# TODO(rcadene): should we assume input is in [-1, 1] range?
|
||||
min = torch.ones(shape, dtype=torch.float32) * -1
|
||||
max = torch.ones(shape, dtype=torch.float32)
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"min": nn.Parameter(min, requires_grad=False),
|
||||
"max": nn.Parameter(max, requires_grad=False),
|
||||
}
|
||||
)
|
||||
|
||||
if stats is not None:
|
||||
if mode == "mean_std":
|
||||
buffer["mean"].data = stats[key]["mean"]
|
||||
buffer["std"].data = stats[key]["std"]
|
||||
elif mode == "min_max":
|
||||
buffer["min"].data = stats[key]["min"]
|
||||
buffer["max"].data = stats[key]["max"]
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
|
||||
|
||||
class Normalize(nn.Module):
|
||||
"""
|
||||
A PyTorch module for normalizing data based on predefined statistics.
|
||||
|
||||
The class is initialized with a set of shapes, modes, and optional pre-defined statistics. It creates buffers for normalization based
|
||||
on these inputs, which are then used to adjust data during the forward pass. The normalization process operates on a batch of data,
|
||||
with different keys in the batch being normalized according to the specified modes. The following normalization modes are supported:
|
||||
- "mean_std": Normalizes data using the mean and standard deviation.
|
||||
- "min_max": Normalizes data to a [0, 1] range and then to a [-1, 1] range.
|
||||
|
||||
Parameters:
|
||||
shapes (dict): A dictionary where keys represent tensor identifiers and values represent the shapes of those tensors.
|
||||
modes (dict): A dictionary indicating the normalization mode for each tensor key. Valid modes are "mean_std" or "min_max".
|
||||
stats (dict, optional): A dictionary containing pre-defined statistics for normalization. It can contain 'mean' and 'std' for
|
||||
"mean_std" mode, or 'min' and 'max' for "min_max" mode. If provided, these statistics will overwrite the default buffers.
|
||||
It's expected for training the model for the first time. If not provided, the default buffers are supposed to be overriden
|
||||
by a call to `policy.load_state_dict(state_dict)`. It's useful for loading a pretrained model for finetuning or evaluation,
|
||||
without requiring to initialize the dataset used to train the model just to acess the `stats`.
|
||||
"""
|
||||
|
||||
def __init__(self, shapes, modes, stats=None):
|
||||
super().__init__()
|
||||
self.shapes = shapes
|
||||
self.modes = modes
|
||||
self.stats = stats
|
||||
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
def forward(self, batch):
|
||||
for key, mode in self.modes.items():
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if mode == "mean_std":
|
||||
mean = buffer["mean"].unsqueeze(0)
|
||||
std = buffer["std"].unsqueeze(0)
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif mode == "min_max":
|
||||
min = buffer["min"].unsqueeze(0)
|
||||
max = buffer["max"].unsqueeze(0)
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min)
|
||||
# normalize to [-1, 1]
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
return batch
|
||||
|
||||
|
||||
class Unnormalize(nn.Module):
|
||||
"""
|
||||
A PyTorch module for unnormalizing data based on predefined statistics.
|
||||
|
||||
The class is initialized with a set of shapes, modes, and optional pre-defined statistics. It creates buffers for unnormalization based
|
||||
on these inputs, which are then used to adjust data during the forward pass. The unnormalization process operates on a batch of data,
|
||||
with different keys in the batch being normalized according to the specified modes. The following unnormalization modes are supported:
|
||||
- "mean_std": Unnormalizes data using the mean and standard deviation.
|
||||
- "min_max": Unnormalizes data to a [0, 1] range and then to a [-1, 1] range.
|
||||
|
||||
Parameters:
|
||||
shapes (dict): A dictionary where keys represent tensor identifiers and values represent the shapes of those tensors.
|
||||
modes (dict): A dictionary indicating the unnormalization mode for each tensor key. Valid modes are "mean_std" or "min_max".
|
||||
stats (dict, optional): A dictionary containing pre-defined statistics for unnormalization. It can contain 'mean' and 'std' for
|
||||
"mean_std" mode, or 'min' and 'max' for "min_max" mode. If provided, these statistics will overwrite the default buffers.
|
||||
It's expected for training the model for the first time. If not provided, the default buffers are supposed to be overriden
|
||||
by a call to `policy.load_state_dict(state_dict)`. It's useful for loading a pretrained model for finetuning or evaluation,
|
||||
without requiring to initialize the dataset used to train the model just to acess the `stats`.
|
||||
"""
|
||||
|
||||
def __init__(self, shapes, modes, stats=None):
|
||||
super().__init__()
|
||||
self.shapes = shapes
|
||||
self.modes = modes
|
||||
self.stats = stats
|
||||
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
def forward(self, batch):
|
||||
for key, mode in self.modes.items():
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if mode == "mean_std":
|
||||
mean = buffer["mean"].unsqueeze(0)
|
||||
std = buffer["std"].unsqueeze(0)
|
||||
batch[key] = batch[key] * std + mean
|
||||
elif mode == "min_max":
|
||||
min = buffer["min"].unsqueeze(0)
|
||||
max = buffer["max"].unsqueeze(0)
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
return batch
|
||||
@@ -28,58 +28,3 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
|
||||
Note: assumes that all parameters have the same dtype.
|
||||
"""
|
||||
return next(iter(module.parameters())).dtype
|
||||
|
||||
|
||||
def normalize_inputs(batch, stats, normalize_input_modes):
|
||||
if normalize_input_modes is None:
|
||||
return batch
|
||||
for key, mode in normalize_input_modes.items():
|
||||
if mode == "mean_std":
|
||||
mean = stats[key]["mean"].unsqueeze(0)
|
||||
std = stats[key]["std"].unsqueeze(0)
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif mode == "min_max":
|
||||
min = stats[key]["min"].unsqueeze(0)
|
||||
max = stats[key]["max"].unsqueeze(0)
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min)
|
||||
# normalize to [-1, 1]
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
return batch
|
||||
|
||||
|
||||
def unnormalize_outputs(batch, stats, unnormalize_output_modes):
|
||||
if unnormalize_output_modes is None:
|
||||
return batch
|
||||
for key, mode in unnormalize_output_modes.items():
|
||||
if mode == "mean_std":
|
||||
mean = stats[key]["mean"].unsqueeze(0)
|
||||
std = stats[key]["std"].unsqueeze(0)
|
||||
batch[key] = batch[key] * std + mean
|
||||
elif mode == "min_max":
|
||||
min = stats[key]["min"].unsqueeze(0)
|
||||
max = stats[key]["max"].unsqueeze(0)
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
return batch
|
||||
|
||||
|
||||
def to_buffer_dict(dataset_stats):
|
||||
# TODO(rcadene): replace this function by `torch.BufferDict` when it exists
|
||||
# see: https://github.com/pytorch/pytorch/issues/37386
|
||||
# TODO(rcadene): make `to_buffer_dict` generic and add docstring
|
||||
if dataset_stats is None:
|
||||
return None
|
||||
|
||||
new_ds_stats = {}
|
||||
for key, stats_dict in dataset_stats.items():
|
||||
new_stats_dict = {}
|
||||
for stats_type, value in stats_dict.items():
|
||||
# set requires_grad=False to have the same behavior as a nn.Buffer
|
||||
new_stats_dict[stats_type] = nn.Parameter(value, requires_grad=False)
|
||||
new_ds_stats[key] = nn.ParameterDict(new_stats_dict)
|
||||
return nn.ParameterDict(new_ds_stats)
|
||||
|
||||
@@ -34,6 +34,13 @@ policy:
|
||||
chunk_size: 100 # chunk_size
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soar): add variables for height and width from the dataset/env?
|
||||
observation.images.top: [3, 480, 640]
|
||||
observation.state: ["${policy.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${policy.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
normalize_input_modes:
|
||||
observation.images.top: mean_std
|
||||
|
||||
@@ -50,6 +50,13 @@ policy:
|
||||
horizon: ${horizon}
|
||||
n_action_steps: ${n_action_steps}
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soar): add variables for height and width from the dataset/env?
|
||||
observation.image: [3, 96, 96]
|
||||
observation.state: ["${policy.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${policy.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
normalize_input_modes:
|
||||
observation.image: mean_std
|
||||
|
||||
@@ -6,10 +6,10 @@ from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
||||
from .utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
|
||||
|
||||
|
||||
# TODO(aliberts): refactor using lerobot/__init__.py variables
|
||||
@@ -93,3 +93,111 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
|
||||
# Test step through policy
|
||||
env.step(action)
|
||||
|
||||
# Test load state_dict
|
||||
if policy_name != "tdmpc":
|
||||
# TODO(rcadene, alexander-soar): make it work for tdmpc
|
||||
# TODO(rcadene, alexander-soar): how to remove need for dataset_stats?
|
||||
new_policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||
new_policy.load_state_dict(policy.state_dict())
|
||||
new_policy.update(batch, step=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"insert_temporal_dim",
|
||||
[
|
||||
False,
|
||||
True,
|
||||
],
|
||||
)
|
||||
def test_normalize(insert_temporal_dim):
|
||||
# TODO(rcadene, alexander-soar): test with real data and assert results of normalization/unnormalization
|
||||
|
||||
input_shapes = {
|
||||
"observation.image": [3, 96, 96],
|
||||
"observation.state": [10],
|
||||
}
|
||||
output_shapes = {
|
||||
"action": [5],
|
||||
}
|
||||
|
||||
normalize_input_modes = {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
}
|
||||
unnormalize_output_modes = {
|
||||
"action": "min_max",
|
||||
}
|
||||
|
||||
dataset_stats = {
|
||||
"observation.image": {
|
||||
"mean": torch.randn(3, 1, 1),
|
||||
"std": torch.randn(3, 1, 1),
|
||||
"min": torch.randn(3, 1, 1),
|
||||
"max": torch.randn(3, 1, 1),
|
||||
},
|
||||
"observation.state": {
|
||||
"mean": torch.randn(10),
|
||||
"std": torch.randn(10),
|
||||
"min": torch.randn(10),
|
||||
"max": torch.randn(10),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.randn(5),
|
||||
"std": torch.randn(5),
|
||||
"min": torch.randn(5),
|
||||
"max": torch.randn(5),
|
||||
},
|
||||
}
|
||||
|
||||
bsize = 2
|
||||
input_batch = {
|
||||
"observation.image": torch.randn(bsize, 3, 96, 96),
|
||||
"observation.state": torch.randn(bsize, 10),
|
||||
}
|
||||
output_batch = {
|
||||
"action": torch.randn(bsize, 5),
|
||||
}
|
||||
|
||||
if insert_temporal_dim:
|
||||
tdim = 4
|
||||
|
||||
for key in input_batch:
|
||||
# [2,3,96,96] -> [2,tdim,3,96,96]
|
||||
input_batch[key] = torch.stack([input_batch[key]] * tdim, dim=1)
|
||||
|
||||
for key in output_batch:
|
||||
output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1)
|
||||
|
||||
# test without stats
|
||||
normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
|
||||
normalize(input_batch)
|
||||
|
||||
# test with stats
|
||||
normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats)
|
||||
normalize(input_batch)
|
||||
|
||||
# test loading pretrained models
|
||||
new_normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
|
||||
new_normalize.load_state_dict(normalize.state_dict())
|
||||
new_normalize(input_batch)
|
||||
|
||||
# test wihtout stats
|
||||
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
|
||||
unnormalize(output_batch)
|
||||
|
||||
# test with stats
|
||||
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=dataset_stats)
|
||||
unnormalize(output_batch)
|
||||
|
||||
# test loading pretrained models
|
||||
new_unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
|
||||
new_unnormalize.load_state_dict(unnormalize.state_dict())
|
||||
unnormalize(output_batch)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_policy(
|
||||
*("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"])
|
||||
)
|
||||
# test_policy(insert_temporal_dim=True)
|
||||
|
||||
Reference in New Issue
Block a user