make unit tests pass

This commit is contained in:
Cadene
2024-04-23 21:39:39 +00:00
parent 42ed7bb670
commit 0660f71556
13 changed files with 79 additions and 38 deletions

View File

@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
@dataclass
@@ -61,13 +61,17 @@ class ActionChunkingTransformerConfig:
n_action_steps: int = 100
# Normalization / Unnormalization
normalize_input_modes: dict[str, str] = {
"observation.image": "mean_std",
"observation.state": "mean_std",
}
unnormalize_output_modes: dict[str, str] = {
"action": "mean_std",
}
normalize_input_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "mean_std",
}
)
unnormalize_output_modes: dict[str, str] = field(
default_factory=lambda: {
"action": "mean_std",
}
)
# Architecture.
# Vision backbone.
vision_backbone: str = "resnet18"

View File

@@ -22,6 +22,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.utils import (
normalize_inputs,
to_buffer_dict,
unnormalize_outputs,
)
@@ -75,7 +76,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
if cfg is None:
cfg = ActionChunkingTransformerConfig()
self.cfg = cfg
self.register_buffer("dataset_stats", dataset_stats)
self.dataset_stats = to_buffer_dict(dataset_stats)
self.normalize_input_modes = cfg.normalize_input_modes
self.unnormalize_output_modes = cfg.unnormalize_output_modes
@@ -179,7 +180,12 @@ class ActionChunkingTransformerPolicy(nn.Module):
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
# has shape (n_action_steps, batch_size, *), hence the transpose.
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
actions = unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
# TODO(rcadene): make _forward return output dictionary?
out_dict = {"action": actions}
out_dict = unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
actions = out_dict["action"]
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
@@ -214,7 +220,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
loss_dict = self.forward(batch)
# TODO(rcadene): unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
# TODO(rcadene): unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
loss = loss_dict["loss"]
loss.backward()

View File

@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
@dataclass
@@ -70,13 +70,17 @@ class DiffusionConfig:
n_action_steps: int = 8
# Normalization / Unnormalization
normalize_input_modes: dict[str, str] = {
"observation.image": "mean_std",
"observation.state": "min_max",
}
unnormalize_output_modes: dict[str, str] = {
"action": "min_max",
}
normalize_input_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "min_max",
}
)
unnormalize_output_modes: dict[str, str] = field(
default_factory=lambda: {
"action": "min_max",
}
)
# Architecture / modeling.
# Vision backbone.

View File

@@ -31,6 +31,7 @@ from lerobot.common.policies.utils import (
get_dtype_from_parameters,
normalize_inputs,
populate_queues,
to_buffer_dict,
unnormalize_outputs,
)
@@ -57,7 +58,7 @@ class DiffusionPolicy(nn.Module):
if cfg is None:
cfg = DiffusionConfig()
self.cfg = cfg
self.register_buffer("dataset_stats", dataset_stats)
self.dataset_stats = to_buffer_dict(dataset_stats)
self.normalize_input_modes = cfg.normalize_input_modes
self.unnormalize_output_modes = cfg.unnormalize_output_modes
@@ -144,7 +145,11 @@ class DiffusionPolicy(nn.Module):
else:
actions = self.diffusion.generate_actions(batch)
actions = unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
# TODO(rcadene): make above methods return output dictionary?
out_dict = {"action": actions}
out_dict = unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
actions = out_dict["action"]
self._queues["action"].extend(actions.transpose(0, 1))
action = self._queues["action"].popleft()
@@ -166,7 +171,7 @@ class DiffusionPolicy(nn.Module):
loss = self.forward(batch)["loss"]
loss.backward()
# TODO(rcadene): unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
# TODO(rcadene): unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
grad_norm = torch.nn.utils.clip_grad_norm_(
self.diffusion.parameters(),

View File

@@ -66,3 +66,20 @@ def unnormalize_outputs(batch, stats, unnormalize_output_modes):
else:
raise ValueError(mode)
return batch
def to_buffer_dict(dataset_stats):
# TODO(rcadene): replace this function by `torch.BufferDict` when it exists
# see: https://github.com/pytorch/pytorch/issues/37386
# TODO(rcadene): make `to_buffer_dict` generic and add docstring
if dataset_stats is None:
return None
new_ds_stats = {}
for key, stats_dict in dataset_stats.items():
new_stats_dict = {}
for stats_type, value in stats_dict.items():
# set requires_grad=False to have the same behavior as a nn.Buffer
new_stats_dict[stats_type] = nn.Parameter(value, requires_grad=False)
new_ds_stats[key] = nn.ParameterDict(new_stats_dict)
return nn.ParameterDict(new_ds_stats)