Changed the init_final value to center the starting mean and std of the policy
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
@@ -148,7 +148,7 @@ class Classifier(
|
||||
def predict_reward(self, x, threshold=0.6):
|
||||
if self.config.num_classes == 2:
|
||||
probs = self.forward(x).probabilities
|
||||
logging.info(f"Predicted reward images: {probs}")
|
||||
logging.debug(f"Predicted reward images: {probs}")
|
||||
return (probs > threshold).float()
|
||||
else:
|
||||
return torch.argmax(self.forward(x).probabilities, dim=1)
|
||||
|
||||
Reference in New Issue
Block a user