Add rounding for safety
This commit is contained in:
@@ -421,6 +421,7 @@ class SACPolicy(
|
|||||||
# In the buffer we have the full action space (continuous + discrete)
|
# In the buffer we have the full action space (continuous + discrete)
|
||||||
# We need to split them before concatenating them in the critic forward
|
# We need to split them before concatenating them in the critic forward
|
||||||
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
|
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
|
||||||
|
actions_discrete = torch.round(actions_discrete)
|
||||||
actions_discrete = actions_discrete.long()
|
actions_discrete = actions_discrete.long()
|
||||||
|
|
||||||
if complementary_info is not None:
|
if complementary_info is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user