fix(utils): Convert observation values in predict_action to torch.Tensor (#1157)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Steven Palma
2025-05-28 15:29:08 +02:00
committed by GitHub
parent 69e48bbe19
commit 1fd3b2e2db

View File

@@ -112,6 +112,7 @@ def predict_action(
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
observation[name] = torch.from_numpy(observation[name])
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()