From 026ad463a9ca0a4ce9e45d3fe7f11349d1781c1e Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Mon, 31 Mar 2025 13:54:21 +0000 Subject: [PATCH] Fix convergence of sac, multiple torch compile on the same model caused divergence --- lerobot/common/policies/sac/modeling_sac.py | 1 - lerobot/scripts/server/actor_server.py | 1 - lerobot/scripts/server/learner_server.py | 4 +--- 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index f3eb8e94..f7866714 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -112,7 +112,6 @@ class SACPolicy( self.critic_ensemble = torch.compile(self.critic_ensemble) self.critic_target = torch.compile(self.critic_target) - self.actor = Policy( encoder=encoder_actor, network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)), diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py index b388f62e..c76dc003 100644 --- a/lerobot/scripts/server/actor_server.py +++ b/lerobot/scripts/server/actor_server.py @@ -231,7 +231,6 @@ def act_with_policy( cfg=cfg.policy, env_cfg=cfg.env, ) - policy = torch.compile(policy) assert isinstance(policy, nn.Module) obs, info = online_env.reset() diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 2334d2e0..0a1b0a77 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -285,9 +285,7 @@ def add_actor_information_and_train( # ds_meta=cfg.dataset, env_cfg=cfg.env, ) - - # compile policy - policy = torch.compile(policy) + assert isinstance(policy, nn.Module) policy.train()