Clean logging, Refactor

This commit is contained in:
Cadene
2024-02-29 23:13:06 +00:00
parent cb7b375526
commit 0b9027f05e
9 changed files with 229 additions and 131 deletions

View File

@@ -1,4 +1,5 @@
import threading
import time
from pathlib import Path
import hydra
@@ -29,6 +30,7 @@ def eval_policy(
fps: int = 15,
return_first_video: bool = False,
):
start = time.time()
sum_rewards = []
max_rewards = []
successes = []
@@ -84,14 +86,16 @@ def eval_policy(
for thread in threads:
thread.join()
metrics = {
info = {
"avg_sum_reward": np.nanmean(sum_rewards),
"avg_max_reward": np.nanmean(max_rewards),
"pc_success": np.nanmean(successes) * 100,
"eval_s": time.time() - start,
"eval_ep_s": (time.time() - start) / num_episodes,
}
if return_first_video:
return metrics, first_video
return metrics
return info, first_video
return info
@hydra.main(version_base=None, config_name="default", config_path="../configs")