This commit is contained in:
Alexander Soare
2024-03-20 09:45:45 +00:00
parent b1ec3da035
commit 5332766a82
4 changed files with 34 additions and 233 deletions

View File

@@ -3,6 +3,7 @@ import threading
import time
from pathlib import Path
import einops
import hydra
import imageio
import numpy as np
@@ -69,9 +70,9 @@ def eval_policy(
rollout_steps = rollout["next", "done"].shape[1]
done_indices = torch.argmax(rollout["next", "done"].to(int), axis=1) # (batch_size, rollout_steps)
mask = (torch.arange(rollout_steps) <= done_indices).unsqueeze(-1) # (batch_size, rollout_steps, 1)
batch_sum_reward = (rollout["next", "reward"] * mask).flatten(start_dim=1).sum(dim=-1)
batch_max_reward = (rollout["next", "reward"] * mask).flatten(start_dim=1).max(dim=-1)[0]
batch_success = (rollout["next", "success"] * mask).flatten(start_dim=1).any(dim=-1)
batch_sum_reward = einops.reduce((rollout["next", "reward"] * mask), "b n 1 -> b", "sum")
batch_max_reward = einops.reduce((rollout["next", "reward"] * mask), "b n 1 -> b", "max")
batch_success = einops.reduce((rollout["next", "success"] * mask), "b n 1 -> b", "any")
sum_rewards.extend(batch_sum_reward.tolist())
max_rewards.extend(batch_max_reward.tolist())
successes.extend(batch_success.tolist())