forked from tangger/lerobot
Add multithreading for video generation, Speed policy sampling
This commit is contained in:
@@ -11,7 +11,10 @@ from torchrl.envs import EnvBase
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.tdmpc import TDMPC
|
||||
from lerobot.common.utils import set_seed
|
||||
import threading
|
||||
|
||||
def write_video(video_path, stacked_frames, fps):
|
||||
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
||||
|
||||
def eval_policy(
|
||||
env: EnvBase,
|
||||
@@ -29,6 +32,7 @@ def eval_policy(
|
||||
sum_rewards = []
|
||||
max_rewards = []
|
||||
successes = []
|
||||
threads = []
|
||||
for i in range(num_episodes):
|
||||
ep_frames = []
|
||||
|
||||
@@ -63,7 +67,12 @@ def eval_policy(
|
||||
if save_video:
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_path = video_dir / f"eval_episode_{i}.mp4"
|
||||
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
||||
thread = threading.Thread(
|
||||
target=write_video,
|
||||
args=(str(video_path), stacked_frames, fps),
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
|
||||
first_episode = i == 0
|
||||
if wandb and first_episode:
|
||||
@@ -72,6 +81,9 @@ def eval_policy(
|
||||
)
|
||||
wandb.log({"eval_video": eval_video}, step=env_step)
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
metrics = {
|
||||
"avg_sum_reward": np.nanmean(sum_rewards),
|
||||
"avg_max_reward": np.nanmean(max_rewards),
|
||||
@@ -90,6 +102,7 @@ def eval(cfg: dict, out_dir=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
assert torch.cuda.is_available()
|
||||
torch.backends.cudnn.benchmark = True
|
||||
set_seed(cfg.seed)
|
||||
print(colored("Log dir:", "yellow", attrs=["bold"]), out_dir)
|
||||
|
||||
@@ -98,9 +111,9 @@ def eval(cfg: dict, out_dir=None):
|
||||
if cfg.pretrained_model_path:
|
||||
policy = TDMPC(cfg)
|
||||
if "offline" in cfg.pretrained_model_path:
|
||||
policy.step = 25000
|
||||
policy.step[0] = 25000
|
||||
elif "final" in cfg.pretrained_model_path:
|
||||
policy.step = 100000
|
||||
policy.step[0] = 100000
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
policy.load(cfg.pretrained_model_path)
|
||||
|
||||
Reference in New Issue
Block a user