Add online training with TD-MPC as proof of concept (#338)

This commit is contained in:
Alexander Soare
2024-07-25 11:16:38 +01:00
committed by GitHub
parent abbb1d2367
commit f8a6574698
25 changed files with 1291 additions and 233 deletions

View File

@@ -19,14 +19,10 @@
The comments in this code may sometimes refer to these references:
TD-MPC paper: Temporal Difference Learning for Model Predictive Control (https://arxiv.org/abs/2203.04955)
FOWM paper: Finetuning Offline World Models in the Real World (https://arxiv.org/abs/2310.16029)
TODO(alexander-soare): Make rollout work for batch sizes larger than 1.
TODO(alexander-soare): Use batch-first throughout.
"""
# ruff: noqa: N806
import logging
from collections import deque
from copy import deepcopy
from functools import partial
@@ -56,9 +52,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
process communication to use the xarm environment from FOWM. This is because our xarm
environment uses newer dependencies and does not match the environment in FOWM. See
https://github.com/huggingface/lerobot/pull/103 for implementation details.
- We have NOT checked that training on LeRobot reproduces SOTA results. This is a TODO.
- We have NOT checked that training on LeRobot reproduces the results from FOWM.
- Nevertheless, we have verified that we can train TD-MPC for PushT. See
`lerobot/configs/policy/tdmpc_pusht_keypoints.yaml`.
- Our current xarm datasets were generated using the environment from FOWM. Therefore they do not
match our xarm environment.
match our xarm environment.
"""
name = "tdmpc"
@@ -74,22 +72,6 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__()
logging.warning(
"""
Please note several warnings for this policy.
- Evaluation of pretrained weights created with the original FOWM code
(https://github.com/fyhMer/fowm) works as expected. To be precise: we trained and evaluated a
model with the FOWM code for the xarm_lift_medium_replay dataset. We ported the weights across
to LeRobot, and were able to evaluate with the same success metric. BUT, we had to use inter-
process communication to use the xarm environment from FOWM. This is because our xarm
environment uses newer dependencies and does not match the environment in FOWM. See
https://github.com/huggingface/lerobot/pull/103 for implementation details.
- We have NOT checked that training on LeRobot reproduces SOTA results. This is a TODO.
- Our current xarm datasets were generated using the environment from FOWM. Therefore they do not
match our xarm environment.
"""
)
if config is None:
config = TDMPCConfig()
@@ -114,8 +96,14 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
assert len(image_keys) == 1
self.input_image_key = image_keys[0]
self._use_image = False
self._use_env_state = False
if len(image_keys) > 0:
assert len(image_keys) == 1
self._use_image = True
self.input_image_key = image_keys[0]
if "observation.environment_state" in config.input_shapes:
self._use_env_state = True
self.reset()
@@ -125,10 +113,13 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
called on `env.reset()`
"""
self._queues = {
"observation.image": deque(maxlen=1),
"observation.state": deque(maxlen=1),
"action": deque(maxlen=self.config.n_action_repeats),
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
}
if self._use_image:
self._queues["observation.image"] = deque(maxlen=1)
if self._use_env_state:
self._queues["observation.environment_state"] = deque(maxlen=1)
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
# CEM for the next step.
self._prev_mean: torch.Tensor | None = None
@@ -137,8 +128,9 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations."""
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[self.input_image_key]
if self._use_image:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[self.input_image_key]
self._queues = populate_queues(self._queues, batch)
@@ -152,49 +144,57 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
batch[key] = batch[key][:, 0]
# NOTE: Order of observations matters here.
z = self.model.encode({k: batch[k] for k in ["observation.image", "observation.state"]})
if self.config.use_mpc:
batch_size = batch["observation.image"].shape[0]
# Batch processing is not handled in MPC mode, so process the batch in a loop.
action = [] # will be a batch of actions for one step
for i in range(batch_size):
# Note: self.plan does not handle batches, hence the squeeze.
action.append(self.plan(z[i]))
action = torch.stack(action)
encode_keys = []
if self._use_image:
encode_keys.append("observation.image")
if self._use_env_state:
encode_keys.append("observation.environment_state")
encode_keys.append("observation.state")
z = self.model.encode({k: batch[k] for k in encode_keys})
if self.config.use_mpc: # noqa: SIM108
actions = self.plan(z) # (horizon, batch, action_dim)
else:
# Plan with the policy (π) alone.
action = self.model.pi(z)
# Plan with the policy (π) alone. This always returns one action so unsqueeze to get a
# sequence dimension like in the MPC branch.
actions = self.model.pi(z).unsqueeze(0)
self.unnormalize_outputs({"action": action})["action"]
actions = torch.clamp(actions, -1, +1)
for _ in range(self.config.n_action_repeats):
self._queues["action"].append(action)
actions = self.unnormalize_outputs({"action": actions})["action"]
if self.config.n_action_repeats > 1:
for _ in range(self.config.n_action_repeats):
self._queues["action"].append(actions[0])
else:
# Action queue is (n_action_steps, batch_size, action_dim), so we transpose the action.
self._queues["action"].extend(actions[: self.config.n_action_steps])
action = self._queues["action"].popleft()
return torch.clamp(action, -1, 1)
return action
@torch.no_grad()
def plan(self, z: Tensor) -> Tensor:
"""Plan next action using TD-MPC inference.
"""Plan sequence of actions using TD-MPC inference.
Args:
z: (latent_dim,) tensor for the initial state.
z: (batch, latent_dim,) tensor for the initial state.
Returns:
(action_dim,) tensor for the next action.
TODO(alexander-soare) Extend this to be able to work with batches.
(horizon, batch, action_dim,) tensor for the planned trajectory of actions.
"""
device = get_device_from_parameters(self)
batch_size = z.shape[0]
# Sample Nπ trajectories from the policy.
pi_actions = torch.empty(
self.config.horizon,
self.config.n_pi_samples,
batch_size,
self.config.output_shapes["action"][0],
device=device,
)
if self.config.n_pi_samples > 0:
_z = einops.repeat(z, "d -> n d", n=self.config.n_pi_samples)
_z = einops.repeat(z, "b d -> n b d", n=self.config.n_pi_samples)
for t in range(self.config.horizon):
# Note: Adding a small amount of noise here doesn't hurt during inference and may even be
# helpful for CEM.
@@ -203,12 +203,14 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
# trajectories.
z = einops.repeat(z, "d -> n d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
# algorithm.
# The initial mean and standard deviation for the cross-entropy method (CEM).
mean = torch.zeros(self.config.horizon, self.config.output_shapes["action"][0], device=device)
mean = torch.zeros(
self.config.horizon, batch_size, self.config.output_shapes["action"][0], device=device
)
# Maybe warm start CEM with the mean from the previous step.
if self._prev_mean is not None:
mean[:-1] = self._prev_mean[1:]
@@ -219,6 +221,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
std_normal_noise = torch.randn(
self.config.horizon,
self.config.n_gaussian_samples,
batch_size,
self.config.output_shapes["action"][0],
device=std.device,
)
@@ -227,21 +230,24 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
# Compute elite actions.
actions = torch.cat([gaussian_actions, pi_actions], dim=1)
value = self.estimate_value(z, actions).nan_to_num_(0)
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch)
elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch)
# (horizon, n_elites, batch, action_dim)
elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1)
# Update guassian PDF parameters to be the (weighted) mean and standard deviation of the elites.
max_value = elite_value.max(0)[0]
# Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites.
max_value = elite_value.max(0, keepdim=True)[0] # (1, batch)
# The weighting is a softmax over trajectory values. Note that this is not the same as the usage
# of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This
# makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²).
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
score /= score.sum()
_mean = torch.sum(einops.rearrange(score, "n -> n 1") * elite_actions, dim=1)
score /= score.sum(axis=0, keepdim=True)
# (horizon, batch, action_dim)
_mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1)
_std = torch.sqrt(
torch.sum(
einops.rearrange(score, "n -> n 1")
* (elite_actions - einops.rearrange(_mean, "h d -> h 1 d")) ** 2,
einops.rearrange(score, "n b -> n b 1")
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2,
dim=1,
)
)
@@ -256,11 +262,9 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
# Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax
# scores from the last iteration.
actions = elite_actions[:, torch.multinomial(score, 1).item()]
actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)]
# Select only the first action
action = actions[0]
return action
return actions
@torch.no_grad()
def estimate_value(self, z: Tensor, actions: Tensor):
@@ -312,13 +316,17 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
return G
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss."""
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss.
Returns a dictionary with loss as a tensor, and other information as native floats.
"""
device = get_device_from_parameters(self)
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[self.input_image_key]
if self._use_image:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[self.input_image_key]
batch = self.normalize_targets(batch)
info = {}
@@ -328,12 +336,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
if batch[key].ndim > 1:
batch[key] = batch[key].transpose(1, 0)
action = batch["action"] # (t, b)
reward = batch["next.reward"] # (t,)
action = batch["action"] # (t, b, action_dim)
reward = batch["next.reward"] # (t, b)
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
# Apply random image augmentations.
if self.config.max_random_shift_ratio > 0:
if self._use_image and self.config.max_random_shift_ratio > 0:
observations["observation.image"] = flatten_forward_unflatten(
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
observations["observation.image"],
@@ -345,7 +353,9 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
for k in observations:
current_observation[k] = observations[k][0]
next_observations[k] = observations[k][1:]
horizon = next_observations["observation.image"].shape[0]
horizon, batch_size = next_observations[
"observation.image" if self._use_image else "observation.environment_state"
].shape[:2]
# Run latent rollout using the latent dynamics model and policy model.
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
@@ -415,7 +425,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
q_value_loss = (
(
F.mse_loss(
temporal_loss_coeffs
* F.mse_loss(
q_preds_ensemble,
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
reduction="none",
@@ -464,10 +475,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
# Calculate the MSE between the actions and the action predictions.
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
# gaussian) and sums over the action dimension. Computing the log probability amounts to multiplying
# the MSE by 0.5 and adding a constant offset (the log(2*pi) term) . Here we drop the constant offset
# as it doesn't change the optimization step, and we drop the 0.5 as we instead make a configuration
# parameter for it (see below where we compute the total loss).
# gaussian) and sums over the action dimension. Computing the (negative) log probability amounts to
# multiplying the MSE by 0.5 and adding a constant offset (the log(2*pi)/2 term, times the action
# dimension). Here we drop the constant offset as it doesn't change the optimization step, and we drop
# the 0.5 as we instead make a configuration parameter for it (see below where we compute the total
# loss).
mse = F.mse_loss(action_preds, action, reduction="none").sum(-1) # (t, b)
# NOTE: The original implementation does not take the sum over the temporal dimension like with the
# other losses.
@@ -728,6 +740,16 @@ class TDMPCObservationEncoder(nn.Module):
nn.LayerNorm(config.latent_dim),
nn.Sigmoid(),
)
if "observation.environment_state" in config.input_shapes:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Sigmoid(),
)
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""Encode the image and/or state vector.
@@ -736,8 +758,11 @@ class TDMPCObservationEncoder(nn.Module):
over all features.
"""
feat = []
# NOTE: Order of observations matters here.
if "observation.image" in self.config.input_shapes:
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict["observation.image"]))
if "observation.environment_state" in self.config.input_shapes:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
return torch.stack(feat, dim=0).mean(0)