added tdmpc2 to policy factory; shape fixes in tdmpc2
This commit is contained in:
@@ -51,6 +51,13 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
|
||||
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
|
||||
return TDMPCPolicy, TDMPCConfig
|
||||
|
||||
elif name == "tdmpc2":
|
||||
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
|
||||
from lerobot.common.policies.tdmpc2.modeling_tdmpc2 import TDMPC2Policy
|
||||
|
||||
return TDMPC2Policy, TDMPC2Config
|
||||
|
||||
elif name == "diffusion":
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
@@ -389,7 +389,7 @@ class TDMPC2Policy(
|
||||
reward_loss = (
|
||||
(
|
||||
temporal_loss_coeffs
|
||||
* soft_cross_entropy(reward_preds, reward, self.config)
|
||||
* soft_cross_entropy(reward_preds, reward, self.config).mean(1)
|
||||
* ~batch["next.reward_is_pad"]
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
@@ -397,10 +397,11 @@ class TDMPC2Policy(
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
ce_value_loss = 0.0
|
||||
for i in range(self.config.q_ensemble_size):
|
||||
ce_value_loss += soft_cross_entropy(q_preds_ensemble[i], td_targets, self.config)
|
||||
ce_value_loss += soft_cross_entropy(q_preds_ensemble[i], td_targets, self.config).mean(1)
|
||||
|
||||
q_value_loss = (
|
||||
(
|
||||
@@ -420,7 +421,6 @@ class TDMPC2Policy(
|
||||
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
|
||||
# We won't need these gradients again so detach.
|
||||
z_preds = z_preds.detach()
|
||||
self.model.change_q_grad(mode=False)
|
||||
action_preds, _, log_pis, _ = self.model.pi(z_preds[:-1])
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -430,14 +430,9 @@ class TDMPC2Policy(
|
||||
self.scale.update(qs[0])
|
||||
qs = self.scale(qs)
|
||||
|
||||
rho = torch.pow(self.config.temporal_decay_coeff, torch.arange(len(qs), device=qs.device)).unsqueeze(
|
||||
-1
|
||||
)
|
||||
|
||||
pi_loss = (
|
||||
(self.config.entropy_coef * log_pis - qs).mean(dim=(1, 2))
|
||||
* rho
|
||||
# * temporal_loss_coeffs
|
||||
(self.config.entropy_coef * log_pis - qs).mean(dim=2)
|
||||
* temporal_loss_coeffs
|
||||
# `action_preds` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
@@ -447,7 +442,7 @@ class TDMPC2Policy(
|
||||
self.config.consistency_coeff * consistency_loss
|
||||
+ self.config.reward_coeff * reward_loss
|
||||
+ self.config.value_coeff * q_value_loss
|
||||
+ self.config.pi_coeff * pi_loss
|
||||
+ pi_loss
|
||||
)
|
||||
|
||||
info.update(
|
||||
|
||||
@@ -75,9 +75,6 @@ def soft_cross_entropy(pred, target, cfg):
|
||||
"""Computes the cross entropy loss between predictions and soft targets."""
|
||||
pred = F.log_softmax(pred, dim=-1)
|
||||
target = two_hot(target, cfg)
|
||||
import pudb
|
||||
|
||||
pudb.set_trace()
|
||||
return -(target * pred).sum(-1, keepdim=True)
|
||||
|
||||
|
||||
@@ -137,16 +134,20 @@ def symexp(x):
|
||||
|
||||
def two_hot(x, cfg):
|
||||
"""Converts a batch of scalars to soft two-hot encoded targets for discrete regression."""
|
||||
|
||||
# x shape [horizon, num_features]
|
||||
if cfg.num_bins == 0:
|
||||
return x
|
||||
elif cfg.num_bins == 1:
|
||||
return symlog(x)
|
||||
x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax)
|
||||
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long()
|
||||
bin_offset = (x - cfg.vmin) / cfg.bin_size - bin_idx.float()
|
||||
soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device)
|
||||
soft_two_hot.scatter_(1, bin_idx, 1 - bin_offset)
|
||||
soft_two_hot.scatter_(1, (bin_idx + 1) % cfg.num_bins, bin_offset)
|
||||
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long() # shape [num_features]
|
||||
bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx.float()).unsqueeze(-1) # shape [num_features , 1]
|
||||
soft_two_hot = torch.zeros(
|
||||
*x.shape, cfg.num_bins, device=x.device
|
||||
) # shape [horizon, num_features, num_bins]
|
||||
soft_two_hot.scatter_(2, bin_idx.unsqueeze(-1), 1 - bin_offset)
|
||||
soft_two_hot.scatter_(2, (bin_idx.unsqueeze(-1) + 1) % cfg.num_bins, bin_offset)
|
||||
return soft_two_hot
|
||||
|
||||
|
||||
|
||||
@@ -93,6 +93,18 @@ def make_optimizer_and_scheduler(cfg, policy):
|
||||
elif policy.name == "tdmpc":
|
||||
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
|
||||
lr_scheduler = None
|
||||
|
||||
elif policy.name == "tdmpc2":
|
||||
params_group = [
|
||||
{"params": policy.model._encoder.parameters(), "lr": cfg.training.lr * cfg.training.enc_lr_scale},
|
||||
{"params": policy.model._dynamics.parameters()},
|
||||
{"params": policy.model._reward.parameters()},
|
||||
{"params": policy.model._Qs.parameters()},
|
||||
{"params": policy.model._pi.parameters(), "eps": 1e-5},
|
||||
]
|
||||
optimizer = torch.optim.Adam(params_group, lr=cfg.training.lr)
|
||||
lr_scheduler = None
|
||||
|
||||
elif cfg.policy.name == "vqbet":
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
|
||||
|
||||
|
||||
Reference in New Issue
Block a user