tests for tdmpc and diffusion policy are passing

This commit is contained in:
Cadene
2024-04-09 02:50:32 +00:00
parent 1e09507bc1
commit 73dfa3c8e3
6 changed files with 19 additions and 17 deletions

View File

@@ -16,8 +16,8 @@ def make_policy(cfg):
cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema,
n_obs_steps=cfg.n_obs_steps,
n_action_steps=cfg.n_action_steps,
# n_obs_steps=cfg.n_obs_steps,
# n_action_steps=cfg.n_action_steps,
**cfg.policy,
)
elif cfg.policy.name == "act":

View File

@@ -110,7 +110,6 @@ class TDMPCPolicy(nn.Module):
# self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
self.model.eval()
self.model_target.eval()
self.batch_size = cfg.batch_size
self.register_buffer("step", torch.zeros(1))
@@ -325,7 +324,7 @@ class TDMPCPolicy(nn.Module):
def _td_target(self, next_z, reward, mask):
"""Compute the TD-target from a reward and the observation at the following time step."""
next_v = self.model.V(next_z)
td_target = reward + self.cfg.discount * mask * next_v
td_target = reward + self.cfg.discount * mask * next_v.squeeze(2)
return td_target
def forward(self, batch, step):
@@ -420,6 +419,8 @@ class TDMPCPolicy(nn.Module):
# idxs = torch.cat([idxs, demo_idxs])
# weights = torch.cat([weights, demo_weights])
batch_size = batch["index"].shape[0]
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
# instead of currently (time/horizon, batch size, channels) which is not the pytorch convention
# batch size b = 256, time/horizon t = 5
@@ -433,7 +434,7 @@ class TDMPCPolicy(nn.Module):
# idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
weights = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
weights = torch.ones(batch_size, dtype=torch.bool, device=reward.device)
obses = {
"rgb": batch["observation.image"],
@@ -476,7 +477,7 @@ class TDMPCPolicy(nn.Module):
td_targets = self._td_target(next_z, reward, mask)
# Latent rollout
zs = torch.empty(horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device)
zs = torch.empty(horizon + 1, batch_size, self.cfg.latent_dim, device=self.device)
reward_preds = torch.empty_like(reward, device=self.device)
assert reward.shape[0] == horizon
z = self.model.encode(obs)
@@ -485,22 +486,21 @@ class TDMPCPolicy(nn.Module):
for t in range(horizon):
z, reward_pred = self.model.next(z, action[t])
zs[t + 1] = z
reward_preds[t] = reward_pred
reward_preds[t] = reward_pred.squeeze(1)
with torch.no_grad():
v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min")
# Predictions
qs = self.model.Q(zs[:-1], action, return_type="all")
qs = qs.squeeze(3)
value_info["Q"] = qs.mean().item()
v = self.model.V(zs[:-1])
value_info["V"] = v.mean().item()
# Losses
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1, 1)
consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask).sum(
dim=0
)
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1)
consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2) * loss_mask).sum(dim=0)
reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0)
q_value_loss, priority_loss = 0, 0
for q in range(self.cfg.num_q):
@@ -508,7 +508,9 @@ class TDMPCPolicy(nn.Module):
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
expectile = h.linear_schedule(self.cfg.expectile, step)
v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask).sum(dim=0)
v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile).squeeze(2) * loss_mask).sum(
dim=0
)
total_loss = (
self.cfg.consistency_coef * consistency_loss
@@ -517,7 +519,7 @@ class TDMPCPolicy(nn.Module):
+ self.cfg.value_coef * v_value_loss
)
weighted_loss = (total_loss.squeeze(1) * weights).mean()
weighted_loss = (total_loss * weights).mean()
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
has_nan = torch.isnan(weighted_loss).item()
if has_nan:

View File

@@ -38,6 +38,7 @@ policy:
horizon: ${horizon}
n_obs_steps: ${n_obs_steps}
n_action_steps: ${n_action_steps}
num_inference_steps: 100
obs_as_global_cond: ${obs_as_global_cond}
# crop_shape: null

View File

@@ -36,7 +36,6 @@ policy:
log_std_max: 2
# learning
batch_size: 256
max_buffer_size: 10000
horizon: 5
reward_coef: 0.5

4
poetry.lock generated
View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
[[package]]
name = "absl-py"
@@ -897,7 +897,7 @@ mujoco = "^2.3.7"
type = "git"
url = "git@github.com:huggingface/gym-aloha.git"
reference = "HEAD"
resolved_reference = "ec7200831e36c14e343cf7d275c6b047f2fe9d11"
resolved_reference = "c636f05ba0d1760df94537da84c860be1487e17f"
[[package]]
name = "gym-pusht"

View File

@@ -49,7 +49,7 @@ def test_policy(env_name, policy_name, extra_overrides):
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=cfg.policy.batch_size,
batch_size=2,
shuffle=True,
pin_memory=DEVICE != "cpu",
drop_last=True,