Simplify configs (#550)

Co-authored-by: Remi <remi.cadene@huggingface.co>
Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
This commit is contained in:
Simon Alibert
2025-01-31 13:57:37 +01:00
committed by GitHub
parent 1ee1acf8ad
commit 3c0a209f9f
119 changed files with 5761 additions and 5466 deletions

View File

@@ -33,21 +33,16 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
class TDMPCPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "tdmpc"],
):
class TDMPCPolicy(PreTrainedPolicy):
"""Implementation of TD-MPC learning + inference.
Please note several warnings for this policy.
@@ -65,11 +60,10 @@ class TDMPCPolicy(
match our xarm environment.
"""
config_class = TDMPCConfig
name = "tdmpc"
def __init__(
self, config: TDMPCConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
):
def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
@@ -77,42 +71,28 @@ class TDMPCPolicy(
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__()
if config is None:
config = TDMPCConfig()
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.model = TDMPCTOLD(config)
self.model_target = deepcopy(self.model)
for param in self.model_target.parameters():
param.requires_grad = False
if config.input_normalization_modes is not None:
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
else:
self.normalize_inputs = nn.Identity()
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
self._use_image = False
self._use_env_state = False
if len(image_keys) > 0:
assert len(image_keys) == 1
self._use_image = True
self.input_image_key = image_keys[0]
if "observation.environment_state" in config.input_shapes:
self._use_env_state = True
self.reset()
def get_optim_params(self) -> dict:
return self.parameters()
def reset(self):
"""
Clear observation and action queues. Clear previous means for warm starting of MPPI/CEM. Should be
@@ -122,9 +102,9 @@ class TDMPCPolicy(
"observation.state": deque(maxlen=1),
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
}
if self._use_image:
if self.config.image_features:
self._queues["observation.image"] = deque(maxlen=1)
if self._use_env_state:
if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=1)
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
# CEM for the next step.
@@ -134,9 +114,9 @@ class TDMPCPolicy(
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations."""
batch = self.normalize_inputs(batch)
if self._use_image:
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[self.input_image_key]
batch["observation.image"] = batch[next(iter(self.config.image_features))]
self._queues = populate_queues(self._queues, batch)
@@ -151,9 +131,9 @@ class TDMPCPolicy(
# NOTE: Order of observations matters here.
encode_keys = []
if self._use_image:
if self.config.image_features:
encode_keys.append("observation.image")
if self._use_env_state:
if self.config.env_state_feature:
encode_keys.append("observation.environment_state")
encode_keys.append("observation.state")
z = self.model.encode({k: batch[k] for k in encode_keys})
@@ -196,7 +176,7 @@ class TDMPCPolicy(
self.config.horizon,
self.config.n_pi_samples,
batch_size,
self.config.output_shapes["action"][0],
self.config.action_feature.shape[0],
device=device,
)
if self.config.n_pi_samples > 0:
@@ -215,7 +195,7 @@ class TDMPCPolicy(
# algorithm.
# The initial mean and standard deviation for the cross-entropy method (CEM).
mean = torch.zeros(
self.config.horizon, batch_size, self.config.output_shapes["action"][0], device=device
self.config.horizon, batch_size, self.config.action_feature.shape[0], device=device
)
# Maybe warm start CEM with the mean from the previous step.
if self._prev_mean is not None:
@@ -228,7 +208,7 @@ class TDMPCPolicy(
self.config.horizon,
self.config.n_gaussian_samples,
batch_size,
self.config.output_shapes["action"][0],
self.config.action_feature.shape[0],
device=std.device,
)
gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1)
@@ -330,9 +310,9 @@ class TDMPCPolicy(
device = get_device_from_parameters(self)
batch = self.normalize_inputs(batch)
if self._use_image:
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[self.input_image_key]
batch["observation.image"] = batch[next(iter(self.config.image_features))]
batch = self.normalize_targets(batch)
info = {}
@@ -347,7 +327,7 @@ class TDMPCPolicy(
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
# Apply random image augmentations.
if self._use_image and self.config.max_random_shift_ratio > 0:
if self.config.image_features and self.config.max_random_shift_ratio > 0:
observations["observation.image"] = flatten_forward_unflatten(
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
observations["observation.image"],
@@ -360,7 +340,7 @@ class TDMPCPolicy(
current_observation[k] = observations[k][0]
next_observations[k] = observations[k][1:]
horizon, batch_size = next_observations[
"observation.image" if self._use_image else "observation.environment_state"
"observation.image" if self.config.image_features else "observation.environment_state"
].shape[:2]
# Run latent rollout using the latent dynamics model and policy model.
@@ -543,7 +523,7 @@ class TDMPCTOLD(nn.Module):
self.config = config
self._encoder = TDMPCObservationEncoder(config)
self._dynamics = nn.Sequential(
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim),
@@ -554,7 +534,7 @@ class TDMPCTOLD(nn.Module):
nn.Sigmoid(),
)
self._reward = nn.Sequential(
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim),
@@ -569,12 +549,12 @@ class TDMPCTOLD(nn.Module):
nn.Linear(config.mlp_dim, config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.output_shapes["action"][0]),
nn.Linear(config.mlp_dim, config.action_feature.shape[0]),
)
self._Qs = nn.ModuleList(
[
nn.Sequential(
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Tanh(),
nn.Linear(config.mlp_dim, config.mlp_dim),
@@ -714,10 +694,13 @@ class TDMPCObservationEncoder(nn.Module):
super().__init__()
self.config = config
if "observation.image" in config.input_shapes:
if config.image_features:
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2
next(iter(config.image_features.values())).shape[0],
config.image_encoder_hidden_dim,
7,
stride=2,
),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
@@ -727,9 +710,8 @@ class TDMPCObservationEncoder(nn.Module):
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.ReLU(),
)
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
with torch.inference_mode():
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
dummy_shape = (1, *next(iter(config.image_features.values())).shape)
out_shape = get_output_shape(self.image_enc_layers, dummy_shape)[1:]
self.image_enc_layers.extend(
nn.Sequential(
nn.Flatten(),
@@ -738,19 +720,19 @@ class TDMPCObservationEncoder(nn.Module):
nn.Sigmoid(),
)
)
if "observation.state" in config.input_shapes:
if config.robot_state_feature:
self.state_enc_layers = nn.Sequential(
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim),
nn.Linear(config.robot_state_feature.shape[0], config.state_encoder_hidden_dim),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Sigmoid(),
)
if "observation.environment_state" in config.input_shapes:
if config.env_state_feature:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
),
nn.Linear(config.env_state_feature.shape[0], config.state_encoder_hidden_dim),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim),
@@ -765,12 +747,16 @@ class TDMPCObservationEncoder(nn.Module):
"""
feat = []
# NOTE: Order of observations matters here.
if "observation.image" in self.config.input_shapes:
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict["observation.image"]))
if "observation.environment_state" in self.config.input_shapes:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
if self.config.image_features:
feat.append(
flatten_forward_unflatten(
self.image_enc_layers, obs_dict[next(iter(self.config.image_features))]
)
)
if self.config.env_state_feature:
feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV]))
if self.config.robot_state_feature:
feat.append(self.state_enc_layers(obs_dict[OBS_ROBOT]))
return torch.stack(feat, dim=0).mean(0)