match target entropy hil serl

Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
AdilZouitine
2025-04-15 08:00:38 +00:00
committed by Michel Aractingi
parent 5c352ae558
commit 8122721f6d
3 changed files with 11 additions and 7 deletions

View File

@@ -155,7 +155,8 @@ class SACPolicy(
**asdict(config.policy_kwargs),
)
if config.target_entropy is None:
config.target_entropy = -np.prod(continuous_action_dim) / 2 # (-dim(A)/2)
discrete_actions_dim: Literal[1] | Literal[0] = 1 if config.num_discrete_actions is None else 0
config.target_entropy = -np.prod(continuous_action_dim + discrete_actions_dim) / 2 # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
@@ -176,7 +177,7 @@ class SACPolicy(
"temperature": self.log_alpha,
}
if self.config.num_discrete_actions is not None:
optim_params["grasp_critic"] = self.grasp_critic.parameters_to_optimize
optim_params["grasp_critic"] = self.grasp_critic.parameters()
return optim_params
def reset(self):