forked from tangger/lerobot
Clean logging, Refactor
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
@@ -29,6 +30,7 @@ def eval_policy(
|
||||
fps: int = 15,
|
||||
return_first_video: bool = False,
|
||||
):
|
||||
start = time.time()
|
||||
sum_rewards = []
|
||||
max_rewards = []
|
||||
successes = []
|
||||
@@ -84,14 +86,16 @@ def eval_policy(
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
metrics = {
|
||||
info = {
|
||||
"avg_sum_reward": np.nanmean(sum_rewards),
|
||||
"avg_max_reward": np.nanmean(max_rewards),
|
||||
"pc_success": np.nanmean(successes) * 100,
|
||||
"eval_s": time.time() - start,
|
||||
"eval_ep_s": (time.time() - start) / num_episodes,
|
||||
}
|
||||
if return_first_video:
|
||||
return metrics, first_video
|
||||
return metrics
|
||||
return info, first_video
|
||||
return info
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
||||
|
||||
Reference in New Issue
Block a user