diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index f3eb8e946..f78667146 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -112,7 +112,6 @@ class SACPolicy( self.critic_ensemble = torch.compile(self.critic_ensemble) self.critic_target = torch.compile(self.critic_target) - self.actor = Policy( encoder=encoder_actor, network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)), diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py index b388f62e6..c76dc003d 100644 --- a/lerobot/scripts/server/actor_server.py +++ b/lerobot/scripts/server/actor_server.py @@ -231,7 +231,6 @@ def act_with_policy( cfg=cfg.policy, env_cfg=cfg.env, ) - policy = torch.compile(policy) assert isinstance(policy, nn.Module) obs, info = online_env.reset() diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 2334d2e08..0a1b0a776 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -285,9 +285,7 @@ def add_actor_information_and_train( # ds_meta=cfg.dataset, env_cfg=cfg.env, ) - - # compile policy - policy = torch.compile(policy) + assert isinstance(policy, nn.Module) policy.train()