SAC works

This commit is contained in:
Adil Zouitine
2025-01-14 11:34:52 +01:00
committed by Michel Aractingi
parent e8449e9630
commit 2fd78879f6
2 changed files with 10 additions and 49 deletions

View File

@@ -254,7 +254,9 @@ class SACPolicy(
next_action_preds, next_log_probs, _ = self.actor(next_observations)
# 2- compute q targets
q_targets = self.critic_forward(next_observations, next_action_preds, use_target=True)
q_targets = self.critic_forward(
observations=next_observations, actions=next_action_preds, use_target=True
)
# subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None:
@@ -264,9 +266,9 @@ class SACPolicy(
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
min_q -= temperature * next_log_probs
td_target = rewards + self.config.discount * min_q * ~done
min_q = min_q - (temperature * next_log_probs)
td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs
q_preds = self.critic_forward(observations, actions, use_target=False)