forked from tangger/lerobot
revision
This commit is contained in:
@@ -3,6 +3,7 @@ import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import hydra
|
||||
import imageio
|
||||
import numpy as np
|
||||
@@ -69,9 +70,9 @@ def eval_policy(
|
||||
rollout_steps = rollout["next", "done"].shape[1]
|
||||
done_indices = torch.argmax(rollout["next", "done"].to(int), axis=1) # (batch_size, rollout_steps)
|
||||
mask = (torch.arange(rollout_steps) <= done_indices).unsqueeze(-1) # (batch_size, rollout_steps, 1)
|
||||
batch_sum_reward = (rollout["next", "reward"] * mask).flatten(start_dim=1).sum(dim=-1)
|
||||
batch_max_reward = (rollout["next", "reward"] * mask).flatten(start_dim=1).max(dim=-1)[0]
|
||||
batch_success = (rollout["next", "success"] * mask).flatten(start_dim=1).any(dim=-1)
|
||||
batch_sum_reward = einops.reduce((rollout["next", "reward"] * mask), "b n 1 -> b", "sum")
|
||||
batch_max_reward = einops.reduce((rollout["next", "reward"] * mask), "b n 1 -> b", "max")
|
||||
batch_success = einops.reduce((rollout["next", "success"] * mask), "b n 1 -> b", "any")
|
||||
sum_rewards.extend(batch_sum_reward.tolist())
|
||||
max_rewards.extend(batch_max_reward.tolist())
|
||||
successes.extend(batch_success.tolist())
|
||||
|
||||
Reference in New Issue
Block a user