Changed the init_final value to center the starting mean and std of the policy

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi
2025-02-13 16:42:43 +01:00
committed by AdilZouitine
parent 24fb8a7f47
commit 0847b2119b
4 changed files with 5 additions and 4 deletions

View File

@@ -148,7 +148,7 @@ class Classifier(
def predict_reward(self, x, threshold=0.6):
if self.config.num_classes == 2:
probs = self.forward(x).probabilities
logging.info(f"Predicted reward images: {probs}")
logging.debug(f"Predicted reward images: {probs}")
return (probs > threshold).float()
else:
return torch.argmax(self.forward(x).probabilities, dim=1)

View File

@@ -95,5 +95,6 @@ class SACConfig:
"use_tanh_squash": True,
"log_std_min": -5,
"log_std_max": 2,
"init_final": 0.01,
}
)