Add multithreading for video generation, Speed policy sampling

This commit is contained in:
Cadene
2024-02-24 18:18:39 +00:00
parent 591985c67d
commit aed02dc7c6
4 changed files with 59 additions and 6 deletions

View File

@@ -11,7 +11,10 @@ from torchrl.envs import EnvBase
from lerobot.common.envs.factory import make_env
from lerobot.common.tdmpc import TDMPC
from lerobot.common.utils import set_seed
import threading
def write_video(video_path, stacked_frames, fps):
imageio.mimsave(video_path, stacked_frames, fps=fps)
def eval_policy(
env: EnvBase,
@@ -29,6 +32,7 @@ def eval_policy(
sum_rewards = []
max_rewards = []
successes = []
threads = []
for i in range(num_episodes):
ep_frames = []
@@ -63,7 +67,12 @@ def eval_policy(
if save_video:
video_dir.mkdir(parents=True, exist_ok=True)
video_path = video_dir / f"eval_episode_{i}.mp4"
imageio.mimsave(video_path, stacked_frames, fps=fps)
thread = threading.Thread(
target=write_video,
args=(str(video_path), stacked_frames, fps),
)
thread.start()
threads.append(thread)
first_episode = i == 0
if wandb and first_episode:
@@ -72,6 +81,9 @@ def eval_policy(
)
wandb.log({"eval_video": eval_video}, step=env_step)
for thread in threads:
thread.join()
metrics = {
"avg_sum_reward": np.nanmean(sum_rewards),
"avg_max_reward": np.nanmean(max_rewards),
@@ -90,6 +102,7 @@ def eval(cfg: dict, out_dir=None):
raise NotImplementedError()
assert torch.cuda.is_available()
torch.backends.cudnn.benchmark = True
set_seed(cfg.seed)
print(colored("Log dir:", "yellow", attrs=["bold"]), out_dir)
@@ -98,9 +111,9 @@ def eval(cfg: dict, out_dir=None):
if cfg.pretrained_model_path:
policy = TDMPC(cfg)
if "offline" in cfg.pretrained_model_path:
policy.step = 25000
policy.step[0] = 25000
elif "final" in cfg.pretrained_model_path:
policy.step = 100000
policy.step[0] = 100000
else:
raise NotImplementedError()
policy.load(cfg.pretrained_model_path)