From b0e2fcdba7211a875f8e9bd288160b8a86655297 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 23 Dec 2024 14:12:03 +0100 Subject: [PATCH] added optimizer and sac to factory.py --- lerobot/scripts/train.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 9a0b7e4c..346c3acd 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -93,6 +93,15 @@ def make_optimizer_and_scheduler(cfg, policy): elif policy.name == "tdmpc": optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr) lr_scheduler = None + + elif policy.name == "sac": + optimizer = torch.optim.Adam([ + {'params': policy.actor.parameters(), 'lr': policy.config.actor_lr}, + {'params': policy.critic_ensemble.parameters(), 'lr': policy.config.critic_lr}, + {'params': policy.temperature.parameters(), 'lr': policy.config.temperature_lr}, + ]) + lr_scheduler = None + elif cfg.policy.name == "vqbet": from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler