Add rounding for safety

This commit is contained in:
AdilZouitine
2025-04-08 08:50:02 +00:00
committed by Michel Aractingi
parent a3ada81816
commit 68c271ad25

View File

@@ -421,6 +421,7 @@ class SACPolicy(
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
actions_discrete = torch.round(actions_discrete)
actions_discrete = actions_discrete.long()
if complementary_info is not None: