Fixed eval.py on MPS (#702)

This commit is contained in:
Ilia Larchenko
2025-02-14 06:03:55 +07:00
committed by GitHub
parent 1e49cc4d60
commit c574eb4984

View File

@@ -151,7 +151,9 @@ def rollout(
if return_observations:
all_observations.append(deepcopy(observation))
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
observation = {
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
}
with torch.inference_mode():
action = policy.select_action(observation)