diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 9a0b7e4c..346c3acd 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -93,6 +93,15 @@ def make_optimizer_and_scheduler(cfg, policy): elif policy.name == "tdmpc": optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr) lr_scheduler = None + + elif policy.name == "sac": + optimizer = torch.optim.Adam([ + {'params': policy.actor.parameters(), 'lr': policy.config.actor_lr}, + {'params': policy.critic_ensemble.parameters(), 'lr': policy.config.critic_lr}, + {'params': policy.temperature.parameters(), 'lr': policy.config.temperature_lr}, + ]) + lr_scheduler = None + elif cfg.policy.name == "vqbet": from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler