Add get_safe_torch_device in policies

This commit is contained in:
Simon Alibert
2024-03-20 18:38:55 +01:00
parent ec536ef0fa
commit 4631d36c05
6 changed files with 39 additions and 18 deletions

View File

@@ -18,7 +18,7 @@ from lerobot.common.envs.factory import make_env
from lerobot.common.logger import log_output_dir
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import init_logging, set_seed
from lerobot.common.utils import get_safe_torch_device, init_logging, set_seed
def write_video(video_path, stacked_frames, fps):
@@ -35,7 +35,8 @@ def eval_policy(
fps: int = 15,
return_first_video: bool = False,
):
policy.eval()
if policy is not None:
policy.eval()
start = time.time()
sum_rewards = []
max_rewards = []
@@ -55,7 +56,8 @@ def eval_policy(
with torch.inference_mode():
# TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all
# envs are done the first time. But we only use the first rollout. This is a waste of compute.
policy.clear_action_queue()
if policy is not None:
policy.clear_action_queue()
rollout = env.rollout(
max_steps=max_steps,
policy=policy,
@@ -128,10 +130,8 @@ def eval(cfg: dict, out_dir=None):
init_logging()
if cfg.device == "cuda":
assert torch.cuda.is_available()
else:
logging.warning("Using CPU, this will be slow.")
# Check device is available
get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True