break_when_any_done==True for batch_size==1
This commit is contained in:
@@ -51,14 +51,14 @@ def eval_policy(
|
||||
ep_frames.append(env.render()) # noqa: B023
|
||||
|
||||
with torch.inference_mode():
|
||||
# TODO(alexander-soare): Due the `break_when_any_done == False` this rolls out for max_steps even when all
|
||||
# TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all
|
||||
# envs are done the first time. But we only use the first rollout. This is a waste of compute.
|
||||
rollout = env.rollout(
|
||||
max_steps=max_steps,
|
||||
policy=policy,
|
||||
auto_cast_to_device=True,
|
||||
callback=maybe_render_frame,
|
||||
break_when_any_done=False,
|
||||
break_when_any_done=env.batch_size[0] == 1,
|
||||
)
|
||||
# Figure out where in each rollout sequence the first done condition was encountered (results after this won't
|
||||
# be included).
|
||||
|
||||
Reference in New Issue
Block a user