- Added JointMaskingActionSpace wrapper in gym_manipulator in order to select which joints will be controlled. For example, we can disable the gripper actions for some tasks.
- Added Nan detection mechanisms in the actor, learner and gym_manipulator for the case where we encounter nans in the loop. - changed the non-blocking in the `.to(device)` functions to only work for the case of cuda because they were causing nans when running the policy on mps - Added some joint clipping and limits in the env, robot and policy configs. TODO clean this part and make the limits in one config file only. Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
@@ -145,8 +145,8 @@ class Classifier(
|
||||
|
||||
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
|
||||
|
||||
def predict_reward(self, x):
|
||||
def predict_reward(self, x, threshold=0.6):
|
||||
if self.config.num_classes == 2:
|
||||
return (self.forward(x).probabilities > 0.6).float()
|
||||
return (self.forward(x).probabilities > threshold).float()
|
||||
else:
|
||||
return torch.argmax(self.forward(x).probabilities, dim=1)
|
||||
|
||||
Reference in New Issue
Block a user