backup wip

This commit is contained in:
Alexander Soare
2024-03-19 18:50:04 +00:00
parent ea17f4ce50
commit 896a11f60e
16 changed files with 169 additions and 138 deletions

View File

@@ -51,16 +51,25 @@ def eval_policy(
ep_frames.append(env.render()) # noqa: B023
with torch.inference_mode():
# TODO(alexander-soare): Due the `break_when_any_done == False` this rolls out for max_steps even when all
# envs are done the first time. But we only use the first rollout. This is a waste of compute.
rollout = env.rollout(
max_steps=max_steps,
policy=policy,
auto_cast_to_device=True,
callback=maybe_render_frame,
break_when_any_done=False,
)
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
batch_sum_reward = rollout["next", "reward"].flatten(start_dim=1).sum(dim=-1)
batch_max_reward = rollout["next", "reward"].flatten(start_dim=1).max(dim=-1)[0]
batch_success = rollout["next", "success"].flatten(start_dim=1).any(dim=-1)
# Figure out where in each rollout sequence the first done condition was encountered (results after this won't
# be included).
# Note: this assumes that the shape of the done key is (batch_size, max_steps, 1).
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
rollout_steps = rollout["next", "done"].shape[1]
done_indices = torch.argmax(rollout["next", "done"].to(int), axis=1) # (batch_size, rollout_steps)
mask = (torch.arange(rollout_steps) <= done_indices).unsqueeze(-1) # (batch_size, rollout_steps, 1)
batch_sum_reward = (rollout["next", "reward"] * mask).flatten(start_dim=1).sum(dim=-1)
batch_max_reward = (rollout["next", "reward"] * mask).flatten(start_dim=1).max(dim=-1)[0]
batch_success = (rollout["next", "success"] * mask).flatten(start_dim=1).any(dim=-1)
sum_rewards.extend(batch_sum_reward.tolist())
max_rewards.extend(batch_max_reward.tolist())
successes.extend(batch_success.tolist())