make unit tests pass
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -61,13 +61,17 @@ class ActionChunkingTransformerConfig:
|
||||
n_action_steps: int = 100
|
||||
|
||||
# Normalization / Unnormalization
|
||||
normalize_input_modes: dict[str, str] = {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "mean_std",
|
||||
}
|
||||
unnormalize_output_modes: dict[str, str] = {
|
||||
"action": "mean_std",
|
||||
}
|
||||
normalize_input_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "mean_std",
|
||||
}
|
||||
)
|
||||
unnormalize_output_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": "mean_std",
|
||||
}
|
||||
)
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: str = "resnet18"
|
||||
|
||||
@@ -22,6 +22,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||
from lerobot.common.policies.utils import (
|
||||
normalize_inputs,
|
||||
to_buffer_dict,
|
||||
unnormalize_outputs,
|
||||
)
|
||||
|
||||
@@ -75,7 +76,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
if cfg is None:
|
||||
cfg = ActionChunkingTransformerConfig()
|
||||
self.cfg = cfg
|
||||
self.register_buffer("dataset_stats", dataset_stats)
|
||||
self.dataset_stats = to_buffer_dict(dataset_stats)
|
||||
self.normalize_input_modes = cfg.normalize_input_modes
|
||||
self.unnormalize_output_modes = cfg.unnormalize_output_modes
|
||||
|
||||
@@ -179,7 +180,12 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
|
||||
# has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
|
||||
actions = unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
|
||||
|
||||
# TODO(rcadene): make _forward return output dictionary?
|
||||
out_dict = {"action": actions}
|
||||
out_dict = unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
|
||||
actions = out_dict["action"]
|
||||
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
@@ -214,7 +220,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
|
||||
batch = normalize_inputs(batch, self.dataset_stats, self.normalize_input_modes)
|
||||
loss_dict = self.forward(batch)
|
||||
# TODO(rcadene): unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
|
||||
# TODO(rcadene): unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
|
||||
loss = loss_dict["loss"]
|
||||
loss.backward()
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -70,13 +70,17 @@ class DiffusionConfig:
|
||||
n_action_steps: int = 8
|
||||
|
||||
# Normalization / Unnormalization
|
||||
normalize_input_modes: dict[str, str] = {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
}
|
||||
unnormalize_output_modes: dict[str, str] = {
|
||||
"action": "min_max",
|
||||
}
|
||||
normalize_input_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
}
|
||||
)
|
||||
unnormalize_output_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": "min_max",
|
||||
}
|
||||
)
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
|
||||
@@ -31,6 +31,7 @@ from lerobot.common.policies.utils import (
|
||||
get_dtype_from_parameters,
|
||||
normalize_inputs,
|
||||
populate_queues,
|
||||
to_buffer_dict,
|
||||
unnormalize_outputs,
|
||||
)
|
||||
|
||||
@@ -57,7 +58,7 @@ class DiffusionPolicy(nn.Module):
|
||||
if cfg is None:
|
||||
cfg = DiffusionConfig()
|
||||
self.cfg = cfg
|
||||
self.register_buffer("dataset_stats", dataset_stats)
|
||||
self.dataset_stats = to_buffer_dict(dataset_stats)
|
||||
self.normalize_input_modes = cfg.normalize_input_modes
|
||||
self.unnormalize_output_modes = cfg.unnormalize_output_modes
|
||||
|
||||
@@ -144,7 +145,11 @@ class DiffusionPolicy(nn.Module):
|
||||
else:
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
|
||||
actions = unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
|
||||
# TODO(rcadene): make above methods return output dictionary?
|
||||
out_dict = {"action": actions}
|
||||
out_dict = unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
|
||||
actions = out_dict["action"]
|
||||
|
||||
self._queues["action"].extend(actions.transpose(0, 1))
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
@@ -166,7 +171,7 @@ class DiffusionPolicy(nn.Module):
|
||||
loss = self.forward(batch)["loss"]
|
||||
loss.backward()
|
||||
|
||||
# TODO(rcadene): unnormalize_outputs(actions, self.dataset_stats, self.unnormalize_output_modes)
|
||||
# TODO(rcadene): unnormalize_outputs(out_dict, self.dataset_stats, self.unnormalize_output_modes)
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.diffusion.parameters(),
|
||||
|
||||
@@ -66,3 +66,20 @@ def unnormalize_outputs(batch, stats, unnormalize_output_modes):
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
return batch
|
||||
|
||||
|
||||
def to_buffer_dict(dataset_stats):
|
||||
# TODO(rcadene): replace this function by `torch.BufferDict` when it exists
|
||||
# see: https://github.com/pytorch/pytorch/issues/37386
|
||||
# TODO(rcadene): make `to_buffer_dict` generic and add docstring
|
||||
if dataset_stats is None:
|
||||
return None
|
||||
|
||||
new_ds_stats = {}
|
||||
for key, stats_dict in dataset_stats.items():
|
||||
new_stats_dict = {}
|
||||
for stats_type, value in stats_dict.items():
|
||||
# set requires_grad=False to have the same behavior as a nn.Buffer
|
||||
new_stats_dict[stats_type] = nn.Parameter(value, requires_grad=False)
|
||||
new_ds_stats[key] = nn.ParameterDict(new_stats_dict)
|
||||
return nn.ParameterDict(new_ds_stats)
|
||||
|
||||
Reference in New Issue
Block a user