Ran pre-commit run --all-files
This commit is contained in:
@@ -7,6 +7,8 @@ import torch.nn as nn
|
||||
|
||||
import lerobot.common.policies.tdmpc_helper as h
|
||||
|
||||
FIRST_FRAME = 0
|
||||
|
||||
|
||||
class TOLD(nn.Module):
|
||||
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
|
||||
@@ -17,9 +19,7 @@ class TOLD(nn.Module):
|
||||
|
||||
self.cfg = cfg
|
||||
self._encoder = h.enc(cfg)
|
||||
self._dynamics = h.dynamics(
|
||||
cfg.latent_dim + action_dim, cfg.mlp_dim, cfg.latent_dim
|
||||
)
|
||||
self._dynamics = h.dynamics(cfg.latent_dim + action_dim, cfg.mlp_dim, cfg.latent_dim)
|
||||
self._reward = h.mlp(cfg.latent_dim + action_dim, cfg.mlp_dim, 1)
|
||||
self._pi = h.mlp(cfg.latent_dim, cfg.mlp_dim, action_dim)
|
||||
self._Qs = nn.ModuleList([h.q(cfg) for _ in range(cfg.num_q)])
|
||||
@@ -65,20 +65,20 @@ class TOLD(nn.Module):
|
||||
return h.TruncatedNormal(mu, std).sample(clip=0.3)
|
||||
return mu
|
||||
|
||||
def V(self, z):
|
||||
def V(self, z): # noqa: N802
|
||||
"""Predict state value (V)."""
|
||||
return self._V(z)
|
||||
|
||||
def Q(self, z, a, return_type):
|
||||
def Q(self, z, a, return_type): # noqa: N802
|
||||
"""Predict state-action value (Q)."""
|
||||
assert return_type in {"min", "avg", "all"}
|
||||
x = torch.cat([z, a], dim=-1)
|
||||
|
||||
if return_type == "all":
|
||||
return torch.stack(list(q(x) for q in self._Qs), dim=0)
|
||||
return torch.stack([q(x) for q in self._Qs], dim=0)
|
||||
|
||||
idxs = np.random.choice(self.cfg.num_q, 2, replace=False)
|
||||
Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x)
|
||||
Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x) # noqa: N806
|
||||
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
|
||||
|
||||
|
||||
@@ -146,25 +146,21 @@ class TDMPC(nn.Module):
|
||||
@torch.no_grad()
|
||||
def estimate_value(self, z, actions, horizon):
|
||||
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
|
||||
G, discount = 0, 1
|
||||
G, discount = 0, 1 # noqa: N806
|
||||
for t in range(horizon):
|
||||
if self.cfg.uncertainty_cost > 0:
|
||||
G -= (
|
||||
G -= ( # noqa: N806
|
||||
discount
|
||||
* self.cfg.uncertainty_cost
|
||||
* self.model.Q(z, actions[t], return_type="all").std(dim=0)
|
||||
)
|
||||
z, reward = self.model.next(z, actions[t])
|
||||
G += discount * reward
|
||||
G += discount * reward # noqa: N806
|
||||
discount *= self.cfg.discount
|
||||
pi = self.model.pi(z, self.cfg.min_std)
|
||||
G += discount * self.model.Q(z, pi, return_type="min")
|
||||
G += discount * self.model.Q(z, pi, return_type="min") # noqa: N806
|
||||
if self.cfg.uncertainty_cost > 0:
|
||||
G -= (
|
||||
discount
|
||||
* self.cfg.uncertainty_cost
|
||||
* self.model.Q(z, pi, return_type="all").std(dim=0)
|
||||
)
|
||||
G -= discount * self.cfg.uncertainty_cost * self.model.Q(z, pi, return_type="all").std(dim=0) # noqa: N806
|
||||
return G
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -180,19 +176,13 @@ class TDMPC(nn.Module):
|
||||
assert step is not None
|
||||
# Seed steps
|
||||
if step < self.cfg.seed_steps and self.model.training:
|
||||
return torch.empty(
|
||||
self.action_dim, dtype=torch.float32, device=self.device
|
||||
).uniform_(-1, 1)
|
||||
return torch.empty(self.action_dim, dtype=torch.float32, device=self.device).uniform_(-1, 1)
|
||||
|
||||
# Sample policy trajectories
|
||||
horizon = int(
|
||||
min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step))
|
||||
)
|
||||
horizon = int(min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step)))
|
||||
num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples)
|
||||
if num_pi_trajs > 0:
|
||||
pi_actions = torch.empty(
|
||||
horizon, num_pi_trajs, self.action_dim, device=self.device
|
||||
)
|
||||
pi_actions = torch.empty(horizon, num_pi_trajs, self.action_dim, device=self.device)
|
||||
_z = z.repeat(num_pi_trajs, 1)
|
||||
for t in range(horizon):
|
||||
pi_actions[t] = self.model.pi(_z, self.cfg.min_std)
|
||||
@@ -201,20 +191,16 @@ class TDMPC(nn.Module):
|
||||
# Initialize state and parameters
|
||||
z = z.repeat(self.cfg.num_samples + num_pi_trajs, 1)
|
||||
mean = torch.zeros(horizon, self.action_dim, device=self.device)
|
||||
std = self.cfg.max_std * torch.ones(
|
||||
horizon, self.action_dim, device=self.device
|
||||
)
|
||||
std = self.cfg.max_std * torch.ones(horizon, self.action_dim, device=self.device)
|
||||
if not t0 and hasattr(self, "_prev_mean"):
|
||||
mean[:-1] = self._prev_mean[1:]
|
||||
|
||||
# Iterate CEM
|
||||
for i in range(self.cfg.iterations):
|
||||
for _ in range(self.cfg.iterations):
|
||||
actions = torch.clamp(
|
||||
mean.unsqueeze(1)
|
||||
+ std.unsqueeze(1)
|
||||
* torch.randn(
|
||||
horizon, self.cfg.num_samples, self.action_dim, device=std.device
|
||||
),
|
||||
* torch.randn(horizon, self.cfg.num_samples, self.action_dim, device=std.device),
|
||||
-1,
|
||||
1,
|
||||
)
|
||||
@@ -223,18 +209,14 @@ class TDMPC(nn.Module):
|
||||
|
||||
# Compute elite actions
|
||||
value = self.estimate_value(z, actions, horizon).nan_to_num_(0)
|
||||
elite_idxs = torch.topk(
|
||||
value.squeeze(1), self.cfg.num_elites, dim=0
|
||||
).indices
|
||||
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
|
||||
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
|
||||
|
||||
# Update parameters
|
||||
max_value = elite_value.max(0)[0]
|
||||
score = torch.exp(self.cfg.temperature * (elite_value - max_value))
|
||||
score /= score.sum(0)
|
||||
_mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (
|
||||
score.sum(0) + 1e-9
|
||||
)
|
||||
_mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9)
|
||||
_std = torch.sqrt(
|
||||
torch.sum(
|
||||
score.unsqueeze(0) * (elite_actions - _mean.unsqueeze(1)) ** 2,
|
||||
@@ -331,7 +313,6 @@ class TDMPC(nn.Module):
|
||||
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
|
||||
batch = batch.to(self.device)
|
||||
|
||||
FIRST_FRAME = 0
|
||||
obs = {
|
||||
"rgb": batch["observation", "image"][FIRST_FRAME].float(),
|
||||
"state": batch["observation", "state"][FIRST_FRAME],
|
||||
@@ -359,10 +340,7 @@ class TDMPC(nn.Module):
|
||||
weights = batch["_weight"][FIRST_FRAME, :, None]
|
||||
return obs, action, next_obses, reward, mask, done, idxs, weights
|
||||
|
||||
if self.cfg.balanced_sampling:
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
else:
|
||||
batch = replay_buffer.sample()
|
||||
batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
|
||||
|
||||
obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
|
||||
batch, self.cfg.horizon, num_slices
|
||||
@@ -384,10 +362,7 @@ class TDMPC(nn.Module):
|
||||
|
||||
if isinstance(obs, dict):
|
||||
obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
|
||||
next_obses = {
|
||||
k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1)
|
||||
for k in next_obses
|
||||
}
|
||||
next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses}
|
||||
else:
|
||||
obs = torch.cat([obs, demo_obs])
|
||||
next_obses = torch.cat([next_obses, demo_next_obses], dim=1)
|
||||
@@ -429,9 +404,7 @@ class TDMPC(nn.Module):
|
||||
td_targets = self._td_target(next_z, reward, mask)
|
||||
|
||||
# Latent rollout
|
||||
zs = torch.empty(
|
||||
horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device
|
||||
)
|
||||
zs = torch.empty(horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device)
|
||||
reward_preds = torch.empty_like(reward, device=self.device)
|
||||
assert reward.shape[0] == horizon
|
||||
z = self.model.encode(obs)
|
||||
@@ -452,12 +425,10 @@ class TDMPC(nn.Module):
|
||||
value_info["V"] = v.mean().item()
|
||||
|
||||
# Losses
|
||||
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(
|
||||
-1, 1, 1
|
||||
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1, 1)
|
||||
consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask).sum(
|
||||
dim=0
|
||||
)
|
||||
consistency_loss = (
|
||||
rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask
|
||||
).sum(dim=0)
|
||||
reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0)
|
||||
q_value_loss, priority_loss = 0, 0
|
||||
for q in range(self.cfg.num_q):
|
||||
@@ -465,9 +436,7 @@ class TDMPC(nn.Module):
|
||||
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
|
||||
|
||||
expectile = h.linear_schedule(self.cfg.expectile, step)
|
||||
v_value_loss = (
|
||||
rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask
|
||||
).sum(dim=0)
|
||||
v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask).sum(dim=0)
|
||||
|
||||
total_loss = (
|
||||
self.cfg.consistency_coef * consistency_loss
|
||||
|
||||
Reference in New Issue
Block a user