added tdmpc2 to policy factory; shape fixes in tdmpc2

This commit is contained in:
Michel Aractingi
2024-11-26 11:58:29 +00:00
parent 16edbbdeee
commit 14490148f3
4 changed files with 34 additions and 19 deletions

View File

@@ -51,6 +51,13 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
return TDMPCPolicy, TDMPCConfig
elif name == "tdmpc2":
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
from lerobot.common.policies.tdmpc2.modeling_tdmpc2 import TDMPC2Policy
return TDMPC2Policy, TDMPC2Config
elif name == "diffusion":
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy

View File

@@ -389,7 +389,7 @@ class TDMPC2Policy(
reward_loss = (
(
temporal_loss_coeffs
* soft_cross_entropy(reward_preds, reward, self.config)
* soft_cross_entropy(reward_preds, reward, self.config).mean(1)
* ~batch["next.reward_is_pad"]
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
@@ -397,10 +397,11 @@ class TDMPC2Policy(
.sum(0)
.mean()
)
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
ce_value_loss = 0.0
for i in range(self.config.q_ensemble_size):
ce_value_loss += soft_cross_entropy(q_preds_ensemble[i], td_targets, self.config)
ce_value_loss += soft_cross_entropy(q_preds_ensemble[i], td_targets, self.config).mean(1)
q_value_loss = (
(
@@ -420,7 +421,6 @@ class TDMPC2Policy(
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
# We won't need these gradients again so detach.
z_preds = z_preds.detach()
self.model.change_q_grad(mode=False)
action_preds, _, log_pis, _ = self.model.pi(z_preds[:-1])
with torch.no_grad():
@@ -430,14 +430,9 @@ class TDMPC2Policy(
self.scale.update(qs[0])
qs = self.scale(qs)
rho = torch.pow(self.config.temporal_decay_coeff, torch.arange(len(qs), device=qs.device)).unsqueeze(
-1
)
pi_loss = (
(self.config.entropy_coef * log_pis - qs).mean(dim=(1, 2))
* rho
# * temporal_loss_coeffs
(self.config.entropy_coef * log_pis - qs).mean(dim=2)
* temporal_loss_coeffs
# `action_preds` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
@@ -447,7 +442,7 @@ class TDMPC2Policy(
self.config.consistency_coeff * consistency_loss
+ self.config.reward_coeff * reward_loss
+ self.config.value_coeff * q_value_loss
+ self.config.pi_coeff * pi_loss
+ pi_loss
)
info.update(

View File

@@ -75,9 +75,6 @@ def soft_cross_entropy(pred, target, cfg):
"""Computes the cross entropy loss between predictions and soft targets."""
pred = F.log_softmax(pred, dim=-1)
target = two_hot(target, cfg)
import pudb
pudb.set_trace()
return -(target * pred).sum(-1, keepdim=True)
@@ -137,16 +134,20 @@ def symexp(x):
def two_hot(x, cfg):
"""Converts a batch of scalars to soft two-hot encoded targets for discrete regression."""
# x shape [horizon, num_features]
if cfg.num_bins == 0:
return x
elif cfg.num_bins == 1:
return symlog(x)
x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax)
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long()
bin_offset = (x - cfg.vmin) / cfg.bin_size - bin_idx.float()
soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device)
soft_two_hot.scatter_(1, bin_idx, 1 - bin_offset)
soft_two_hot.scatter_(1, (bin_idx + 1) % cfg.num_bins, bin_offset)
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long() # shape [num_features]
bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx.float()).unsqueeze(-1) # shape [num_features , 1]
soft_two_hot = torch.zeros(
*x.shape, cfg.num_bins, device=x.device
) # shape [horizon, num_features, num_bins]
soft_two_hot.scatter_(2, bin_idx.unsqueeze(-1), 1 - bin_offset)
soft_two_hot.scatter_(2, (bin_idx.unsqueeze(-1) + 1) % cfg.num_bins, bin_offset)
return soft_two_hot

View File

@@ -93,6 +93,18 @@ def make_optimizer_and_scheduler(cfg, policy):
elif policy.name == "tdmpc":
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
lr_scheduler = None
elif policy.name == "tdmpc2":
params_group = [
{"params": policy.model._encoder.parameters(), "lr": cfg.training.lr * cfg.training.enc_lr_scale},
{"params": policy.model._dynamics.parameters()},
{"params": policy.model._reward.parameters()},
{"params": policy.model._Qs.parameters()},
{"params": policy.model._pi.parameters(), "eps": 1e-5},
]
optimizer = torch.optim.Adam(params_group, lr=cfg.training.lr)
lr_scheduler = None
elif cfg.policy.name == "vqbet":
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler