diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 297e9f53..341e516b 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -77,8 +77,19 @@ class SACPolicy( ) critic_nets.append(critic_net) + target_critic_nets = [] + for _ in range(config.num_critics): + target_critic_net = Critic( + encoder=encoder_critic, + network=MLP( + input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], + **config.critic_network_kwargs + ) + ) + target_critic_nets.append(target_critic_net) + self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics) - self.critic_target = deepcopy(self.critic_ensemble) + self.critic_target = create_critic_ensemble(target_critic_nets, config.num_critics) self.actor = Policy( encoder=encoder_actor, @@ -169,12 +180,12 @@ class SACPolicy( # critics subsample size min_q, _ = q_targets.min(dim=0) # Get values from min operation - - # compute td target - td_target = rewards + self.config.discount * min_q + # breakpoint() if self.config.use_backup_entropy: - td_target -= self.config.discount * self.temperature() * log_probs \ - * ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon] + min_q -= self.temperature() * log_probs * ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon] + td_target = rewards + self.config.discount * min_q * ~batch["next.done"] + # td_target -= self.config.discount * self.temperature() * log_probs \ + # * ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon] # print(f"Target Q-values: mean={td_target.mean():.3f}, max={td_target.max():.3f}") # 3- compute predicted qs diff --git a/lerobot/configs/policy/sac_pusht_keypoints.yaml b/lerobot/configs/policy/sac_pusht_keypoints.yaml index f4f4dba8..19f02b57 100644 --- a/lerobot/configs/policy/sac_pusht_keypoints.yaml +++ b/lerobot/configs/policy/sac_pusht_keypoints.yaml @@ -29,7 +29,7 @@ training: online_steps_between_rollouts: 1000 online_sampling_ratio: 1.0 online_env_seed: 10000 - online_buffer_capacity: 10000 + online_buffer_capacity: 40000 online_buffer_seed_size: 0 do_online_rollout_async: false @@ -70,9 +70,9 @@ policy: temperature_init: 1.0 num_critics: 2 num_subsample_critics: None - critic_lr: 1e-4 - actor_lr: 1e-4 - temperature_lr: 1e-4 + critic_lr: 3e-4 + actor_lr: 3e-4 + temperature_lr: 3e-4 critic_target_update_weight: 0.005 utd_ratio: 2