make load_state_dict work

This commit is contained in:
Cadene
2024-04-24 15:40:09 +00:00
parent 0660f71556
commit 72751b7cf6
9 changed files with 376 additions and 87 deletions

View File

@@ -21,10 +21,24 @@ class ActionChunkingTransformerConfig:
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
environment, and throws the other 50 out.
image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in
[0, 1]) for normalization.
image_normalization_std: Value by which to divide the input image pixels (after the mean has been
subtracted).
input_shapes: A dictionary defining the shapes of the input data for the policy.
The key represents the input data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "observation.images.top" refers to an input from the
"top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
Importantly, shapes doesnt include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy.
The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
normalize_input_modes: A dictionary specifying the normalization mode to be applied to various inputs.
The key represents the input data name, and the value specifies the type of normalization to apply.
Common normalization methods include "mean_std" (mean and standard deviation) or "min_max" (to normalize
between -1 and 1).
unnormalize_output_modes: A dictionary specifying the method to unnormalize outputs.
This parameter maps output data types to their unnormalization modes, allowing the results to be
transformed back from a normalized state to a standard state. It is typically used when output
data needs to be interpreted in its original scale or units. For example, for "action", the
unnormalization mode might be "mean_std" or "min_max".
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
torchvision.
@@ -51,6 +65,7 @@ class ActionChunkingTransformerConfig:
"""
# Environment.
# TODO(rcadene, alexander-soar): remove these as they are defined in input_shapes, output_shapes
state_dim: int = 14
action_dim: int = 14
@@ -60,6 +75,18 @@ class ActionChunkingTransformerConfig:
chunk_size: int = 100
n_action_steps: int = 100
input_shapes: dict[str, str] = field(
default_factory=lambda: {
"observation.images.top": [3, 480, 640],
"observation.state": [14],
}
)
output_shapes: dict[str, str] = field(
default_factory=lambda: {
"action": [14],
}
)
# Normalization / Unnormalization
normalize_input_modes: dict[str, str] = field(
default_factory=lambda: {
@@ -72,6 +99,7 @@ class ActionChunkingTransformerConfig:
"action": "mean_std",
}
)
# Architecture.
# Vision backbone.
vision_backbone: str = "resnet18"

View File

@@ -20,11 +20,7 @@ from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.utils import (
normalize_inputs,
to_buffer_dict,
unnormalize_outputs,
)
from lerobot.common.policies.normalize import Normalize, Unnormalize
class ActionChunkingTransformerPolicy(nn.Module):
@@ -76,9 +72,10 @@ class ActionChunkingTransformerPolicy(nn.Module):
if cfg is None:
cfg = ActionChunkingTransformerConfig()
self.cfg = cfg
self.dataset_stats = to_buffer_dict(dataset_stats)
self.normalize_input_modes = cfg.normalize_input_modes
self.unnormalize_output_modes = cfg.unnormalize_output_modes
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
@@ -174,7 +171,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
"""
self.eval()
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
batch = self.normalize_inputs(batch)
if len(self._action_queue) == 0:
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
@@ -182,9 +179,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
# TODO(rcadene): make _forward return output dictionary?
out_dict = {"action": actions}
out_dict = unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
actions = out_dict["action"]
actions = self.unnormalize_outputs({"action": actions})["action"]
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
@@ -218,9 +213,10 @@ class ActionChunkingTransformerPolicy(nn.Module):
start_time = time.time()
self.train()
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
batch = self.normalize_inputs(batch)
loss_dict = self.forward(batch)
# TODO(rcadene): unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
# TODO(rcadene): self.unnormalize_outputs(out_dict)
loss = loss_dict["loss"]
loss.backward()

View File

@@ -19,10 +19,24 @@ class DiffusionConfig:
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
See `DiffusionPolicy.select_action` for more details.
image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in
[0, 1]) for normalization.
image_normalization_std: Value by which to divide the input image pixels (after the mean has been
subtracted).
input_shapes: A dictionary defining the shapes of the input data for the policy.
The key represents the input data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "observation.image" refers to an input from
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
Importantly, shapes doesnt include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy.
The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
normalize_input_modes: A dictionary specifying the normalization mode to be applied to various inputs.
The key represents the input data name, and the value specifies the type of normalization to apply.
Common normalization methods include "mean_std" (mean and standard deviation) or "min_max" (to normalize
between -1 and 1).
unnormalize_output_modes: A dictionary specifying the method to unnormalize outputs.
This parameter maps output data types to their unnormalization modes, allowing the results to be
transformed back from a normalized state to a standard state. It is typically used when output
data needs to be interpreted in its original scale or units. For example, for "action", the
unnormalization mode might be "mean_std" or "min_max".
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.
@@ -60,6 +74,7 @@ class DiffusionConfig:
# Environment.
# Inherit these from the environment config.
# TODO(rcadene, alexander-soar): remove these as they are defined in input_shapes, output_shapes
state_dim: int = 2
action_dim: int = 2
image_size: tuple[int, int] = (96, 96)
@@ -69,6 +84,18 @@ class DiffusionConfig:
horizon: int = 16
n_action_steps: int = 8
input_shapes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": [3, 96, 96],
"observation.state": [2],
}
)
output_shapes: dict[str, str] = field(
default_factory=lambda: {
"action": [2],
}
)
# Normalization / Unnormalization
normalize_input_modes: dict[str, str] = field(
default_factory=lambda: {

View File

@@ -26,13 +26,11 @@ from torch import Tensor, nn
from torch.nn.modules.batchnorm import _BatchNorm
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.utils import (
get_device_from_parameters,
get_dtype_from_parameters,
normalize_inputs,
populate_queues,
to_buffer_dict,
unnormalize_outputs,
)
@@ -58,9 +56,10 @@ class DiffusionPolicy(nn.Module):
if cfg is None:
cfg = DiffusionConfig()
self.cfg = cfg
self.dataset_stats = to_buffer_dict(dataset_stats)
self.normalize_input_modes = cfg.normalize_input_modes
self.unnormalize_output_modes = cfg.unnormalize_output_modes
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
@@ -133,7 +132,7 @@ class DiffusionPolicy(nn.Module):
assert "observation.state" in batch
assert len(batch) == 2
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
batch = self.normalize_inputs(batch)
self._queues = populate_queues(self._queues, batch)
@@ -146,9 +145,7 @@ class DiffusionPolicy(nn.Module):
actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary?
out_dict = {"action": actions}
out_dict = unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
actions = out_dict["action"]
actions = self.unnormalize_outputs({"action": actions})["action"]
self._queues["action"].extend(actions.transpose(0, 1))
@@ -166,12 +163,12 @@ class DiffusionPolicy(nn.Module):
self.diffusion.train()
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
batch = self.normalize_inputs(batch)
loss = self.forward(batch)["loss"]
loss.backward()
# TODO(rcadene): unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
# TODO(rcadene): self.unnormalize_outputs(out_dict)
grad_norm = torch.nn.utils.clip_grad_norm_(
self.diffusion.parameters(),

View File

@@ -0,0 +1,174 @@
import torch
from torch import nn
def create_stats_buffers(shapes, modes, stats=None):
"""
This function generates buffers to store the mean and standard deviation, or minimum and maximum values,
used for normalizing tensors. The mode of normalization is determined by the `modes` dictionary, which can
be either "mean_std" (for mean and standard deviation) or "min_max" (for minimum and maximum). These buffers
are created as PyTorch nn.ParameterDict objects with nn.Parameters set to not require gradients, suitable
for normalization purposes.
If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
and width, assuming a channel-first (c, h, w) format.
Parameters:
shapes (dict): A dictionary where keys represent tensor identifiers and values represent the shapes of those tensors.
modes (dict): A dictionary specifying the normalization mode for each key in `shapes`. Valid modes are "mean_std" or "min_max".
stats (dict, optional): A dictionary containing pre-defined statistics for normalization. It can contain 'mean' and 'std' for
"mean_std" mode, or 'min' and 'max' for "min_max" mode. If provided, these statistics will overwrite the default buffers.
It's expected for training the model for the first time. If not provided, the default buffers are supposed to be overriden
by a call to `policy.load_state_dict(state_dict)`. It's useful for loading a pretrained model for finetuning or evaluation,
without requiring to initialize the dataset used to train the model just to acess the `stats`.
Returns:
dict: A dictionary where keys match the `modes` and `shapes` keys, and values are nn.ParameterDict objects containing
the appropriate buffers for normalization.
"""
stats_buffers = {}
for key, mode in modes.items():
assert mode in ["mean_std", "min_max"]
shape = shapes[key]
# override shape to be invariant to height and width
if "image" in key:
# assume shape is channel first (b, c, h, w) or (b, t, c, h, w)
shape[-1] = 1
shape[-2] = 1
buffer = {}
if mode == "mean_std":
mean = torch.zeros(shape, dtype=torch.float32)
std = torch.ones(shape, dtype=torch.float32)
buffer = nn.ParameterDict(
{
"mean": nn.Parameter(mean, requires_grad=False),
"std": nn.Parameter(std, requires_grad=False),
}
)
elif mode == "min_max":
# TODO(rcadene): should we assume input is in [-1, 1] range?
min = torch.ones(shape, dtype=torch.float32) * -1
max = torch.ones(shape, dtype=torch.float32)
buffer = nn.ParameterDict(
{
"min": nn.Parameter(min, requires_grad=False),
"max": nn.Parameter(max, requires_grad=False),
}
)
if stats is not None:
if mode == "mean_std":
buffer["mean"].data = stats[key]["mean"]
buffer["std"].data = stats[key]["std"]
elif mode == "min_max":
buffer["min"].data = stats[key]["min"]
buffer["max"].data = stats[key]["max"]
stats_buffers[key] = buffer
return stats_buffers
class Normalize(nn.Module):
"""
A PyTorch module for normalizing data based on predefined statistics.
The class is initialized with a set of shapes, modes, and optional pre-defined statistics. It creates buffers for normalization based
on these inputs, which are then used to adjust data during the forward pass. The normalization process operates on a batch of data,
with different keys in the batch being normalized according to the specified modes. The following normalization modes are supported:
- "mean_std": Normalizes data using the mean and standard deviation.
- "min_max": Normalizes data to a [0, 1] range and then to a [-1, 1] range.
Parameters:
shapes (dict): A dictionary where keys represent tensor identifiers and values represent the shapes of those tensors.
modes (dict): A dictionary indicating the normalization mode for each tensor key. Valid modes are "mean_std" or "min_max".
stats (dict, optional): A dictionary containing pre-defined statistics for normalization. It can contain 'mean' and 'std' for
"mean_std" mode, or 'min' and 'max' for "min_max" mode. If provided, these statistics will overwrite the default buffers.
It's expected for training the model for the first time. If not provided, the default buffers are supposed to be overriden
by a call to `policy.load_state_dict(state_dict)`. It's useful for loading a pretrained model for finetuning or evaluation,
without requiring to initialize the dataset used to train the model just to acess the `stats`.
"""
def __init__(self, shapes, modes, stats=None):
super().__init__()
self.shapes = shapes
self.modes = modes
self.stats = stats
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
stats_buffers = create_stats_buffers(shapes, modes, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
def forward(self, batch):
for key, mode in self.modes.items():
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if mode == "mean_std":
mean = buffer["mean"].unsqueeze(0)
std = buffer["std"].unsqueeze(0)
batch[key] = (batch[key] - mean) / (std + 1e-8)
elif mode == "min_max":
min = buffer["min"].unsqueeze(0)
max = buffer["max"].unsqueeze(0)
# normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min)
# normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1
else:
raise ValueError(mode)
return batch
class Unnormalize(nn.Module):
"""
A PyTorch module for unnormalizing data based on predefined statistics.
The class is initialized with a set of shapes, modes, and optional pre-defined statistics. It creates buffers for unnormalization based
on these inputs, which are then used to adjust data during the forward pass. The unnormalization process operates on a batch of data,
with different keys in the batch being normalized according to the specified modes. The following unnormalization modes are supported:
- "mean_std": Unnormalizes data using the mean and standard deviation.
- "min_max": Unnormalizes data to a [0, 1] range and then to a [-1, 1] range.
Parameters:
shapes (dict): A dictionary where keys represent tensor identifiers and values represent the shapes of those tensors.
modes (dict): A dictionary indicating the unnormalization mode for each tensor key. Valid modes are "mean_std" or "min_max".
stats (dict, optional): A dictionary containing pre-defined statistics for unnormalization. It can contain 'mean' and 'std' for
"mean_std" mode, or 'min' and 'max' for "min_max" mode. If provided, these statistics will overwrite the default buffers.
It's expected for training the model for the first time. If not provided, the default buffers are supposed to be overriden
by a call to `policy.load_state_dict(state_dict)`. It's useful for loading a pretrained model for finetuning or evaluation,
without requiring to initialize the dataset used to train the model just to acess the `stats`.
"""
def __init__(self, shapes, modes, stats=None):
super().__init__()
self.shapes = shapes
self.modes = modes
self.stats = stats
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
stats_buffers = create_stats_buffers(shapes, modes, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
def forward(self, batch):
for key, mode in self.modes.items():
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if mode == "mean_std":
mean = buffer["mean"].unsqueeze(0)
std = buffer["std"].unsqueeze(0)
batch[key] = batch[key] * std + mean
elif mode == "min_max":
min = buffer["min"].unsqueeze(0)
max = buffer["max"].unsqueeze(0)
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min
else:
raise ValueError(mode)
return batch

View File

@@ -28,58 +28,3 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
Note: assumes that all parameters have the same dtype.
"""
return next(iter(module.parameters())).dtype
def normalize_inputs(batch, stats, normalize_input_modes):
if normalize_input_modes is None:
return batch
for key, mode in normalize_input_modes.items():
if mode == "mean_std":
mean = stats[key]["mean"].unsqueeze(0)
std = stats[key]["std"].unsqueeze(0)
batch[key] = (batch[key] - mean) / (std + 1e-8)
elif mode == "min_max":
min = stats[key]["min"].unsqueeze(0)
max = stats[key]["max"].unsqueeze(0)
# normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min)
# normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1
else:
raise ValueError(mode)
return batch
def unnormalize_outputs(batch, stats, unnormalize_output_modes):
if unnormalize_output_modes is None:
return batch
for key, mode in unnormalize_output_modes.items():
if mode == "mean_std":
mean = stats[key]["mean"].unsqueeze(0)
std = stats[key]["std"].unsqueeze(0)
batch[key] = batch[key] * std + mean
elif mode == "min_max":
min = stats[key]["min"].unsqueeze(0)
max = stats[key]["max"].unsqueeze(0)
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min
else:
raise ValueError(mode)
return batch
def to_buffer_dict(dataset_stats):
# TODO(rcadene): replace this function by `torch.BufferDict` when it exists
# see: https://github.com/pytorch/pytorch/issues/37386
# TODO(rcadene): make `to_buffer_dict` generic and add docstring
if dataset_stats is None:
return None
new_ds_stats = {}
for key, stats_dict in dataset_stats.items():
new_stats_dict = {}
for stats_type, value in stats_dict.items():
# set requires_grad=False to have the same behavior as a nn.Buffer
new_stats_dict[stats_type] = nn.Parameter(value, requires_grad=False)
new_ds_stats[key] = nn.ParameterDict(new_stats_dict)
return nn.ParameterDict(new_ds_stats)

View File

@@ -34,6 +34,13 @@ policy:
chunk_size: 100 # chunk_size
n_action_steps: 100
input_shapes:
# TODO(rcadene, alexander-soar): add variables for height and width from the dataset/env?
observation.images.top: [3, 480, 640]
observation.state: ["${policy.state_dim}"]
output_shapes:
action: ["${policy.action_dim}"]
# Normalization / Unnormalization
normalize_input_modes:
observation.images.top: mean_std

View File

@@ -50,6 +50,13 @@ policy:
horizon: ${horizon}
n_action_steps: ${n_action_steps}
input_shapes:
# TODO(rcadene, alexander-soar): add variables for height and width from the dataset/env?
observation.image: [3, 96, 96]
observation.state: ["${policy.state_dim}"]
output_shapes:
action: ["${policy.action_dim}"]
# Normalization / Unnormalization
normalize_input_modes:
observation.image: mean_std

View File

@@ -6,10 +6,10 @@ from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config
from .utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
# TODO(aliberts): refactor using lerobot/__init__.py variables
@@ -93,3 +93,111 @@ def test_policy(env_name, policy_name, extra_overrides):
# Test step through policy
env.step(action)
# Test load state_dict
if policy_name != "tdmpc":
# TODO(rcadene, alexander-soar): make it work for tdmpc
# TODO(rcadene, alexander-soar): how to remove need for dataset_stats?
new_policy = make_policy(cfg, dataset_stats=dataset.stats)
new_policy.load_state_dict(policy.state_dict())
new_policy.update(batch, step=0)
@pytest.mark.parametrize(
"insert_temporal_dim",
[
False,
True,
],
)
def test_normalize(insert_temporal_dim):
# TODO(rcadene, alexander-soar): test with real data and assert results of normalization/unnormalization
input_shapes = {
"observation.image": [3, 96, 96],
"observation.state": [10],
}
output_shapes = {
"action": [5],
}
normalize_input_modes = {
"observation.image": "mean_std",
"observation.state": "min_max",
}
unnormalize_output_modes = {
"action": "min_max",
}
dataset_stats = {
"observation.image": {
"mean": torch.randn(3, 1, 1),
"std": torch.randn(3, 1, 1),
"min": torch.randn(3, 1, 1),
"max": torch.randn(3, 1, 1),
},
"observation.state": {
"mean": torch.randn(10),
"std": torch.randn(10),
"min": torch.randn(10),
"max": torch.randn(10),
},
"action": {
"mean": torch.randn(5),
"std": torch.randn(5),
"min": torch.randn(5),
"max": torch.randn(5),
},
}
bsize = 2
input_batch = {
"observation.image": torch.randn(bsize, 3, 96, 96),
"observation.state": torch.randn(bsize, 10),
}
output_batch = {
"action": torch.randn(bsize, 5),
}
if insert_temporal_dim:
tdim = 4
for key in input_batch:
# [2,3,96,96] -> [2,tdim,3,96,96]
input_batch[key] = torch.stack([input_batch[key]] * tdim, dim=1)
for key in output_batch:
output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1)
# test without stats
normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
normalize(input_batch)
# test with stats
normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats)
normalize(input_batch)
# test loading pretrained models
new_normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
new_normalize.load_state_dict(normalize.state_dict())
new_normalize(input_batch)
# test wihtout stats
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
unnormalize(output_batch)
# test with stats
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=dataset_stats)
unnormalize(output_batch)
# test loading pretrained models
new_unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
new_unnormalize.load_state_dict(unnormalize.state_dict())
unnormalize(output_batch)
if __name__ == "__main__":
test_policy(
*("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"])
)
# test_policy(insert_temporal_dim=True)