From ad7eea132d3796edb1911c3c69a08bda8f416148 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 23 Dec 2024 14:12:03 +0100 Subject: [PATCH] added optimizer and sac to factory.py --- lerobot/common/policies/factory.py | 6 ++++++ lerobot/common/policies/sac/configuration_sac.py | 1 + 2 files changed, 7 insertions(+) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 8def95a3..3519f077 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -59,6 +59,12 @@ def get_policy_class(name: str) -> PreTrainedPolicy: from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy return PI0FASTPolicy + elif name == "sac": + from lerobot.common.policies.sac.configuration_sac import SACConfig + from lerobot.common.policies.sac.modeling_sac import SACPolicy + + return SACPolicy, SACConfig + else: raise NotImplementedError(f"Policy with name {name} is not implemented.") diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index d324462e..6db198e8 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -26,6 +26,7 @@ class SACConfig: num_subsample_critics = None critic_lr = 3e-4 actor_lr = 3e-4 + temperature_lr = 3e-4 critic_target_update_weight = 0.005 utd_ratio = 2 critic_network_kwargs = {