tests for tdmpc and diffusion policy are passing
This commit is contained in:
@@ -16,8 +16,8 @@ def make_policy(cfg):
|
||||
cfg_obs_encoder=cfg.obs_encoder,
|
||||
cfg_optimizer=cfg.optimizer,
|
||||
cfg_ema=cfg.ema,
|
||||
n_obs_steps=cfg.n_obs_steps,
|
||||
n_action_steps=cfg.n_action_steps,
|
||||
# n_obs_steps=cfg.n_obs_steps,
|
||||
# n_action_steps=cfg.n_action_steps,
|
||||
**cfg.policy,
|
||||
)
|
||||
elif cfg.policy.name == "act":
|
||||
|
||||
@@ -110,7 +110,6 @@ class TDMPCPolicy(nn.Module):
|
||||
# self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||
self.model.eval()
|
||||
self.model_target.eval()
|
||||
self.batch_size = cfg.batch_size
|
||||
|
||||
self.register_buffer("step", torch.zeros(1))
|
||||
|
||||
@@ -325,7 +324,7 @@ class TDMPCPolicy(nn.Module):
|
||||
def _td_target(self, next_z, reward, mask):
|
||||
"""Compute the TD-target from a reward and the observation at the following time step."""
|
||||
next_v = self.model.V(next_z)
|
||||
td_target = reward + self.cfg.discount * mask * next_v
|
||||
td_target = reward + self.cfg.discount * mask * next_v.squeeze(2)
|
||||
return td_target
|
||||
|
||||
def forward(self, batch, step):
|
||||
@@ -420,6 +419,8 @@ class TDMPCPolicy(nn.Module):
|
||||
# idxs = torch.cat([idxs, demo_idxs])
|
||||
# weights = torch.cat([weights, demo_weights])
|
||||
|
||||
batch_size = batch["index"].shape[0]
|
||||
|
||||
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
|
||||
# instead of currently (time/horizon, batch size, channels) which is not the pytorch convention
|
||||
# batch size b = 256, time/horizon t = 5
|
||||
@@ -433,7 +434,7 @@ class TDMPCPolicy(nn.Module):
|
||||
# idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights
|
||||
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
||||
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
||||
weights = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
||||
weights = torch.ones(batch_size, dtype=torch.bool, device=reward.device)
|
||||
|
||||
obses = {
|
||||
"rgb": batch["observation.image"],
|
||||
@@ -476,7 +477,7 @@ class TDMPCPolicy(nn.Module):
|
||||
td_targets = self._td_target(next_z, reward, mask)
|
||||
|
||||
# Latent rollout
|
||||
zs = torch.empty(horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device)
|
||||
zs = torch.empty(horizon + 1, batch_size, self.cfg.latent_dim, device=self.device)
|
||||
reward_preds = torch.empty_like(reward, device=self.device)
|
||||
assert reward.shape[0] == horizon
|
||||
z = self.model.encode(obs)
|
||||
@@ -485,22 +486,21 @@ class TDMPCPolicy(nn.Module):
|
||||
for t in range(horizon):
|
||||
z, reward_pred = self.model.next(z, action[t])
|
||||
zs[t + 1] = z
|
||||
reward_preds[t] = reward_pred
|
||||
reward_preds[t] = reward_pred.squeeze(1)
|
||||
|
||||
with torch.no_grad():
|
||||
v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min")
|
||||
|
||||
# Predictions
|
||||
qs = self.model.Q(zs[:-1], action, return_type="all")
|
||||
qs = qs.squeeze(3)
|
||||
value_info["Q"] = qs.mean().item()
|
||||
v = self.model.V(zs[:-1])
|
||||
value_info["V"] = v.mean().item()
|
||||
|
||||
# Losses
|
||||
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1, 1)
|
||||
consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask).sum(
|
||||
dim=0
|
||||
)
|
||||
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1)
|
||||
consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2) * loss_mask).sum(dim=0)
|
||||
reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0)
|
||||
q_value_loss, priority_loss = 0, 0
|
||||
for q in range(self.cfg.num_q):
|
||||
@@ -508,7 +508,9 @@ class TDMPCPolicy(nn.Module):
|
||||
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
|
||||
|
||||
expectile = h.linear_schedule(self.cfg.expectile, step)
|
||||
v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask).sum(dim=0)
|
||||
v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile).squeeze(2) * loss_mask).sum(
|
||||
dim=0
|
||||
)
|
||||
|
||||
total_loss = (
|
||||
self.cfg.consistency_coef * consistency_loss
|
||||
@@ -517,7 +519,7 @@ class TDMPCPolicy(nn.Module):
|
||||
+ self.cfg.value_coef * v_value_loss
|
||||
)
|
||||
|
||||
weighted_loss = (total_loss.squeeze(1) * weights).mean()
|
||||
weighted_loss = (total_loss * weights).mean()
|
||||
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
|
||||
has_nan = torch.isnan(weighted_loss).item()
|
||||
if has_nan:
|
||||
|
||||
@@ -38,6 +38,7 @@ policy:
|
||||
|
||||
horizon: ${horizon}
|
||||
n_obs_steps: ${n_obs_steps}
|
||||
n_action_steps: ${n_action_steps}
|
||||
num_inference_steps: 100
|
||||
obs_as_global_cond: ${obs_as_global_cond}
|
||||
# crop_shape: null
|
||||
|
||||
@@ -36,7 +36,6 @@ policy:
|
||||
log_std_max: 2
|
||||
|
||||
# learning
|
||||
batch_size: 256
|
||||
max_buffer_size: 10000
|
||||
horizon: 5
|
||||
reward_coef: 0.5
|
||||
|
||||
4
poetry.lock
generated
4
poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "absl-py"
|
||||
@@ -897,7 +897,7 @@ mujoco = "^2.3.7"
|
||||
type = "git"
|
||||
url = "git@github.com:huggingface/gym-aloha.git"
|
||||
reference = "HEAD"
|
||||
resolved_reference = "ec7200831e36c14e343cf7d275c6b047f2fe9d11"
|
||||
resolved_reference = "c636f05ba0d1760df94537da84c860be1487e17f"
|
||||
|
||||
[[package]]
|
||||
name = "gym-pusht"
|
||||
|
||||
@@ -49,7 +49,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=cfg.policy.batch_size,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
pin_memory=DEVICE != "cpu",
|
||||
drop_last=True,
|
||||
|
||||
Reference in New Issue
Block a user