From d1855a202ad004f801def5f773234b6f2ecd97e3 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 1 May 2024 16:40:04 +0100 Subject: [PATCH] Refactor TD-MPC (#103) Co-authored-by: Cadene Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> --- Makefile | 8 +- .../common/policies/act/configuration_act.py | 19 +- lerobot/common/policies/act/modeling_act.py | 12 +- .../diffusion/configuration_diffusion.py | 17 +- .../policies/diffusion/modeling_diffusion.py | 11 +- lerobot/common/policies/factory.py | 5 +- lerobot/common/policies/policy_protocol.py | 18 + .../policies/tdmpc/configuration_tdmpc.py | 150 ++++ lerobot/common/policies/tdmpc/helper.py | 576 ------------- .../common/policies/tdmpc/modeling_tdmpc.py | 798 ++++++++++++++++++ lerobot/common/policies/tdmpc/policy.py | 495 ----------- lerobot/configs/default.yaml | 1 + lerobot/configs/policy/tdmpc.yaml | 135 ++- lerobot/scripts/eval.py | 7 +- lerobot/scripts/train.py | 37 +- tests/test_available.py | 2 +- tests/test_policies.py | 19 +- 17 files changed, 1105 insertions(+), 1205 deletions(-) create mode 100644 lerobot/common/policies/tdmpc/configuration_tdmpc.py delete mode 100644 lerobot/common/policies/tdmpc/helper.py create mode 100644 lerobot/common/policies/tdmpc/modeling_tdmpc.py delete mode 100644 lerobot/common/policies/tdmpc/policy.py diff --git a/Makefile b/Makefile index ea6c3091..20d2c553 100644 --- a/Makefile +++ b/Makefile @@ -22,8 +22,8 @@ test-end-to-end: ${MAKE} test-act-ete-eval ${MAKE} test-diffusion-ete-train ${MAKE} test-diffusion-ete-eval - # ${MAKE} test-tdmpc-ete-train - # ${MAKE} test-tdmpc-ete-eval + ${MAKE} test-tdmpc-ete-train + ${MAKE} test-tdmpc-ete-eval ${MAKE} test-default-ete-eval test-act-ete-train: @@ -74,8 +74,10 @@ test-tdmpc-ete-train: python lerobot/scripts/train.py \ policy=tdmpc \ env=xarm \ + env.task=XarmLift-v0 \ + dataset_repo_id=lerobot/xarm_lift_medium_replay \ wandb.enable=False \ - training.offline_steps=1 \ + training.offline_steps=2 \ training.online_steps=2 \ eval.n_episodes=1 \ env.episode_length=2 \ diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index b3700a26..a3980b14 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -22,16 +22,17 @@ class ACTConfig: The key represents the input data name, and the value is a list indicating the dimensions of the corresponding data. For example, "observation.images.top" refers to an input from the "top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution. - Importantly, shapes doesnt include batch dimension or temporal dimension. + Importantly, shapes doesn't include batch dimension or temporal dimension. output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents the output data name, and the value is a list indicating the dimensions of the corresponding data. For example, "action" refers to an output shape of [14], indicating - 14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension. - normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"), - and the value specifies the normalization mode to apply. The two availables - modes are "mean_std" which substracts the mean and divide by the standard - deviation and "min_max" which rescale in a [-1, 1] range. - unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale. + 14-dimensional actions. Importantly, shapes doesn't include batch dimension or temporal dimension. + input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), + and the value specifies the normalization mode to apply. The two available modes are "mean_std" + which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a + [-1, 1] range. + output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the + original scale. Note that this is also used for normalizing the training targets. vision_backbone: Name of the torchvision resnet backbone to use for encoding images. pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone. `None` means no pretrained weights. @@ -62,13 +63,13 @@ class ACTConfig: chunk_size: int = 100 n_action_steps: int = 100 - input_shapes: dict[str, list[str]] = field( + input_shapes: dict[str, list[int]] = field( default_factory=lambda: { "observation.images.top": [3, 480, 640], "observation.state": [14], } ) - output_shapes: dict[str, list[str]] = field( + output_shapes: dict[str, list[int]] = field( default_factory=lambda: { "action": [14], } diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 448bd2cb..f9e52e02 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -31,11 +31,17 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): name = "act" - def __init__(self, config: ACTConfig | None = None, dataset_stats=None): + def __init__( + self, + config: ACTConfig | None = None, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): """ Args: config: Policy configuration class instance or None, in which case the default instantiation of the configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected + that they will be passed with a call to `load_state_dict` before the policy is used. """ super().__init__() if config is None: @@ -58,7 +64,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): self._action_queue = deque([], maxlen=self.config.n_action_steps) @torch.no_grad - def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: + def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations. This method wraps `select_actions` in order to return one action at a time for execution in the @@ -81,7 +87,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): self._action_queue.extend(actions.transpose(0, 1)) return self._action_queue.popleft() - def forward(self, batch, **_) -> dict[str, Tensor]: + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) batch = self.normalize_targets(batch) diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index af07154d..b5188488 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, field @dataclass class DiffusionConfig: - """Configuration class for Diffusion Policy. + """Configuration class for DiffusionPolicy. Defaults are configured for training with PushT providing proprioceptive and single camera observations. @@ -25,11 +25,12 @@ class DiffusionConfig: The key represents the output data name, and the value is a list indicating the dimensions of the corresponding data. For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension. - normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"), - and the value specifies the normalization mode to apply. The two availables - modes are "mean_std" which substracts the mean and divide by the standard - deviation and "min_max" which rescale in a [-1, 1] range. - unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale. + input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), + and the value specifies the normalization mode to apply. The two available modes are "mean_std" + which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a + [-1, 1] range. + output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the + original scale. Note that this is also used for normalizing the training targets. vision_backbone: Name of the torchvision resnet backbone to use for encoding images. crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit within the image size. If None, no cropping is done. @@ -70,13 +71,13 @@ class DiffusionConfig: horizon: int = 16 n_action_steps: int = 8 - input_shapes: dict[str, list[str]] = field( + input_shapes: dict[str, list[int]] = field( default_factory=lambda: { "observation.image": [3, 96, 96], "observation.state": [2], } ) - output_shapes: dict[str, list[str]] = field( + output_shapes: dict[str, list[int]] = field( default_factory=lambda: { "action": [2], } diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index f57daf63..5b6da771 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -43,15 +43,16 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): def __init__( self, config: DiffusionConfig | None = None, - dataset_stats=None, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, ): """ Args: config: Policy configuration class instance or None, in which case the default instantiation of - the configuration class is used. + the configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected + that they will be passed with a call to `load_state_dict` before the policy is used. """ super().__init__() - # TODO(alexander-soare): LR scheduler will be removed. if config is None: config = DiffusionConfig() self.config = config @@ -88,7 +89,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): } @torch.no_grad - def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: + def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations. This method handles caching a history of observations and an action trajectory generated by the @@ -136,7 +137,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): action = self._queues["action"].popleft() return action - def forward(self, batch: dict[str, Tensor], **_) -> dict[str, Tensor]: + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) batch = self.normalize_targets(batch) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 727aa80b..808a3145 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -24,7 +24,10 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg): def get_policy_and_config_classes(name: str) -> tuple[Policy, object]: """Get the policy's class and config class given a name (matching the policy class' `name` attribute).""" if name == "tdmpc": - raise NotImplementedError("Coming soon!") + from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig + from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy + + return TDMPCPolicy, TDMPCConfig elif name == "diffusion": from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy diff --git a/lerobot/common/policies/policy_protocol.py b/lerobot/common/policies/policy_protocol.py index 62bc9dfc..5749c6a8 100644 --- a/lerobot/common/policies/policy_protocol.py +++ b/lerobot/common/policies/policy_protocol.py @@ -21,6 +21,14 @@ class Policy(Protocol): name: str + def __init__(self, cfg, dataset_stats: dict[str, dict[str, Tensor]] | None = None): + """ + Args: + cfg: Policy configuration class instance or None, in which case the default instantiation of the + configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. + """ + def reset(self): """To be called whenever the environment is reset. @@ -39,3 +47,13 @@ class Policy(Protocol): When the model uses a history of observations, or outputs a sequence of actions, this method deals with caching. """ + + +@runtime_checkable +class PolicyWithUpdate(Policy, Protocol): + def update(self): + """An update method that is to be called after a training optimization step. + + Implements an additional updates the model parameters may need (for example, doing an EMA step for a + target model, or incrementing an internal buffer). + """ diff --git a/lerobot/common/policies/tdmpc/configuration_tdmpc.py b/lerobot/common/policies/tdmpc/configuration_tdmpc.py new file mode 100644 index 00000000..82e3a507 --- /dev/null +++ b/lerobot/common/policies/tdmpc/configuration_tdmpc.py @@ -0,0 +1,150 @@ +from dataclasses import dataclass, field + + +@dataclass +class TDMPCConfig: + """Configuration class for TDMPCPolicy. + + Defaults are configured for training with xarm_lift_medium_replay providing proprioceptive and single + camera observations. + + The parameters you will most likely need to change are the ones which depend on the environment / sensors. + Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift`. + + Args: + n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google + action repeats in Q-learning or ask your favorite chatbot) + horizon: Horizon for model predictive control. + input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), + and the value specifies the normalization mode to apply. The two available modes are "mean_std" + which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a + [-1, 1] range. Note that here this defaults to None meaning inputs are not normalized. This is to + match the original implementation. + output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the + original scale. Note that this is also used for normalizing the training targets. NOTE: Clipping + to [-1, +1] is used during MPPI/CEM. Therefore, it is recommended that you stick with "min_max" + normalization mode here. + image_encoder_hidden_dim: Number of channels for the convolutional layers used for image encoding. + state_encoder_hidden_dim: Hidden dimension for MLP used for state vector encoding. + latent_dim: Observation's latent embedding dimension. + q_ensemble_size: Number of Q function estimators to use in an ensemble for uncertainty estimation. + mlp_dim: Hidden dimension of MLPs used for modelling the dynamics encoder, reward function, policy + (π), Q ensemble, and V. + discount: Discount factor (γ) to use for the reinforcement learning formalism. + use_mpc: Whether to use model predictive control. The alternative is to just sample the policy model + (π) for each step. + cem_iterations: Number of iterations for the MPPI/CEM loop in MPC. + max_std: Maximum standard deviation for actions sampled from the gaussian PDF in CEM. + min_std: Minimum standard deviation for noise applied to actions sampled from the policy model (π). + Doubles up as the minimum standard deviation for actions sampled from the gaussian PDF in CEM. + n_gaussian_samples: Number of samples to draw from the gaussian distribution every CEM iteration. Must + be non-zero. + n_pi_samples: Number of samples to draw from the policy / world model rollout every CEM iteration. Can + be zero. + uncertainty_regularizer_coeff: Coefficient for the uncertainty regularization used when estimating + trajectory values (this is the λ coeffiecient in eqn 4 of FOWM). + n_elites: The number of elite samples to use for updating the gaussian parameters every CEM iteration. + elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the + elites, when updating the gaussian parameters for CEM. + gaussian_mean_momentum: Momentum (α) used for EMA updates of the mean parameter μ of the gaussian + paramters optimized in CEM. Updates are calculated as μ⁻ ← αμ⁻ + (1-α)μ. + max_random_shift_ratio: Maximum random shift (as a proportion of the image size) to apply to the + image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation + is applied. Note that the input images are assumed to be square for this augmentation. + reward_coeff: Loss weighting coefficient for the reward regression loss. + expectile_weight: Weighting (τ) used in expectile regression for the state value function (V). + v_pred < v_target is weighted by τ and v_pred >= v_target is weighted by (1-τ). τ is expected to + be in [0, 1]. Setting τ closer to 1 results in a more "optimistic" V. This is sensible to do + because v_target is obtained by evaluating the learned state-action value functions (Q) with + in-sample actions that may not be always optimal. + value_coeff: Loss weighting coefficient for both the state-action value (Q) TD loss, and the state + value (V) expectile regression loss. + consistency_coeff: Loss weighting coefficient for the consistency loss. + advantage_scaling: A factor by which the advantages are scaled prior to exponentiation for advantage + weighted regression of the policy (π) estimator parameters. Note that the exponentiated advantages + are clamped at 100.0. + pi_coeff: Loss weighting coefficient for the action regression loss. + temporal_decay_coeff: Exponential decay coefficient for decaying the loss coefficient for future time- + steps. Hint: each loss computation involves `horizon` steps worth of actions starting from the + current time step. + target_model_momentum: Momentum (α) used for EMA updates of the target models. Updates are calculated + as ϕ ← αϕ + (1-α)θ where ϕ are the parameters of the target model and θ are the parameters of the + model being trained. + """ + + # Input / output structure. + n_action_repeats: int = 2 + horizon: int = 5 + + input_shapes: dict[str, list[int]] = field( + default_factory=lambda: { + "observation.image": [3, 84, 84], + "observation.state": [4], + } + ) + output_shapes: dict[str, list[int]] = field( + default_factory=lambda: { + "action": [4], + } + ) + + # Normalization / Unnormalization + input_normalization_modes: dict[str, str] | None = None + output_normalization_modes: dict[str, str] = field( + default_factory=lambda: {"action": "min_max"}, + ) + + # Architecture / modeling. + # Neural networks. + image_encoder_hidden_dim: int = 32 + state_encoder_hidden_dim: int = 256 + latent_dim: int = 50 + q_ensemble_size: int = 5 + mlp_dim: int = 512 + # Reinforcement learning. + discount: float = 0.9 + + # Inference. + use_mpc: bool = True + cem_iterations: int = 6 + max_std: float = 2.0 + min_std: float = 0.05 + n_gaussian_samples: int = 512 + n_pi_samples: int = 51 + uncertainty_regularizer_coeff: float = 1.0 + n_elites: int = 50 + elite_weighting_temperature: float = 0.5 + gaussian_mean_momentum: float = 0.1 + + # Training and loss computation. + max_random_shift_ratio: float = 0.0476 + # Loss coefficients. + reward_coeff: float = 0.5 + expectile_weight: float = 0.9 + value_coeff: float = 0.1 + consistency_coeff: float = 20.0 + advantage_scaling: float = 3.0 + pi_coeff: float = 0.5 + temporal_decay_coeff: float = 0.5 + # Target model. + target_model_momentum: float = 0.995 + + def __post_init__(self): + """Input validation (not exhaustive).""" + if self.input_shapes["observation.image"][-2] != self.input_shapes["observation.image"][-1]: + # TODO(alexander-soare): This limitation is solely because of code in the random shift + # augmentation. It should be able to be removed. + raise ValueError( + "Only square images are handled now. Got image shape " + f"{self.input_shapes['observation.image']}." + ) + if self.n_gaussian_samples <= 0: + raise ValueError( + f"The number of guassian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`" + ) + if self.output_normalization_modes != {"action": "min_max"}: + raise ValueError( + "TD-MPC assumes the action space dimensions to all be in [-1, 1]. Therefore it is strongly " + f"advised that you stick with the default. See {self.__class__.__name__} docstring for more " + "information." + ) diff --git a/lerobot/common/policies/tdmpc/helper.py b/lerobot/common/policies/tdmpc/helper.py deleted file mode 100644 index 964f1718..00000000 --- a/lerobot/common/policies/tdmpc/helper.py +++ /dev/null @@ -1,576 +0,0 @@ -import os -import pickle -import re - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F # noqa: N812 -from torch import distributions as pyd -from torch.distributions.utils import _standard_normal - -DEFAULT_ACT_FN = nn.Mish() - - -def __REDUCE__(b): # noqa: N802, N807 - return "mean" if b else "none" - - -def l1(pred, target, reduce=False): - """Computes the L1-loss between predictions and targets.""" - return F.l1_loss(pred, target, reduction=__REDUCE__(reduce)) - - -def mse(pred, target, reduce=False): - """Computes the MSE loss between predictions and targets.""" - return F.mse_loss(pred, target, reduction=__REDUCE__(reduce)) - - -def l2_expectile(diff, expectile=0.7, reduce=False): - weight = torch.where(diff > 0, expectile, (1 - expectile)) - loss = weight * (diff**2) - reduction = __REDUCE__(reduce) - if reduction == "mean": - return torch.mean(loss) - elif reduction == "sum": - return torch.sum(loss) - return loss - - -def _get_out_shape(in_shape, layers): - """Utility function. Returns the output shape of a network for a given input shape.""" - x = torch.randn(*in_shape).unsqueeze(0) - return (nn.Sequential(*layers) if isinstance(layers, list) else layers)(x).squeeze(0).shape - - -def gaussian_logprob(eps, log_std): - """Compute Gaussian log probability.""" - residual = (-0.5 * eps.pow(2) - log_std).sum(-1, keepdim=True) - return residual - 0.5 * np.log(2 * np.pi) * eps.size(-1) - - -def squash(mu, pi, log_pi): - """Apply squashing function.""" - mu = torch.tanh(mu) - pi = torch.tanh(pi) - log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True) - return mu, pi, log_pi - - -def orthogonal_init(m): - """Orthogonal layer initialization.""" - if isinstance(m, nn.Linear): - nn.init.orthogonal_(m.weight.data) - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, nn.Conv2d): - gain = nn.init.calculate_gain("relu") - nn.init.orthogonal_(m.weight.data, gain) - if m.bias is not None: - nn.init.zeros_(m.bias) - - -def ema(m, m_target, tau): - """Update slow-moving average of online network (target network) at rate tau.""" - with torch.no_grad(): - # TODO(rcadene, aliberts): issue with strict=False - # for p, p_target in zip(m.parameters(), m_target.parameters(), strict=False): - # p_target.data.lerp_(p.data, tau) - m_params_iter = iter(m.parameters()) - m_target_params_iter = iter(m_target.parameters()) - - while True: - try: - p = next(m_params_iter) - p_target = next(m_target_params_iter) - p_target.data.lerp_(p.data, tau) - except StopIteration: - # If any iterator is exhausted, exit the loop - break - - -def set_requires_grad(net, value): - """Enable/disable gradients for a given (sub)network.""" - for param in net.parameters(): - param.requires_grad_(value) - - -class TruncatedNormal(pyd.Normal): - """Utility class implementing the truncated normal distribution.""" - - default_sample_shape = torch.Size() - - def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6): - super().__init__(loc, scale, validate_args=False) - self.low = low - self.high = high - self.eps = eps - - def _clamp(self, x): - clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps) - x = x - x.detach() + clamped_x.detach() - return x - - def sample(self, clip=None, sample_shape=default_sample_shape): - shape = self._extended_shape(sample_shape) - eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) - eps *= self.scale - if clip is not None: - eps = torch.clamp(eps, -clip, clip) - x = self.loc + eps - return self._clamp(x) - - -class NormalizeImg(nn.Module): - """Normalizes pixel observations to [0,1) range.""" - - def __init__(self): - super().__init__() - - def forward(self, x): - return x.div(255.0) - - -class Flatten(nn.Module): - """Flattens its input to a (batched) vector.""" - - def __init__(self): - super().__init__() - - def forward(self, x): - return x.view(x.size(0), -1) - - -def enc(cfg): - obs_shape = { - "rgb": (3, cfg.img_size, cfg.img_size), - "state": (cfg.state_dim,), - } - - """Returns a TOLD encoder.""" - pixels_enc_layers, state_enc_layers = None, None - if cfg.modality in {"pixels", "all"}: - C = int(3 * cfg.frame_stack) # noqa: N806 - pixels_enc_layers = [ - NormalizeImg(), - nn.Conv2d(C, cfg.num_channels, 7, stride=2), - nn.ReLU(), - nn.Conv2d(cfg.num_channels, cfg.num_channels, 5, stride=2), - nn.ReLU(), - nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2), - nn.ReLU(), - nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2), - nn.ReLU(), - ] - out_shape = _get_out_shape((C, cfg.img_size, cfg.img_size), pixels_enc_layers) - pixels_enc_layers.extend( - [ - Flatten(), - nn.Linear(np.prod(out_shape), cfg.latent_dim), - nn.LayerNorm(cfg.latent_dim), - nn.Sigmoid(), - ] - ) - if cfg.modality == "pixels": - return ConvExt(nn.Sequential(*pixels_enc_layers)) - if cfg.modality in {"state", "all"}: - state_dim = obs_shape[0] if cfg.modality == "state" else obs_shape["state"][0] - state_enc_layers = [ - nn.Linear(state_dim, cfg.enc_dim), - nn.ELU(), - nn.Linear(cfg.enc_dim, cfg.latent_dim), - nn.LayerNorm(cfg.latent_dim), - nn.Sigmoid(), - ] - if cfg.modality == "state": - return nn.Sequential(*state_enc_layers) - else: - raise NotImplementedError - - encoders = {} - for k in obs_shape: - if k == "state": - encoders[k] = nn.Sequential(*state_enc_layers) - elif k.endswith("rgb"): - encoders[k] = ConvExt(nn.Sequential(*pixels_enc_layers)) - else: - raise NotImplementedError - return Multiplexer(nn.ModuleDict(encoders)) - - -def mlp(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN): - """Returns an MLP.""" - if isinstance(mlp_dim, int): - mlp_dim = [mlp_dim, mlp_dim] - return nn.Sequential( - nn.Linear(in_dim, mlp_dim[0]), - nn.LayerNorm(mlp_dim[0]), - act_fn, - nn.Linear(mlp_dim[0], mlp_dim[1]), - nn.LayerNorm(mlp_dim[1]), - act_fn, - nn.Linear(mlp_dim[1], out_dim), - ) - - -def dynamics(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN): - """Returns a dynamics network.""" - return nn.Sequential( - mlp(in_dim, mlp_dim, out_dim, act_fn), - nn.LayerNorm(out_dim), - nn.Sigmoid(), - ) - - -def q(cfg): - action_dim = cfg.action_dim - """Returns a Q-function that uses Layer Normalization.""" - return nn.Sequential( - nn.Linear(cfg.latent_dim + action_dim, cfg.mlp_dim), - nn.LayerNorm(cfg.mlp_dim), - nn.Tanh(), - nn.Linear(cfg.mlp_dim, cfg.mlp_dim), - nn.ELU(), - nn.Linear(cfg.mlp_dim, 1), - ) - - -def v(cfg): - """Returns a state value function that uses Layer Normalization.""" - return nn.Sequential( - nn.Linear(cfg.latent_dim, cfg.mlp_dim), - nn.LayerNorm(cfg.mlp_dim), - nn.Tanh(), - nn.Linear(cfg.mlp_dim, cfg.mlp_dim), - nn.ELU(), - nn.Linear(cfg.mlp_dim, 1), - ) - - -def aug(cfg): - obs_shape = { - "rgb": (3, cfg.img_size, cfg.img_size), - "state": (4,), - } - - """Multiplex augmentation""" - if cfg.modality == "state": - return nn.Identity() - elif cfg.modality == "pixels": - return RandomShiftsAug(cfg) - else: - augs = {} - for k in obs_shape: - if k == "state": - augs[k] = nn.Identity() - elif k.endswith("rgb"): - augs[k] = RandomShiftsAug(cfg) - else: - raise NotImplementedError - return Multiplexer(nn.ModuleDict(augs)) - - -class ConvExt(nn.Module): - """Auxiliary conv net accommodating high-dim input""" - - def __init__(self, conv): - super().__init__() - self.conv = conv - - def forward(self, x): - if x.ndim > 4: - batch_shape = x.shape[:-3] - out = self.conv(x.view(-1, *x.shape[-3:])) - out = out.view(*batch_shape, *out.shape[1:]) - else: - out = self.conv(x) - return out - - -class Multiplexer(nn.Module): - """Model multiplexer""" - - def __init__(self, choices): - super().__init__() - self.choices = choices - - def forward(self, x, key=None): - if isinstance(x, dict): - if key is not None: - return self.choices[key](x) - return {k: self.choices[k](_x) for k, _x in x.items()} - return self.choices(x) - - -class RandomShiftsAug(nn.Module): - """ - Random shift image augmentation. - Adapted from https://github.com/facebookresearch/drqv2 - """ - - def __init__(self, cfg): - super().__init__() - assert cfg.modality in {"pixels", "all"} - self.pad = int(cfg.img_size / 21) - - def forward(self, x): - n, c, h, w = x.size() - assert h == w - padding = tuple([self.pad] * 4) - x = F.pad(x, padding, "replicate") - eps = 1.0 / (h + 2 * self.pad) - arange = torch.linspace( - -1.0 + eps, - 1.0 - eps, - h + 2 * self.pad, - device=x.device, - dtype=torch.float32, - )[:h] - arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2) - base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2) - base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1) - shift = torch.randint( - 0, - 2 * self.pad + 1, - size=(n, 1, 1, 2), - device=x.device, - dtype=torch.float32, - ) - shift *= 2.0 / (h + 2 * self.pad) - grid = base_grid + shift - return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False) - - -# TODO(aliberts): remove class -# class Episode: -# """Storage object for a single episode.""" - -# def __init__(self, cfg, init_obs): -# action_dim = cfg.action_dim - -# self.cfg = cfg -# self.device = torch.device(cfg.buffer_device) -# if cfg.modality in {"pixels", "state"}: -# dtype = torch.float32 if cfg.modality == "state" else torch.uint8 -# self.obses = torch.empty( -# (cfg.episode_length + 1, *init_obs.shape), -# dtype=dtype, -# device=self.device, -# ) -# self.obses[0] = torch.tensor(init_obs, dtype=dtype, device=self.device) -# elif cfg.modality == "all": -# self.obses = {} -# for k, v in init_obs.items(): -# assert k in {"rgb", "state"} -# dtype = torch.float32 if k == "state" else torch.uint8 -# self.obses[k] = torch.empty( -# (cfg.episode_length + 1, *v.shape), dtype=dtype, device=self.device -# ) -# self.obses[k][0] = torch.tensor(v, dtype=dtype, device=self.device) -# else: -# raise ValueError -# self.actions = torch.empty((cfg.episode_length, action_dim), dtype=torch.float32, device=self.device) -# self.rewards = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device) -# self.dones = torch.empty((cfg.episode_length,), dtype=torch.bool, device=self.device) -# self.masks = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device) -# self.cumulative_reward = 0 -# self.done = False -# self.success = False -# self._idx = 0 - -# def __len__(self): -# return self._idx - -# @classmethod -# def from_trajectory(cls, cfg, obses, actions, rewards, dones=None, masks=None): -# """Constructs an episode from a trajectory.""" - -# if cfg.modality in {"pixels", "state"}: -# episode = cls(cfg, obses[0]) -# episode.obses[1:] = torch.tensor(obses[1:], dtype=episode.obses.dtype, device=episode.device) -# elif cfg.modality == "all": -# episode = cls(cfg, {k: v[0] for k, v in obses.items()}) -# for k in obses: -# episode.obses[k][1:] = torch.tensor( -# obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device -# ) -# else: -# raise NotImplementedError -# episode.actions = torch.tensor(actions, dtype=episode.actions.dtype, device=episode.device) -# episode.rewards = torch.tensor(rewards, dtype=episode.rewards.dtype, device=episode.device) -# episode.dones = ( -# torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device) -# if dones is not None -# else torch.zeros_like(episode.dones) -# ) -# episode.masks = ( -# torch.tensor(masks, dtype=episode.masks.dtype, device=episode.device) -# if masks is not None -# else torch.ones_like(episode.masks) -# ) -# episode.cumulative_reward = torch.sum(episode.rewards) -# episode.done = True -# episode._idx = cfg.episode_length -# return episode - -# @property -# def first(self): -# return len(self) == 0 - -# def __add__(self, transition): -# self.add(*transition) -# return self - -# def add(self, obs, action, reward, done, mask=1.0, success=False): -# """Add a transition into the episode.""" -# if isinstance(obs, dict): -# for k, v in obs.items(): -# self.obses[k][self._idx + 1] = torch.tensor( -# v, dtype=self.obses[k].dtype, device=self.obses[k].device -# ) -# else: -# self.obses[self._idx + 1] = torch.tensor(obs, dtype=self.obses.dtype, device=self.obses.device) -# self.actions[self._idx] = action -# self.rewards[self._idx] = reward -# self.dones[self._idx] = done -# self.masks[self._idx] = mask -# self.cumulative_reward += reward -# self.done = done -# self.success = self.success or success -# self._idx += 1 - - -def get_dataset_dict(cfg, env, return_reward_normalizer=False): - """Construct a dataset for env""" - required_keys = [ - "observations", - "next_observations", - "actions", - "rewards", - "dones", - "masks", - ] - - if cfg.task.startswith("xarm"): - dataset_path = os.path.join(cfg.dataset_dir, "buffer.pkl") - print(f"Using offline dataset '{dataset_path}'") - with open(dataset_path, "rb") as f: - dataset_dict = pickle.load(f) - for k in required_keys: - if k not in dataset_dict and k[:-1] in dataset_dict: - dataset_dict[k] = dataset_dict.pop(k[:-1]) - elif cfg.task.startswith("legged"): - dataset_path = os.path.join(cfg.dataset_dir, "buffer.pkl") - print(f"Using offline dataset '{dataset_path}'") - with open(dataset_path, "rb") as f: - dataset_dict = pickle.load(f) - dataset_dict["actions"] /= env.unwrapped.clip_actions - print(f"clip_actions={env.unwrapped.clip_actions}") - else: - import d4rl - - dataset_dict = d4rl.qlearning_dataset(env) - dones = np.full_like(dataset_dict["rewards"], False, dtype=bool) - - for i in range(len(dones) - 1): - if ( - np.linalg.norm(dataset_dict["observations"][i + 1] - dataset_dict["next_observations"][i]) - > 1e-6 - or dataset_dict["terminals"][i] == 1.0 - ): - dones[i] = True - - dones[-1] = True - - dataset_dict["masks"] = 1.0 - dataset_dict["terminals"] - del dataset_dict["terminals"] - - for k, v in dataset_dict.items(): - dataset_dict[k] = v.astype(np.float32) - - dataset_dict["dones"] = dones - - if cfg.is_data_clip: - lim = 1 - cfg.data_clip_eps - dataset_dict["actions"] = np.clip(dataset_dict["actions"], -lim, lim) - reward_normalizer = get_reward_normalizer(cfg, dataset_dict) - dataset_dict["rewards"] = reward_normalizer(dataset_dict["rewards"]) - - for key in required_keys: - assert key in dataset_dict, f"Missing `{key}` in dataset." - - if return_reward_normalizer: - return dataset_dict, reward_normalizer - return dataset_dict - - -def get_trajectory_boundaries_and_returns(dataset): - """ - Split dataset into trajectories and compute returns - """ - episode_starts = [0] - episode_ends = [] - - episode_return = 0 - episode_returns = [] - - n_transitions = len(dataset["rewards"]) - - for i in range(n_transitions): - episode_return += dataset["rewards"][i] - - if dataset["dones"][i]: - episode_returns.append(episode_return) - episode_ends.append(i + 1) - if i + 1 < n_transitions: - episode_starts.append(i + 1) - episode_return = 0.0 - - return episode_starts, episode_ends, episode_returns - - -def normalize_returns(dataset, scaling=1000): - """ - Normalize returns in the dataset - """ - (_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset) - dataset["rewards"] /= np.max(episode_returns) - np.min(episode_returns) - dataset["rewards"] *= scaling - return dataset - - -def get_reward_normalizer(cfg, dataset): - """ - Get a reward normalizer for the dataset - """ - if cfg.task.startswith("xarm"): - return lambda x: x - elif "maze" in cfg.task: - return lambda x: x - 1.0 - elif cfg.task.split("-")[0] in ["hopper", "halfcheetah", "walker2d"]: - (_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset) - return lambda x: x / (np.max(episode_returns) - np.min(episode_returns)) * 1000.0 - elif hasattr(cfg, "reward_scale"): - return lambda x: x * cfg.reward_scale - return lambda x: x - - -def linear_schedule(schdl, step): - """ - Outputs values following a linear decay schedule. - Adapted from https://github.com/facebookresearch/drqv2 - """ - try: - return float(schdl) - except ValueError: - match = re.match(r"linear\((.+),(.+),(.+),(.+)\)", schdl) - if match: - init, final, start, end = (float(g) for g in match.groups()) - mix = np.clip((step - start) / (end - start), 0.0, 1.0) - return (1.0 - mix) * init + mix * final - match = re.match(r"linear\((.+),(.+),(.+)\)", schdl) - if match: - init, final, duration = (float(g) for g in match.groups()) - mix = np.clip(step / duration, 0.0, 1.0) - return (1.0 - mix) * init + mix * final - raise NotImplementedError(schdl) diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py new file mode 100644 index 00000000..4205b4fc --- /dev/null +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -0,0 +1,798 @@ +"""Implementation of Finetuning Offline World Models in the Real World. + +The comments in this code may sometimes refer to these references: + TD-MPC paper: Temporal Difference Learning for Model Predictive Control (https://arxiv.org/abs/2203.04955) + FOWM paper: Finetuning Offline World Models in the Real World (https://arxiv.org/abs/2310.16029) + +TODO(alexander-soare): Make rollout work for batch sizes larger than 1. +TODO(alexander-soare): Use batch-first throughout. +""" + +# ruff: noqa: N806 + +import logging +from collections import deque +from copy import deepcopy +from functools import partial +from typing import Callable + +import einops +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F # noqa: N812 +from huggingface_hub import PyTorchModelHubMixin +from torch import Tensor + +from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig +from lerobot.common.policies.utils import get_device_from_parameters, populate_queues + + +class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): + """Implementation of TD-MPC learning + inference. + + Please note several warnings for this policy. + - Evaluation of pretrained weights created with the original FOWM code + (https://github.com/fyhMer/fowm) works as expected. To be precise: we trained and evaluated a + model with the FOWM code for the xarm_lift_medium_replay dataset. We ported the weights across + to LeRobot, and were able to evaluate with the same success metric. BUT, we had to use inter- + process communication to use the xarm environment from FOWM. This is because our xarm + environment uses newer dependencies and does not match the environment in FOWM. See + https://github.com/huggingface/lerobot/pull/103 for implementation details. + - We have NOT checked that training on LeRobot reproduces SOTA results. This is a TODO. + - Our current xarm datasets were generated using the environment from FOWM. Therefore they do not + match our xarm environment. + """ + + name = "tdmpc" + + def __init__( + self, config: TDMPCConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None + ): + """ + Args: + config: Policy configuration class instance or None, in which case the default instantiation of + the configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected + that they will be passed with a call to `load_state_dict` before the policy is used. + """ + super().__init__() + logging.warning( + """ + Please note several warnings for this policy. + + - Evaluation of pretrained weights created with the original FOWM code + (https://github.com/fyhMer/fowm) works as expected. To be precise: we trained and evaluated a + model with the FOWM code for the xarm_lift_medium_replay dataset. We ported the weights across + to LeRobot, and were able to evaluate with the same success metric. BUT, we had to use inter- + process communication to use the xarm environment from FOWM. This is because our xarm + environment uses newer dependencies and does not match the environment in FOWM. See + https://github.com/huggingface/lerobot/pull/103 for implementation details. + - We have NOT checked that training on LeRobot reproduces SOTA results. This is a TODO. + - Our current xarm datasets were generated using the environment from FOWM. Therefore they do not + match our xarm environment. + """ + ) + + if config is None: + config = TDMPCConfig() + self.config = config + self.model = TDMPCTOLD(config) + self.model_target = deepcopy(self.model) + self.model_target.eval() + + if config.input_normalization_modes is not None: + self.normalize_inputs = Normalize( + config.input_shapes, config.input_normalization_modes, dataset_stats + ) + else: + self.normalize_inputs = nn.Identity() + self.normalize_targets = Normalize( + config.output_shapes, config.output_normalization_modes, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_shapes, config.output_normalization_modes, dataset_stats + ) + + def save(self, fp): + """Save state dict of TOLD model to filepath.""" + torch.save(self.state_dict(), fp) + + def load(self, fp): + """Load a saved state dict from filepath into current agent.""" + self.load_state_dict(torch.load(fp)) + + def reset(self): + """ + Clear observation and action queues. Clear previous means for warm starting of MPPI/CEM. Should be + called on `env.reset()` + """ + self._queues = { + "observation.image": deque(maxlen=1), + "observation.state": deque(maxlen=1), + "action": deque(maxlen=self.config.n_action_repeats), + } + # Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start + # CEM for the next step. + self._prev_mean: torch.Tensor | None = None + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor]): + """Select a single action given environment observations.""" + assert "observation.image" in batch + assert "observation.state" in batch + assert len(batch) == 2 + + batch = self.normalize_inputs(batch) + + self._queues = populate_queues(self._queues, batch) + + # When the action queue is depleted, populate it again by querying the policy. + if len(self._queues["action"]) == 0: + batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} + + # Remove the time dimensions as it is not handled yet. + for key in batch: + assert batch[key].shape[1] == 1 + batch[key] = batch[key][:, 0] + + # NOTE: Order of observations matters here. + z = self.model.encode({k: batch[k] for k in ["observation.image", "observation.state"]}) + if self.config.use_mpc: + batch_size = batch["observation.image"].shape[0] + # Batch processing is not handled in MPC mode, so process the batch in a loop. + action = [] # will be a batch of actions for one step + for i in range(batch_size): + # Note: self.plan does not handle batches, hence the squeeze. + action.append(self.plan(z[i])) + action = torch.stack(action) + else: + # Plan with the policy (π) alone. + action = self.model.pi(z) + + self.unnormalize_outputs({"action": action})["action"] + + for _ in range(self.config.n_action_repeats): + self._queues["action"].append(action) + + action = self._queues["action"].popleft() + return torch.clamp(action, -1, 1) + + @torch.no_grad() + def plan(self, z: Tensor) -> Tensor: + """Plan next action using TD-MPC inference. + + Args: + z: (latent_dim,) tensor for the initial state. + Returns: + (action_dim,) tensor for the next action. + + TODO(alexander-soare) Extend this to be able to work with batches. + """ + device = get_device_from_parameters(self) + + # Sample Nπ trajectories from the policy. + pi_actions = torch.empty( + self.config.horizon, + self.config.n_pi_samples, + self.config.output_shapes["action"][0], + device=device, + ) + if self.config.n_pi_samples > 0: + _z = einops.repeat(z, "d -> n d", n=self.config.n_pi_samples) + for t in range(self.config.horizon): + # Note: Adding a small amount of noise here doesn't hurt during inference and may even be + # helpful for CEM. + pi_actions[t] = self.model.pi(_z, self.config.min_std) + _z = self.model.latent_dynamics(_z, pi_actions[t]) + + # In the CEM loop we will need this for a call to estimate_value with the gaussian sampled + # trajectories. + z = einops.repeat(z, "d -> n d", n=self.config.n_gaussian_samples + self.config.n_pi_samples) + + # Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization + # algorithm. + # The initial mean and standard deviation for the cross-entropy method (CEM). + mean = torch.zeros(self.config.horizon, self.config.output_shapes["action"][0], device=device) + # Maybe warm start CEM with the mean from the previous step. + if self._prev_mean is not None: + mean[:-1] = self._prev_mean[1:] + std = self.config.max_std * torch.ones_like(mean) + + for _ in range(self.config.cem_iterations): + # Randomly sample action trajectories for the gaussian distribution. + std_normal_noise = torch.randn( + self.config.horizon, + self.config.n_gaussian_samples, + self.config.output_shapes["action"][0], + device=std.device, + ) + gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1) + + # Compute elite actions. + actions = torch.cat([gaussian_actions, pi_actions], dim=1) + value = self.estimate_value(z, actions).nan_to_num_(0) + elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices + elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs] + + # Update guassian PDF parameters to be the (weighted) mean and standard deviation of the elites. + max_value = elite_value.max(0)[0] + # The weighting is a softmax over trajectory values. Note that this is not the same as the usage + # of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This + # makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²). + score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value)) + score /= score.sum() + _mean = torch.sum(einops.rearrange(score, "n -> n 1") * elite_actions, dim=1) + _std = torch.sqrt( + torch.sum( + einops.rearrange(score, "n -> n 1") + * (elite_actions - einops.rearrange(_mean, "h d -> h 1 d")) ** 2, + dim=1, + ) + ) + # Update mean with an exponential moving average, and std with a direct replacement. + mean = ( + self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean + ) + std = _std.clamp_(self.config.min_std, self.config.max_std) + + # Keep track of the mean for warm-starting subsequent steps. + self._prev_mean = mean + + # Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax + # scores from the last iteration. + actions = elite_actions[:, torch.multinomial(score, 1).item()] + + # Select only the first action + action = actions[0] + return action + + @torch.no_grad() + def estimate_value(self, z: Tensor, actions: Tensor): + """Estimates the value of a trajectory as per eqn 4 of the FOWM paper. + + Args: + z: (batch, latent_dim) tensor of initial latent states. + actions: (horizon, batch, action_dim) tensor of action trajectories. + Returns: + (batch,) tensor of values. + """ + # Initialize return and running discount factor. + G, running_discount = 0, 1 + # Iterate over the actions in the trajectory to simulate the trajectory using the latent dynamics + # model. Keep track of return. + for t in range(actions.shape[0]): + # We will compute the reward in a moment. First compute the uncertainty regularizer from eqn 4 + # of the FOWM paper. + if self.config.uncertainty_regularizer_coeff > 0: + regularization = -( + self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0) + ) + else: + regularization = 0 + # Estimate the next state (latent) and reward. + z, reward = self.model.latent_dynamics_and_reward(z, actions[t]) + # Update the return and running discount. + G += running_discount * (reward + regularization) + running_discount *= self.config.discount + # Add the estimated value of the final state (using the minimum for a conservative estimate). + # Do so by predicting the next action, then taking a minimum over the ensemble of state-action value + # estimators. + # Note: This small amount of added noise seems to help a bit at inference time as observed by success + # metrics over 50 episodes of xarm_lift_medium_replay. + next_action = self.model.pi(z, self.config.min_std) # (batch, action_dim) + terminal_values = self.model.Qs(z, next_action) # (ensemble, batch) + # Randomly choose 2 of the Qs for terminal value estimation (as in App C. of the FOWM paper). + if self.config.q_ensemble_size > 2: + G += ( + running_discount + * torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[ + 0 + ] + ) + else: + G += running_discount * torch.min(terminal_values, dim=0)[0] + # Finally, also regularize the terminal value. + if self.config.uncertainty_regularizer_coeff > 0: + G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0) + return G + + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + """Run the batch through the model and compute the loss.""" + device = get_device_from_parameters(self) + + batch = self.normalize_inputs(batch) + batch = self.normalize_targets(batch) + + info = {} + + # TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation. + batch_size = batch["index"].shape[0] + + # (b, t) -> (t, b) + for key in batch: + if batch[key].ndim > 1: + batch[key] = batch[key].transpose(1, 0) + + action = batch["action"] # (t, b) + reward = batch["next.reward"] # (t,) + observations = {k: v for k, v in batch.items() if k.startswith("observation.")} + + # Apply random image augmentations. + if self.config.max_random_shift_ratio > 0: + observations["observation.image"] = flatten_forward_unflatten( + partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio), + observations["observation.image"], + ) + + # Get the current observation for predicting trajectories, and all future observations for use in + # the latent consistency loss and TD loss. + current_observation, next_observations = {}, {} + for k in observations: + current_observation[k] = observations[k][0] + next_observations[k] = observations[k][1:] + horizon = next_observations["observation.image"].shape[0] + + # Run latent rollout using the latent dynamics model and policy model. + # Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action + # gives us a next `z`. + z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device) + z_preds[0] = self.model.encode(current_observation) + reward_preds = torch.empty_like(reward, device=device) + for t in range(horizon): + z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t]) + + # Compute Q and V value predictions based on the latent rollout. + q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch) + v_preds = self.model.V(z_preds[:-1]) + info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()}) + + # Compute various targets with stopgrad. + with torch.no_grad(): + # Latent state consistency targets. + z_targets = self.model_target.encode(next_observations) + # State-action value targets (or TD targets) as in eqn 3 of the FOWM. Unlike TD-MPC which uses the + # learned state-action value function in conjunction with the learned policy: Q(z, π(z)), FOWM + # uses a learned state value function: V(z). This means the TD targets only depend on in-sample + # actions (not actions estimated by π). + # Note: Here we do not use self.model_target, but self.model. This is to follow the original code + # and the FOWM paper. + q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations)) + # From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we + # are using them to compute loss for V. + v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True) + + # Compute losses. + # Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the + # future have less impact on the loss. Note: unsqueeze will let us broadcast to (seq, batch). + temporal_loss_coeffs = torch.pow( + self.config.temporal_decay_coeff, torch.arange(horizon, device=device) + ).unsqueeze(-1) + # Compute consistency loss as MSE loss between latents predicted from the rollout and latents + # predicted from the (target model's) observation encoder. + consistency_loss = ( + ( + temporal_loss_coeffs + * F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1) + # `z_preds` depends on the current observation and the actions. + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + # `z_targets` depends on the next observation. + * ~batch["observation.state_is_pad"][1:] + ) + .sum(0) + .mean() + ) + # Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset + # rewards. + reward_loss = ( + ( + temporal_loss_coeffs + * F.mse_loss(reward_preds, reward, reduction="none") + * ~batch["next.reward_is_pad"] + # `reward_preds` depends on the current observation and the actions. + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + ) + .sum(0) + .mean() + ) + # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. + q_value_loss = ( + ( + F.mse_loss( + q_preds_ensemble, + einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]), + reduction="none", + ).sum(0) # sum over ensemble + # `q_preds_ensemble` depends on the first observation and the actions. + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + # q_targets depends on the reward and the next observations. + * ~batch["next.reward_is_pad"] + * ~batch["observation.state_is_pad"][1:] + ) + .sum(0) + .mean() + ) + # Compute state value loss as in eqn 3 of FOWM. + diff = v_targets - v_preds + # Expectile loss penalizes: + # - `v_preds < v_targets` with weighting `expectile_weight` + # - `v_preds >= v_targets` with weighting `1 - expectile_weight` + raw_v_value_loss = torch.where( + diff > 0, self.config.expectile_weight, (1 - self.config.expectile_weight) + ) * (diff**2) + v_value_loss = ( + ( + temporal_loss_coeffs + * raw_v_value_loss + # `v_targets` depends on the first observation and the actions, as does `v_preds`. + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + ) + .sum(0) + .mean() + ) + + # Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1. + # We won't need these gradients again so detach. + z_preds = z_preds.detach() + # Use stopgrad for the advantage calculation. + with torch.no_grad(): + advantage = self.model_target.Qs(z_preds[:-1], action, return_min=True) - self.model.V( + z_preds[:-1] + ) + info["advantage"] = advantage[0] + # (t, b) + exp_advantage = torch.clamp(torch.exp(advantage * self.config.advantage_scaling), max=100.0) + action_preds = self.model.pi(z_preds[:-1]) # (t, b, a) + # Calculate the MSE between the actions and the action predictions. + # Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation + # gaussian) and sums over the action dimension. Computing the log probability amounts to multiplying + # the MSE by 0.5 and adding a constant offset (the log(2*pi) term) . Here we drop the constant offset + # as it doesn't change the optimization step, and we drop the 0.5 as we instead make a configuration + # parameter for it (see below where we compute the total loss). + mse = F.mse_loss(action_preds, action, reduction="none").sum(-1) # (t, b) + # NOTE: The original implementation does not take the sum over the temporal dimension like with the + # other losses. + # TODO(alexander-soare): Take the sum over the temporal dimension and check that training still works + # as well as expected. + pi_loss = ( + exp_advantage + * mse + * temporal_loss_coeffs + # `action_preds` depends on the first observation and the actions. + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + ).mean() + + loss = ( + self.config.consistency_coeff * consistency_loss + + self.config.reward_coeff * reward_loss + + self.config.value_coeff * q_value_loss + + self.config.value_coeff * v_value_loss + + self.config.pi_coeff * pi_loss + ) + + info.update( + { + "consistency_loss": consistency_loss.item(), + "reward_loss": reward_loss.item(), + "Q_value_loss": q_value_loss.item(), + "V_value_loss": v_value_loss.item(), + "pi_loss": pi_loss.item(), + "loss": loss, + "sum_loss": loss.item() * self.config.horizon, + } + ) + + # Undo (b, t) -> (t, b). + for key in batch: + if batch[key].ndim > 1: + batch[key] = batch[key].transpose(1, 0) + + return info + + def update(self): + """Update the target model's parameters with an EMA step.""" + # Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA + # update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code + # we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995) + update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum) + + +class TDMPCTOLD(nn.Module): + """Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC.""" + + def __init__(self, config: TDMPCConfig): + super().__init__() + self.config = config + self._encoder = TDMPCObservationEncoder(config) + self._dynamics = nn.Sequential( + nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Mish(), + nn.Linear(config.mlp_dim, config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Mish(), + nn.Linear(config.mlp_dim, config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Sigmoid(), + ) + self._reward = nn.Sequential( + nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Mish(), + nn.Linear(config.mlp_dim, config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Mish(), + nn.Linear(config.mlp_dim, 1), + ) + self._pi = nn.Sequential( + nn.Linear(config.latent_dim, config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Mish(), + nn.Linear(config.mlp_dim, config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Mish(), + nn.Linear(config.mlp_dim, config.output_shapes["action"][0]), + ) + self._Qs = nn.ModuleList( + [ + nn.Sequential( + nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Tanh(), + nn.Linear(config.mlp_dim, config.mlp_dim), + nn.ELU(), + nn.Linear(config.mlp_dim, 1), + ) + for _ in range(config.q_ensemble_size) + ] + ) + self._V = nn.Sequential( + nn.Linear(config.latent_dim, config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Tanh(), + nn.Linear(config.mlp_dim, config.mlp_dim), + nn.ELU(), + nn.Linear(config.mlp_dim, 1), + ) + self._init_weights() + + def _init_weights(self): + """Initialize model weights. + + Orthogonal initialization for all linear and convolutional layers' weights (apart from final layers + of reward network and Q networks which get zero initialization). + Zero initialization for all linear and convolutional layers' biases. + """ + + def _apply_fn(m): + if isinstance(m, nn.Linear): + nn.init.orthogonal_(m.weight.data) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Conv2d): + gain = nn.init.calculate_gain("relu") + nn.init.orthogonal_(m.weight.data, gain) + if m.bias is not None: + nn.init.zeros_(m.bias) + + self.apply(_apply_fn) + for m in [self._reward, *self._Qs]: + assert isinstance( + m[-1], nn.Linear + ), "Sanity check. The last linear layer needs 0 initialization on weights." + nn.init.zeros_(m[-1].weight) + nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure + + def encode(self, obs: dict[str, Tensor]) -> Tensor: + """Encodes an observation into its latent representation.""" + return self._encoder(obs) + + def latent_dynamics_and_reward(self, z: Tensor, a: Tensor) -> tuple[Tensor, Tensor]: + """Predict the next state's latent representation and the reward given a current latent and action. + + Args: + z: (*, latent_dim) tensor for the current state's latent representation. + a: (*, action_dim) tensor for the action to be applied. + Returns: + A tuple containing: + - (*, latent_dim) tensor for the next state's latent representation. + - (*,) tensor for the estimated reward. + """ + x = torch.cat([z, a], dim=-1) + return self._dynamics(x), self._reward(x).squeeze(-1) + + def latent_dynamics(self, z: Tensor, a: Tensor) -> Tensor: + """Predict the next state's latent representation given a current latent and action. + + Args: + z: (*, latent_dim) tensor for the current state's latent representation. + a: (*, action_dim) tensor for the action to be applied. + Returns: + (*, latent_dim) tensor for the next state's latent representation. + """ + x = torch.cat([z, a], dim=-1) + return self._dynamics(x) + + def pi(self, z: Tensor, std: float = 0.0) -> Tensor: + """Samples an action from the learned policy. + + The policy can also have added (truncated) Gaussian noise injected for encouraging exploration when + generating rollouts for online training. + + Args: + z: (*, latent_dim) tensor for the current state's latent representation. + std: The standard deviation of the injected noise. + Returns: + (*, action_dim) tensor for the sampled action. + """ + action = torch.tanh(self._pi(z)) + if std > 0: + std = torch.ones_like(action) * std + action += torch.randn_like(action) * std + return action + + def V(self, z: Tensor) -> Tensor: # noqa: N802 + """Predict state value (V). + + Args: + z: (*, latent_dim) tensor for the current state's latent representation. + Returns: + (*,) tensor of estimated state values. + """ + return self._V(z).squeeze(-1) + + def Qs(self, z: Tensor, a: Tensor, return_min: bool = False) -> Tensor: # noqa: N802 + """Predict state-action value for all of the learned Q functions. + + Args: + z: (*, latent_dim) tensor for the current state's latent representation. + a: (*, action_dim) tensor for the action to be applied. + return_min: Set to true for implementing the detail in App. C of the FOWM paper: randomly select + 2 of the Qs and return the minimum + Returns: + (q_ensemble, *) tensor for the value predictions of each learned Q function in the ensemble OR + (*,) tensor if return_min=True. + """ + x = torch.cat([z, a], dim=-1) + if not return_min: + return torch.stack([q(x).squeeze(-1) for q in self._Qs], dim=0) + else: + if len(self._Qs) > 2: # noqa: SIM108 + Qs = [self._Qs[i] for i in np.random.choice(len(self._Qs), size=2)] + else: + Qs = self._Qs + return torch.stack([q(x).squeeze(-1) for q in Qs], dim=0).min(dim=0)[0] + + +class TDMPCObservationEncoder(nn.Module): + """Encode image and/or state vector observations.""" + + def __init__(self, config: TDMPCConfig): + """ + Creates encoders for pixel and/or state modalities. + TODO(alexander-soare): The original work allows for multiple images by concatenating them along the + channel dimension. Re-implement this capability. + """ + super().__init__() + self.config = config + + if "observation.image" in config.input_shapes: + self.image_enc_layers = nn.Sequential( + nn.Conv2d( + config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2 + ), + nn.ReLU(), + nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2), + nn.ReLU(), + nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), + nn.ReLU(), + nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), + nn.ReLU(), + ) + dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"]) + with torch.inference_mode(): + out_shape = self.image_enc_layers(dummy_batch).shape[1:] + self.image_enc_layers.extend( + nn.Sequential( + nn.Flatten(), + nn.Linear(np.prod(out_shape), config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Sigmoid(), + ) + ) + if "observation.state" in config.input_shapes: + self.state_enc_layers = nn.Sequential( + nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim), + nn.ELU(), + nn.Linear(config.state_encoder_hidden_dim, config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Sigmoid(), + ) + + def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: + """Encode the image and/or state vector. + + Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken + over all features. + """ + feat = [] + if "observation.image" in self.config.input_shapes: + feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict["observation.image"])) + if "observation.state" in self.config.input_shapes: + feat.append(self.state_enc_layers(obs_dict["observation.state"])) + return torch.stack(feat, dim=0).mean(0) + + +def random_shifts_aug(x: Tensor, max_random_shift_ratio: float) -> Tensor: + """Randomly shifts images horizontally and vertically. + + Adapted from https://github.com/facebookresearch/drqv2 + """ + b, _, h, w = x.size() + assert h == w, "non-square images not handled yet" + pad = int(round(max_random_shift_ratio * h)) + x = F.pad(x, tuple([pad] * 4), "replicate") + eps = 1.0 / (h + 2 * pad) + arange = torch.linspace( + -1.0 + eps, + 1.0 - eps, + h + 2 * pad, + device=x.device, + dtype=torch.float32, + )[:h] + arange = einops.repeat(arange, "w -> h w 1", h=h) + base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2) + base_grid = einops.repeat(base_grid, "h w c -> b h w c", b=b) + # A random shift in units of pixels and within the boundaries of the padding. + shift = torch.randint( + 0, + 2 * pad + 1, + size=(b, 1, 1, 2), + device=x.device, + dtype=torch.float32, + ) + shift *= 2.0 / (h + 2 * pad) + grid = base_grid + shift + return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False) + + +def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float): + """Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param.""" + for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True): + for (n_p_ema, p_ema), (n_p, p) in zip( + ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True + ): + assert n_p_ema == n_p, "Parameter names don't match for EMA model update" + if isinstance(p, dict): + raise RuntimeError("Dict parameter not supported") + if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad: + # Copy BatchNorm parameters, and non-trainable parameters directly. + p_ema.copy_(p.to(dtype=p_ema.dtype).data) + with torch.no_grad(): + p_ema.mul_(alpha) + p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha) + + +def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: + """Helper to temporarily flatten extra dims at the start of the image tensor. + + Args: + fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return + (B, *), where * is any number of dimensions. + image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions, generally + different from *. + Returns: + A return value from the callable reshaped to (**, *). + """ + if image_tensor.ndim == 4: + return fn(image_tensor) + start_dims = image_tensor.shape[:-3] + inp = torch.flatten(image_tensor, end_dim=-4) + flat_out = fn(inp) + return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:])) diff --git a/lerobot/common/policies/tdmpc/policy.py b/lerobot/common/policies/tdmpc/policy.py deleted file mode 100644 index adaa30c0..00000000 --- a/lerobot/common/policies/tdmpc/policy.py +++ /dev/null @@ -1,495 +0,0 @@ -# ruff: noqa: N806 - -import time -from collections import deque -from copy import deepcopy - -import einops -import numpy as np -import torch -import torch.nn as nn - -import lerobot.common.policies.tdmpc.helper as h -from lerobot.common.policies.utils import populate_queues -from lerobot.common.utils.utils import get_safe_torch_device - -FIRST_FRAME = 0 - - -class TOLD(nn.Module): - """Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC.""" - - def __init__(self, cfg): - super().__init__() - action_dim = cfg.action_dim - - self.cfg = cfg - self._encoder = h.enc(cfg) - self._dynamics = h.dynamics(cfg.latent_dim + action_dim, cfg.mlp_dim, cfg.latent_dim) - self._reward = h.mlp(cfg.latent_dim + action_dim, cfg.mlp_dim, 1) - self._pi = h.mlp(cfg.latent_dim, cfg.mlp_dim, action_dim) - self._Qs = nn.ModuleList([h.q(cfg) for _ in range(cfg.num_q)]) - self._V = h.v(cfg) - self.apply(h.orthogonal_init) - for m in [self._reward, *self._Qs]: - m[-1].weight.data.fill_(0) - m[-1].bias.data.fill_(0) - - def track_q_grad(self, enable=True): - """Utility function. Enables/disables gradient tracking of Q-networks.""" - for m in self._Qs: - h.set_requires_grad(m, enable) - - def track_v_grad(self, enable=True): - """Utility function. Enables/disables gradient tracking of Q-networks.""" - if hasattr(self, "_V"): - h.set_requires_grad(self._V, enable) - - def encode(self, obs): - """Encodes an observation into its latent representation.""" - out = self._encoder(obs) - if isinstance(obs, dict): - # fusion - out = torch.stack([v for k, v in out.items()]).mean(dim=0) - return out - - def next(self, z, a): - """Predicts next latent state (d) and single-step reward (R).""" - x = torch.cat([z, a], dim=-1) - return self._dynamics(x), self._reward(x) - - def next_dynamics(self, z, a): - """Predicts next latent state (d).""" - x = torch.cat([z, a], dim=-1) - return self._dynamics(x) - - def pi(self, z, std=0): - """Samples an action from the learned policy (pi).""" - mu = torch.tanh(self._pi(z)) - if std > 0: - std = torch.ones_like(mu) * std - return h.TruncatedNormal(mu, std).sample(clip=0.3) - return mu - - def V(self, z): # noqa: N802 - """Predict state value (V).""" - return self._V(z) - - def Q(self, z, a, return_type): # noqa: N802 - """Predict state-action value (Q).""" - assert return_type in {"min", "avg", "all"} - x = torch.cat([z, a], dim=-1) - - if return_type == "all": - return torch.stack([q(x) for q in self._Qs], dim=0) - - idxs = np.random.choice(self.cfg.num_q, 2, replace=False) - Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x) - return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2 - - -class TDMPCPolicy(nn.Module): - """Implementation of TD-MPC learning + inference.""" - - name = "tdmpc" - - def __init__(self, cfg, n_obs_steps, n_action_steps, device): - super().__init__() - self.action_dim = cfg.action_dim - - self.cfg = cfg - self.n_obs_steps = n_obs_steps - self.n_action_steps = n_action_steps - self.device = get_safe_torch_device(device) - self.std = h.linear_schedule(cfg.std_schedule, 0) - self.model = TOLD(cfg) - self.model.to(self.device) - self.model_target = deepcopy(self.model) - self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr) - self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr) - # self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr) - self.model.eval() - self.model_target.eval() - - self.register_buffer("step", torch.zeros(1)) - - def state_dict(self): - """Retrieve state dict of TOLD model, including slow-moving target network.""" - return { - "model": self.model.state_dict(), - "model_target": self.model_target.state_dict(), - } - - def save(self, fp): - """Save state dict of TOLD model to filepath.""" - torch.save(self.state_dict(), fp) - - def load(self, fp): - """Load a saved state dict from filepath into current agent.""" - d = torch.load(fp) - self.model.load_state_dict(d["model"]) - self.model_target.load_state_dict(d["model_target"]) - - def reset(self): - """ - Clear observation and action queues. Should be called on `env.reset()` - """ - self._queues = { - "observation.image": deque(maxlen=self.n_obs_steps), - "observation.state": deque(maxlen=self.n_obs_steps), - "action": deque(maxlen=self.n_action_steps), - } - - @torch.no_grad() - def select_action(self, batch, step): - assert "observation.image" in batch - assert "observation.state" in batch - assert len(batch) == 2 - - self._queues = populate_queues(self._queues, batch) - - t0 = step == 0 - - self.eval() - - if len(self._queues["action"]) == 0: - batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} - - if self.n_obs_steps == 1: - # hack to remove the time dimension - for key in batch: - assert batch[key].shape[1] == 1 - batch[key] = batch[key][:, 0] - - actions = [] - batch_size = batch["observation.image"].shape[0] - for i in range(batch_size): - obs = { - "rgb": batch["observation.image"][[i]], - "state": batch["observation.state"][[i]], - } - # Note: unsqueeze needed because `act` still uses non-batch logic. - action = self.act(obs, t0=t0, step=self.step) - actions.append(action) - action = torch.stack(actions) - - # tdmpc returns an action for 1 timestep only, so we copy it over `n_action_steps` time - if i in range(self.n_action_steps): - self._queues["action"].append(action) - - action = self._queues["action"].popleft() - return action - - @torch.no_grad() - def act(self, obs, t0=False, step=None): - """Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag.""" - obs = {k: o.detach() for k, o in obs.items()} if isinstance(obs, dict) else obs.detach() - z = self.model.encode(obs) - if self.cfg.mpc: - a = self.plan(z, t0=t0, step=step) - else: - a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0) - return a - - @torch.no_grad() - def estimate_value(self, z, actions, horizon): - """Estimate value of a trajectory starting at latent state z and executing given actions.""" - G, discount = 0, 1 - for t in range(horizon): - if self.cfg.uncertainty_cost > 0: - G -= ( - discount - * self.cfg.uncertainty_cost - * self.model.Q(z, actions[t], return_type="all").std(dim=0) - ) - z, reward = self.model.next(z, actions[t]) - G += discount * reward - discount *= self.cfg.discount - pi = self.model.pi(z, self.cfg.min_std) - G += discount * self.model.Q(z, pi, return_type="min") - if self.cfg.uncertainty_cost > 0: - G -= discount * self.cfg.uncertainty_cost * self.model.Q(z, pi, return_type="all").std(dim=0) - return G - - @torch.no_grad() - def plan(self, z, step=None, t0=True): - """ - Plan next action using TD-MPC inference. - z: latent state. - step: current time step. determines e.g. planning horizon. - t0: whether current step is the first step of an episode. - """ - # during eval: eval_mode: uniform sampling and action noise is disabled during evaluation. - - assert step is not None - # Seed steps - if step < self.cfg.seed_steps and self.model.training: - return torch.empty(self.action_dim, dtype=torch.float32, device=self.device).uniform_(-1, 1) - - # Sample policy trajectories - horizon = int(min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step))) - num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples) - if num_pi_trajs > 0: - pi_actions = torch.empty(horizon, num_pi_trajs, self.action_dim, device=self.device) - _z = z.repeat(num_pi_trajs, 1) - for t in range(horizon): - pi_actions[t] = self.model.pi(_z, self.cfg.min_std) - _z = self.model.next_dynamics(_z, pi_actions[t]) - - # Initialize state and parameters - z = z.repeat(self.cfg.num_samples + num_pi_trajs, 1) - mean = torch.zeros(horizon, self.action_dim, device=self.device) - std = self.cfg.max_std * torch.ones(horizon, self.action_dim, device=self.device) - if not t0 and hasattr(self, "_prev_mean"): - mean[:-1] = self._prev_mean[1:] - - # Iterate CEM - for _ in range(self.cfg.iterations): - actions = torch.clamp( - mean.unsqueeze(1) - + std.unsqueeze(1) - * torch.randn(horizon, self.cfg.num_samples, self.action_dim, device=std.device), - -1, - 1, - ) - if num_pi_trajs > 0: - actions = torch.cat([actions, pi_actions], dim=1) - - # Compute elite actions - value = self.estimate_value(z, actions, horizon).nan_to_num_(0) - elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices - elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs] - - # Update parameters - max_value = elite_value.max(0)[0] - score = torch.exp(self.cfg.temperature * (elite_value - max_value)) - score /= score.sum(0) - _mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9) - _std = torch.sqrt( - torch.sum( - score.unsqueeze(0) * (elite_actions - _mean.unsqueeze(1)) ** 2, - dim=1, - ) - / (score.sum(0) + 1e-9) - ) - _std = _std.clamp_(self.std, self.cfg.max_std) - mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std - - # Outputs - # TODO(rcadene): remove numpy with - # # Convert score tensor to probabilities using softmax - # probabilities = torch.softmax(score, dim=0) - # # Generate a random sample index based on the probabilities - # sample_index = torch.multinomial(probabilities, 1).item() - score = score.squeeze(1).cpu().numpy() - actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)] - self._prev_mean = mean - mean, std = actions[0], _std[0] - a = mean - if self.model.training: - a += std * torch.randn(self.action_dim, device=std.device) - return torch.clamp(a, -1, 1) - - def update_pi(self, zs, acts=None): - """Update policy using a sequence of latent states.""" - self.pi_optim.zero_grad(set_to_none=True) - self.model.track_q_grad(False) - self.model.track_v_grad(False) - - info = {} - # Advantage Weighted Regression - assert acts is not None - vs = self.model.V(zs) - qs = self.model_target.Q(zs, acts, return_type="min") - adv = qs - vs - exp_a = torch.exp(adv * self.cfg.A_scaling) - exp_a = torch.clamp(exp_a, max=100.0) - log_probs = h.gaussian_logprob(self.model.pi(zs) - acts, 0) - rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device)) - pi_loss = -((exp_a * log_probs).mean(dim=(1, 2)) * rho).mean() - info["adv"] = adv[0] - - pi_loss.backward() - torch.nn.utils.clip_grad_norm_( - self.model._pi.parameters(), - self.cfg.grad_clip_norm, - error_if_nonfinite=False, - ) - self.pi_optim.step() - self.model.track_q_grad(True) - self.model.track_v_grad(True) - - info["pi_loss"] = pi_loss.item() - return pi_loss.item(), info - - @torch.no_grad() - def _td_target(self, next_z, reward, mask): - """Compute the TD-target from a reward and the observation at the following time step.""" - next_v = self.model.V(next_z) - td_target = reward + self.cfg.discount * mask * next_v.squeeze(2) - return td_target - - def forward(self, batch, step): - # TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation. - raise NotImplementedError() - - def update(self, batch, step): - """Main update function. Corresponds to one iteration of the model learning.""" - start_time = time.time() - - batch_size = batch["index"].shape[0] - - # TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels) - # instead of currently (time/horizon, batch size, channels) which is not the pytorch convention - # batch size b = 256, time/horizon t = 5 - # b t ... -> t b ... - for key in batch: - if batch[key].ndim > 1: - batch[key] = batch[key].transpose(1, 0) - - action = batch["action"] - reward = batch["next.reward"] - # idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights - done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device) - mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device) - weights = torch.ones(batch_size, dtype=torch.bool, device=reward.device) - - obses = { - "rgb": batch["observation.image"], - "state": batch["observation.state"], - } - - shapes = {} - for k in obses: - shapes[k] = obses[k].shape - obses[k] = einops.rearrange(obses[k], "t b ... -> (t b) ... ") - - # Apply augmentations - aug_tf = h.aug(self.cfg) - obses = aug_tf(obses) - - for k in obses: - t, b = shapes[k][:2] - obses[k] = einops.rearrange(obses[k], "(t b) ... -> t b ... ", b=b, t=t) - - obs, next_obses = {}, {} - for k in obses: - obs[k] = obses[k][0] - next_obses[k] = obses[k][1:].clone() - - horizon = next_obses["rgb"].shape[0] - loss_mask = torch.ones_like(mask, device=self.device) - for t in range(1, horizon): - loss_mask[t] = loss_mask[t - 1] * (~done[t - 1]) - - self.optim.zero_grad(set_to_none=True) - self.std = h.linear_schedule(self.cfg.std_schedule, step) - self.model.train() - - data_s = time.time() - start_time - - # Compute targets - with torch.no_grad(): - next_z = self.model.encode(next_obses) - z_targets = self.model_target.encode(next_obses) - td_targets = self._td_target(next_z, reward, mask) - - # Latent rollout - zs = torch.empty(horizon + 1, batch_size, self.cfg.latent_dim, device=self.device) - reward_preds = torch.empty_like(reward, device=self.device) - assert reward.shape[0] == horizon - z = self.model.encode(obs) - zs[0] = z - value_info = {"Q": 0.0, "V": 0.0} - for t in range(horizon): - z, reward_pred = self.model.next(z, action[t]) - zs[t + 1] = z - reward_preds[t] = reward_pred.squeeze(1) - - with torch.no_grad(): - v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min") - - # Predictions - qs = self.model.Q(zs[:-1], action, return_type="all") - qs = qs.squeeze(3) - value_info["Q"] = qs.mean().item() - v = self.model.V(zs[:-1]) - value_info["V"] = v.mean().item() - - # Losses - rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1) - consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2) * loss_mask).sum(dim=0) - reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0) - q_value_loss, priority_loss = 0, 0 - for q in range(self.cfg.num_q): - q_value_loss += (rho * h.mse(qs[q], td_targets) * loss_mask).sum(dim=0) - priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0) - - expectile = h.linear_schedule(self.cfg.expectile, step) - v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile).squeeze(2) * loss_mask).sum( - dim=0 - ) - - total_loss = ( - self.cfg.consistency_coef * consistency_loss - + self.cfg.reward_coef * reward_loss - + self.cfg.value_coef * q_value_loss - + self.cfg.value_coef * v_value_loss - ) - - weighted_loss = (total_loss * weights).mean() - weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon)) - has_nan = torch.isnan(weighted_loss).item() - if has_nan: - print(f"weighted_loss has nan: {total_loss=} {weights=}") - else: - weighted_loss.backward() - - grad_norm = torch.nn.utils.clip_grad_norm_( - self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False - ) - self.optim.step() - - # TODO(rcadene): implement PrioritizedSampling by modifying sampler.weights with priorities computed by a criterion - # if self.cfg.per: - # # Update priorities - # priorities = priority_loss.clamp(max=1e4).detach() - # has_nan = torch.isnan(priorities).any().item() - # if has_nan: - # print(f"priorities has nan: {priorities=}") - # else: - # replay_buffer.update_priority( - # idxs[:num_slices], - # priorities[:num_slices], - # ) - # if demo_batch_size > 0: - # demo_buffer.update_priority(demo_idxs, priorities[num_slices:]) - - # Update policy + target network - _, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action) - - if step % self.cfg.update_freq == 0: - h.ema(self.model._encoder, self.model_target._encoder, self.cfg.tau) - h.ema(self.model._Qs, self.model_target._Qs, self.cfg.tau) - - self.model.eval() - - info = { - "consistency_loss": float(consistency_loss.mean().item()), - "reward_loss": float(reward_loss.mean().item()), - "Q_value_loss": float(q_value_loss.mean().item()), - "V_value_loss": float(v_value_loss.mean().item()), - "sum_loss": float(total_loss.mean().item()), - "loss": float(weighted_loss.mean().item()), - "grad_norm": float(grad_norm), - "lr": self.cfg.lr, - "data_s": data_s, - "update_s": time.time() - start_time, - } - # info["demo_batch_size"] = demo_batch_size - info["expectile"] = expectile - info.update(value_info) - info.update(pi_update_info) - - self.step[0] = step - return info diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 6bf78573..b3b85c0c 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -17,6 +17,7 @@ training: offline_steps: ??? online_steps: ??? online_steps_between_rollouts: ??? + online_sampling_ratio: 0.5 eval_freq: ??? save_freq: ??? log_freq: 250 diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index a4761d13..6387882c 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -1,85 +1,76 @@ # @package _global_ -n_action_steps: 2 -n_obs_steps: 1 +seed: 1 + +training: + offline_steps: 25000 + online_steps: 25000 + eval_freq: 5000 + online_steps_between_rollouts: 1 + online_sampling_ratio: 0.5 + + batch_size: 256 + grad_clip_norm: 10.0 + lr: 3e-4 + + delta_timestamps: + observation.image: "[i / ${fps} for i in range(${policy.horizon} + 1)]" + observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]" + action: "[i / ${fps} for i in range(${policy.horizon})]" + next.reward: "[i / ${fps} for i in range(${policy.horizon})]" policy: name: tdmpc - reward_scale: 1.0 + pretrained_model_path: - episode_length: ${env.episode_length} - discount: 0.9 - modality: 'all' - - # pixels - frame_stack: 1 - num_channels: 32 - img_size: ${env.image_size} - state_dim: ${env.action_dim} - action_dim: ${env.action_dim} - - # planning - mpc: true - iterations: 6 - num_samples: 512 - num_elites: 50 - mixture_coef: 0.1 - min_std: 0.05 - max_std: 2.0 - temperature: 0.5 - momentum: 0.1 - uncertainty_cost: 1 - - # actor - log_std_min: -10 - log_std_max: 2 - - # learning - batch_size: 256 - max_buffer_size: 10000 + # Input / output structure. + n_action_repeats: 2 horizon: 5 - reward_coef: 0.5 - value_coef: 0.1 - consistency_coef: 20 - rho: 0.5 - kappa: 0.1 - lr: 3e-4 - std_schedule: ${policy.min_std} - horizon_schedule: ${policy.horizon} - per: true - per_alpha: 0.6 - per_beta: 0.4 - grad_clip_norm: 10 - seed_steps: 0 - update_freq: 2 - tau: 0.01 - online_steps_between_rollouts: 1 - # offline rl - # dataset_dir: ??? - data_first_percent: 1.0 - is_data_clip: true - data_clip_eps: 1e-5 - expectile: 0.9 - A_scaling: 3.0 + input_shapes: + # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? + observation.image: [3, 84, 84] + observation.state: ["${env.state_dim}"] + output_shapes: + action: ["${env.action_dim}"] - # offline->online - offline_steps: ${offline_steps} - pretrained_model_path: "" - # pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt" - # pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt" - balanced_sampling: true - demo_schedule: 0.5 + # Normalization / Unnormalization + input_normalization_modes: null + output_normalization_modes: + action: min_max - # architecture - enc_dim: 256 - num_q: 5 - mlp_dim: 512 + # Architecture / modeling. + # Neural networks. + image_encoder_hidden_dim: 32 + state_encoder_hidden_dim: 256 latent_dim: 50 + q_ensemble_size: 5 + mlp_dim: 512 + # Reinforcement learning. + discount: 0.9 - delta_timestamps: - observation.image: "[i / ${fps} for i in range(6)]" - observation.state: "[i / ${fps} for i in range(6)]" - action: "[i / ${fps} for i in range(5)]" - next.reward: "[i / ${fps} for i in range(5)]" + # Inference. + use_mpc: false + cem_iterations: 6 + max_std: 2.0 + min_std: 0.05 + n_gaussian_samples: 512 + n_pi_samples: 51 + uncertainty_regularizer_coeff: 1.0 + n_elites: 50 + elite_weighting_temperature: 0.5 + gaussian_mean_momentum: 0.1 + + # Training and loss computation. + max_random_shift_ratio: 0.0476 + # Loss coefficients. + reward_coeff: 0.5 + expectile_weight: 0.9 + value_coeff: 0.1 + consistency_coeff: 20.0 + advantage_scaling: 3.0 + pi_coeff: 0.5 + temporal_decay_coeff: 0.5 + # Target model. + target_model_momentum: 0.995 diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index c74af290..6c9e28bf 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -67,10 +67,10 @@ def eval_policy( """ set `return_episode_data` to return a Hugging Face dataset object in an "episodes" key of the return dict. """ + policy.eval() + fps = env.unwrapped.metadata["render_fps"] - if policy is not None: - policy.eval() device = "cpu" if policy is None else next(policy.parameters()).device start = time.time() @@ -132,7 +132,7 @@ def eval_policy( # get the next action for the environment with torch.inference_mode(): - action = policy.select_action(observation, step=step) + action = policy.select_action(observation) # convert to cpu numpy action = postprocess_action(action) @@ -386,6 +386,7 @@ def eval( else: # Note: We need the dataset stats to pass to the policy's normalization modules. policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats) + policy.eval() info = eval_policy( env, diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 565c5f3a..bd27b28a 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -15,6 +15,7 @@ from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env from lerobot.common.logger import Logger, log_output_dir from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.policy_protocol import PolicyWithUpdate from lerobot.common.utils.utils import ( format_big_number, get_safe_torch_device, @@ -39,12 +40,17 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): optimizer.step() optimizer.zero_grad() + if lr_scheduler is not None: lr_scheduler.step() if hasattr(policy, "ema") and policy.ema is not None: policy.ema.step(policy.diffusion) + if isinstance(policy, PolicyWithUpdate): + # To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC). + policy.update() + info = { "loss": loss.item(), "grad_norm": float(grad_norm), @@ -246,11 +252,12 @@ def train(cfg: dict, out_dir=None, job_name=None): raise NotImplementedError() if job_name is None: raise NotImplementedError() - if cfg.training.online_steps > 0: - assert cfg.eval.batch_size == 1, "eval.batch_size > 1 not supported for online training steps" init_logging() + if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1: + logging.warning("eval.batch_size > 1 not supported for online training steps") + # Check device is available get_safe_torch_device(cfg.device, log=True) @@ -305,7 +312,10 @@ def train(cfg: dict, out_dir=None, job_name=None): num_training_steps=cfg.training.offline_steps, ) elif policy.name == "tdmpc": - raise NotImplementedError("TD-MPC not implemented yet.") + optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr) + lr_scheduler = None + else: + raise NotImplementedError() num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) @@ -361,12 +371,12 @@ def train(cfg: dict, out_dir=None, job_name=None): ) dl_iter = cycle(dataloader) + policy.train() step = 0 # number of policy update (forward + backward + optim) is_offline = True for offline_step in range(cfg.training.offline_steps): if offline_step == 0: logging.info("Start offline training on a fixed dataset") - policy.train() batch = next(dl_iter) for key in batch: @@ -414,6 +424,7 @@ def train(cfg: dict, out_dir=None, job_name=None): if env_step == 0: logging.info("Start online training by interacting with environment") + policy.eval() with torch.no_grad(): eval_info = eval_policy( rollout_env, @@ -422,17 +433,17 @@ def train(cfg: dict, out_dir=None, job_name=None): seed=cfg.seed, ) - add_episodes_inplace( - online_dataset, - concat_dataset, - sampler, - hf_dataset=eval_info["episodes"]["hf_dataset"], - episode_data_index=eval_info["episodes"]["episode_data_index"], - pc_online_samples=cfg.get("demo_schedule", 0.5), - ) + add_episodes_inplace( + online_dataset, + concat_dataset, + sampler, + hf_dataset=eval_info["episodes"]["hf_dataset"], + episode_data_index=eval_info["episodes"]["episode_data_index"], + pc_online_samples=cfg.training.online_sampling_ratio, + ) + policy.train() for _ in range(cfg.training.online_steps_between_rollouts): - policy.train() batch = next(dl_iter) for key in batch: diff --git a/tests/test_available.py b/tests/test_available.py index 0cfdf52b..b3d0cd78 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -6,7 +6,7 @@ import pytest import lerobot from lerobot.common.policies.act.modeling_act import ACTPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy -from lerobot.common.policies.tdmpc.policy import TDMPCPolicy +from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy from tests.utils import require_env diff --git a/tests/test_policies.py b/tests/test_policies.py index ed046659..50f36a25 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -19,10 +19,6 @@ from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env @pytest.mark.parametrize("policy_name", available_policies) def test_get_policy_and_config_classes(policy_name: str): """Check that the correct policy and config classes are returned.""" - if policy_name == "tdmpc": - with pytest.raises(NotImplementedError): - get_policy_and_config_classes(policy_name) - return policy_cls, config_cls = get_policy_and_config_classes(policy_name) assert policy_cls.name == policy_name assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation) @@ -32,8 +28,7 @@ def test_get_policy_and_config_classes(policy_name: str): @pytest.mark.parametrize( "env_name,policy_name,extra_overrides", [ - # ("xarm", "tdmpc", ["policy.mpc=true"]), - # ("pusht", "tdmpc", ["policy.mpc=false"]), + ("xarm", "tdmpc", ["policy.use_mpc=true", "dataset_repo_id=lerobot/xarm_lift_medium"]), ("pusht", "diffusion", []), ("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"]), ( @@ -103,7 +98,7 @@ def test_policy(env_name, policy_name, extra_overrides): batch[key] = batch[key].to(DEVICE, non_blocking=True) # Test updating the policy - policy.forward(batch, step=0) + policy.forward(batch) # reset the policy and environment policy.reset() @@ -117,7 +112,7 @@ def test_policy(env_name, policy_name, extra_overrides): # get the next action for the environment with torch.inference_mode(): - action = policy.select_action(observation, step=0) + action = policy.select_action(observation) # convert action to cpu numpy array action = postprocess_action(action) @@ -129,20 +124,12 @@ def test_policy(env_name, policy_name, extra_overrides): @pytest.mark.parametrize("policy_name", available_policies) def test_policy_defaults(policy_name: str): """Check that the policy can be instantiated with defaults.""" - if policy_name == "tdmpc": - with pytest.raises(NotImplementedError): - get_policy_and_config_classes(policy_name) - return policy_cls, _ = get_policy_and_config_classes(policy_name) policy_cls() @pytest.mark.parametrize("policy_name", available_policies) def test_save_and_load_pretrained(policy_name: str): - if policy_name == "tdmpc": - with pytest.raises(NotImplementedError): - get_policy_and_config_classes(policy_name) - return policy_cls, _ = get_policy_and_config_classes(policy_name) policy: Policy = policy_cls() save_dir = "/tmp/test_save_and_load_pretrained_{policy_cls.__name__}"