Ran pre-commit run --all-files

This commit is contained in:
Simon Alibert
2024-02-29 13:37:48 +01:00
parent 0e0e760e43
commit 7e024fdce6
16 changed files with 124 additions and 237 deletions

View File

@@ -7,6 +7,8 @@ import torch.nn as nn
import lerobot.common.policies.tdmpc_helper as h
FIRST_FRAME = 0
class TOLD(nn.Module):
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
@@ -17,9 +19,7 @@ class TOLD(nn.Module):
self.cfg = cfg
self._encoder = h.enc(cfg)
self._dynamics = h.dynamics(
cfg.latent_dim + action_dim, cfg.mlp_dim, cfg.latent_dim
)
self._dynamics = h.dynamics(cfg.latent_dim + action_dim, cfg.mlp_dim, cfg.latent_dim)
self._reward = h.mlp(cfg.latent_dim + action_dim, cfg.mlp_dim, 1)
self._pi = h.mlp(cfg.latent_dim, cfg.mlp_dim, action_dim)
self._Qs = nn.ModuleList([h.q(cfg) for _ in range(cfg.num_q)])
@@ -65,20 +65,20 @@ class TOLD(nn.Module):
return h.TruncatedNormal(mu, std).sample(clip=0.3)
return mu
def V(self, z):
def V(self, z): # noqa: N802
"""Predict state value (V)."""
return self._V(z)
def Q(self, z, a, return_type):
def Q(self, z, a, return_type): # noqa: N802
"""Predict state-action value (Q)."""
assert return_type in {"min", "avg", "all"}
x = torch.cat([z, a], dim=-1)
if return_type == "all":
return torch.stack(list(q(x) for q in self._Qs), dim=0)
return torch.stack([q(x) for q in self._Qs], dim=0)
idxs = np.random.choice(self.cfg.num_q, 2, replace=False)
Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x)
Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x) # noqa: N806
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
@@ -146,25 +146,21 @@ class TDMPC(nn.Module):
@torch.no_grad()
def estimate_value(self, z, actions, horizon):
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
G, discount = 0, 1
G, discount = 0, 1 # noqa: N806
for t in range(horizon):
if self.cfg.uncertainty_cost > 0:
G -= (
G -= ( # noqa: N806
discount
* self.cfg.uncertainty_cost
* self.model.Q(z, actions[t], return_type="all").std(dim=0)
)
z, reward = self.model.next(z, actions[t])
G += discount * reward
G += discount * reward # noqa: N806
discount *= self.cfg.discount
pi = self.model.pi(z, self.cfg.min_std)
G += discount * self.model.Q(z, pi, return_type="min")
G += discount * self.model.Q(z, pi, return_type="min") # noqa: N806
if self.cfg.uncertainty_cost > 0:
G -= (
discount
* self.cfg.uncertainty_cost
* self.model.Q(z, pi, return_type="all").std(dim=0)
)
G -= discount * self.cfg.uncertainty_cost * self.model.Q(z, pi, return_type="all").std(dim=0) # noqa: N806
return G
@torch.no_grad()
@@ -180,19 +176,13 @@ class TDMPC(nn.Module):
assert step is not None
# Seed steps
if step < self.cfg.seed_steps and self.model.training:
return torch.empty(
self.action_dim, dtype=torch.float32, device=self.device
).uniform_(-1, 1)
return torch.empty(self.action_dim, dtype=torch.float32, device=self.device).uniform_(-1, 1)
# Sample policy trajectories
horizon = int(
min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step))
)
horizon = int(min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step)))
num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples)
if num_pi_trajs > 0:
pi_actions = torch.empty(
horizon, num_pi_trajs, self.action_dim, device=self.device
)
pi_actions = torch.empty(horizon, num_pi_trajs, self.action_dim, device=self.device)
_z = z.repeat(num_pi_trajs, 1)
for t in range(horizon):
pi_actions[t] = self.model.pi(_z, self.cfg.min_std)
@@ -201,20 +191,16 @@ class TDMPC(nn.Module):
# Initialize state and parameters
z = z.repeat(self.cfg.num_samples + num_pi_trajs, 1)
mean = torch.zeros(horizon, self.action_dim, device=self.device)
std = self.cfg.max_std * torch.ones(
horizon, self.action_dim, device=self.device
)
std = self.cfg.max_std * torch.ones(horizon, self.action_dim, device=self.device)
if not t0 and hasattr(self, "_prev_mean"):
mean[:-1] = self._prev_mean[1:]
# Iterate CEM
for i in range(self.cfg.iterations):
for _ in range(self.cfg.iterations):
actions = torch.clamp(
mean.unsqueeze(1)
+ std.unsqueeze(1)
* torch.randn(
horizon, self.cfg.num_samples, self.action_dim, device=std.device
),
* torch.randn(horizon, self.cfg.num_samples, self.action_dim, device=std.device),
-1,
1,
)
@@ -223,18 +209,14 @@ class TDMPC(nn.Module):
# Compute elite actions
value = self.estimate_value(z, actions, horizon).nan_to_num_(0)
elite_idxs = torch.topk(
value.squeeze(1), self.cfg.num_elites, dim=0
).indices
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
# Update parameters
max_value = elite_value.max(0)[0]
score = torch.exp(self.cfg.temperature * (elite_value - max_value))
score /= score.sum(0)
_mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (
score.sum(0) + 1e-9
)
_mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9)
_std = torch.sqrt(
torch.sum(
score.unsqueeze(0) * (elite_actions - _mean.unsqueeze(1)) ** 2,
@@ -331,7 +313,6 @@ class TDMPC(nn.Module):
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
batch = batch.to(self.device)
FIRST_FRAME = 0
obs = {
"rgb": batch["observation", "image"][FIRST_FRAME].float(),
"state": batch["observation", "state"][FIRST_FRAME],
@@ -359,10 +340,7 @@ class TDMPC(nn.Module):
weights = batch["_weight"][FIRST_FRAME, :, None]
return obs, action, next_obses, reward, mask, done, idxs, weights
if self.cfg.balanced_sampling:
batch = replay_buffer.sample(batch_size)
else:
batch = replay_buffer.sample()
batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
batch, self.cfg.horizon, num_slices
@@ -384,10 +362,7 @@ class TDMPC(nn.Module):
if isinstance(obs, dict):
obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
next_obses = {
k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1)
for k in next_obses
}
next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses}
else:
obs = torch.cat([obs, demo_obs])
next_obses = torch.cat([next_obses, demo_next_obses], dim=1)
@@ -429,9 +404,7 @@ class TDMPC(nn.Module):
td_targets = self._td_target(next_z, reward, mask)
# Latent rollout
zs = torch.empty(
horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device
)
zs = torch.empty(horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device)
reward_preds = torch.empty_like(reward, device=self.device)
assert reward.shape[0] == horizon
z = self.model.encode(obs)
@@ -452,12 +425,10 @@ class TDMPC(nn.Module):
value_info["V"] = v.mean().item()
# Losses
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(
-1, 1, 1
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1, 1)
consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask).sum(
dim=0
)
consistency_loss = (
rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask
).sum(dim=0)
reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0)
q_value_loss, priority_loss = 0, 0
for q in range(self.cfg.num_q):
@@ -465,9 +436,7 @@ class TDMPC(nn.Module):
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
expectile = h.linear_schedule(self.cfg.expectile, step)
v_value_loss = (
rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask
).sum(dim=0)
v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask).sum(dim=0)
total_loss = (
self.cfg.consistency_coef * consistency_loss