tests for tdmpc and diffusion policy are passing
This commit is contained in:
@@ -16,8 +16,8 @@ def make_policy(cfg):
|
|||||||
cfg_obs_encoder=cfg.obs_encoder,
|
cfg_obs_encoder=cfg.obs_encoder,
|
||||||
cfg_optimizer=cfg.optimizer,
|
cfg_optimizer=cfg.optimizer,
|
||||||
cfg_ema=cfg.ema,
|
cfg_ema=cfg.ema,
|
||||||
n_obs_steps=cfg.n_obs_steps,
|
# n_obs_steps=cfg.n_obs_steps,
|
||||||
n_action_steps=cfg.n_action_steps,
|
# n_action_steps=cfg.n_action_steps,
|
||||||
**cfg.policy,
|
**cfg.policy,
|
||||||
)
|
)
|
||||||
elif cfg.policy.name == "act":
|
elif cfg.policy.name == "act":
|
||||||
|
|||||||
@@ -110,7 +110,6 @@ class TDMPCPolicy(nn.Module):
|
|||||||
# self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
# self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
self.model_target.eval()
|
self.model_target.eval()
|
||||||
self.batch_size = cfg.batch_size
|
|
||||||
|
|
||||||
self.register_buffer("step", torch.zeros(1))
|
self.register_buffer("step", torch.zeros(1))
|
||||||
|
|
||||||
@@ -325,7 +324,7 @@ class TDMPCPolicy(nn.Module):
|
|||||||
def _td_target(self, next_z, reward, mask):
|
def _td_target(self, next_z, reward, mask):
|
||||||
"""Compute the TD-target from a reward and the observation at the following time step."""
|
"""Compute the TD-target from a reward and the observation at the following time step."""
|
||||||
next_v = self.model.V(next_z)
|
next_v = self.model.V(next_z)
|
||||||
td_target = reward + self.cfg.discount * mask * next_v
|
td_target = reward + self.cfg.discount * mask * next_v.squeeze(2)
|
||||||
return td_target
|
return td_target
|
||||||
|
|
||||||
def forward(self, batch, step):
|
def forward(self, batch, step):
|
||||||
@@ -420,6 +419,8 @@ class TDMPCPolicy(nn.Module):
|
|||||||
# idxs = torch.cat([idxs, demo_idxs])
|
# idxs = torch.cat([idxs, demo_idxs])
|
||||||
# weights = torch.cat([weights, demo_weights])
|
# weights = torch.cat([weights, demo_weights])
|
||||||
|
|
||||||
|
batch_size = batch["index"].shape[0]
|
||||||
|
|
||||||
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
|
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
|
||||||
# instead of currently (time/horizon, batch size, channels) which is not the pytorch convention
|
# instead of currently (time/horizon, batch size, channels) which is not the pytorch convention
|
||||||
# batch size b = 256, time/horizon t = 5
|
# batch size b = 256, time/horizon t = 5
|
||||||
@@ -433,7 +434,7 @@ class TDMPCPolicy(nn.Module):
|
|||||||
# idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights
|
# idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights
|
||||||
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
||||||
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
||||||
weights = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
weights = torch.ones(batch_size, dtype=torch.bool, device=reward.device)
|
||||||
|
|
||||||
obses = {
|
obses = {
|
||||||
"rgb": batch["observation.image"],
|
"rgb": batch["observation.image"],
|
||||||
@@ -476,7 +477,7 @@ class TDMPCPolicy(nn.Module):
|
|||||||
td_targets = self._td_target(next_z, reward, mask)
|
td_targets = self._td_target(next_z, reward, mask)
|
||||||
|
|
||||||
# Latent rollout
|
# Latent rollout
|
||||||
zs = torch.empty(horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device)
|
zs = torch.empty(horizon + 1, batch_size, self.cfg.latent_dim, device=self.device)
|
||||||
reward_preds = torch.empty_like(reward, device=self.device)
|
reward_preds = torch.empty_like(reward, device=self.device)
|
||||||
assert reward.shape[0] == horizon
|
assert reward.shape[0] == horizon
|
||||||
z = self.model.encode(obs)
|
z = self.model.encode(obs)
|
||||||
@@ -485,22 +486,21 @@ class TDMPCPolicy(nn.Module):
|
|||||||
for t in range(horizon):
|
for t in range(horizon):
|
||||||
z, reward_pred = self.model.next(z, action[t])
|
z, reward_pred = self.model.next(z, action[t])
|
||||||
zs[t + 1] = z
|
zs[t + 1] = z
|
||||||
reward_preds[t] = reward_pred
|
reward_preds[t] = reward_pred.squeeze(1)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min")
|
v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min")
|
||||||
|
|
||||||
# Predictions
|
# Predictions
|
||||||
qs = self.model.Q(zs[:-1], action, return_type="all")
|
qs = self.model.Q(zs[:-1], action, return_type="all")
|
||||||
|
qs = qs.squeeze(3)
|
||||||
value_info["Q"] = qs.mean().item()
|
value_info["Q"] = qs.mean().item()
|
||||||
v = self.model.V(zs[:-1])
|
v = self.model.V(zs[:-1])
|
||||||
value_info["V"] = v.mean().item()
|
value_info["V"] = v.mean().item()
|
||||||
|
|
||||||
# Losses
|
# Losses
|
||||||
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1, 1)
|
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1)
|
||||||
consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask).sum(
|
consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2) * loss_mask).sum(dim=0)
|
||||||
dim=0
|
|
||||||
)
|
|
||||||
reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0)
|
reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0)
|
||||||
q_value_loss, priority_loss = 0, 0
|
q_value_loss, priority_loss = 0, 0
|
||||||
for q in range(self.cfg.num_q):
|
for q in range(self.cfg.num_q):
|
||||||
@@ -508,7 +508,9 @@ class TDMPCPolicy(nn.Module):
|
|||||||
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
|
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
|
||||||
|
|
||||||
expectile = h.linear_schedule(self.cfg.expectile, step)
|
expectile = h.linear_schedule(self.cfg.expectile, step)
|
||||||
v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask).sum(dim=0)
|
v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile).squeeze(2) * loss_mask).sum(
|
||||||
|
dim=0
|
||||||
|
)
|
||||||
|
|
||||||
total_loss = (
|
total_loss = (
|
||||||
self.cfg.consistency_coef * consistency_loss
|
self.cfg.consistency_coef * consistency_loss
|
||||||
@@ -517,7 +519,7 @@ class TDMPCPolicy(nn.Module):
|
|||||||
+ self.cfg.value_coef * v_value_loss
|
+ self.cfg.value_coef * v_value_loss
|
||||||
)
|
)
|
||||||
|
|
||||||
weighted_loss = (total_loss.squeeze(1) * weights).mean()
|
weighted_loss = (total_loss * weights).mean()
|
||||||
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
|
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
|
||||||
has_nan = torch.isnan(weighted_loss).item()
|
has_nan = torch.isnan(weighted_loss).item()
|
||||||
if has_nan:
|
if has_nan:
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ policy:
|
|||||||
|
|
||||||
horizon: ${horizon}
|
horizon: ${horizon}
|
||||||
n_obs_steps: ${n_obs_steps}
|
n_obs_steps: ${n_obs_steps}
|
||||||
|
n_action_steps: ${n_action_steps}
|
||||||
num_inference_steps: 100
|
num_inference_steps: 100
|
||||||
obs_as_global_cond: ${obs_as_global_cond}
|
obs_as_global_cond: ${obs_as_global_cond}
|
||||||
# crop_shape: null
|
# crop_shape: null
|
||||||
|
|||||||
@@ -36,7 +36,6 @@ policy:
|
|||||||
log_std_max: 2
|
log_std_max: 2
|
||||||
|
|
||||||
# learning
|
# learning
|
||||||
batch_size: 256
|
|
||||||
max_buffer_size: 10000
|
max_buffer_size: 10000
|
||||||
horizon: 5
|
horizon: 5
|
||||||
reward_coef: 0.5
|
reward_coef: 0.5
|
||||||
|
|||||||
4
poetry.lock
generated
4
poetry.lock
generated
@@ -1,4 +1,4 @@
|
|||||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "absl-py"
|
name = "absl-py"
|
||||||
@@ -897,7 +897,7 @@ mujoco = "^2.3.7"
|
|||||||
type = "git"
|
type = "git"
|
||||||
url = "git@github.com:huggingface/gym-aloha.git"
|
url = "git@github.com:huggingface/gym-aloha.git"
|
||||||
reference = "HEAD"
|
reference = "HEAD"
|
||||||
resolved_reference = "ec7200831e36c14e343cf7d275c6b047f2fe9d11"
|
resolved_reference = "c636f05ba0d1760df94537da84c860be1487e17f"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "gym-pusht"
|
name = "gym-pusht"
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
batch_size=cfg.policy.batch_size,
|
batch_size=2,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
pin_memory=DEVICE != "cpu",
|
pin_memory=DEVICE != "cpu",
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
|
|||||||
Reference in New Issue
Block a user