Add rounding for safety
This commit is contained in:
committed by
Michel Aractingi
parent
a3ada81816
commit
68c271ad25
@@ -421,6 +421,7 @@ class SACPolicy(
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
|
||||
actions_discrete = torch.round(actions_discrete)
|
||||
actions_discrete = actions_discrete.long()
|
||||
|
||||
if complementary_info is not None:
|
||||
|
||||
Reference in New Issue
Block a user