added optimizer and sac to factory.py
This commit is contained in:
@@ -93,6 +93,15 @@ def make_optimizer_and_scheduler(cfg, policy):
|
|||||||
elif policy.name == "tdmpc":
|
elif policy.name == "tdmpc":
|
||||||
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
|
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
|
|
||||||
|
elif policy.name == "sac":
|
||||||
|
optimizer = torch.optim.Adam([
|
||||||
|
{'params': policy.actor.parameters(), 'lr': policy.config.actor_lr},
|
||||||
|
{'params': policy.critic_ensemble.parameters(), 'lr': policy.config.critic_lr},
|
||||||
|
{'params': policy.temperature.parameters(), 'lr': policy.config.temperature_lr},
|
||||||
|
])
|
||||||
|
lr_scheduler = None
|
||||||
|
|
||||||
elif cfg.policy.name == "vqbet":
|
elif cfg.policy.name == "vqbet":
|
||||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
|
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user