Merge remote-tracking branch 'upstream/main' into refactor_dp
This commit is contained in:
@@ -32,6 +32,7 @@ import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from datetime import datetime as dt
|
||||
from pathlib import Path
|
||||
|
||||
@@ -56,15 +57,15 @@ def write_video(video_path, stacked_frames, fps):
|
||||
|
||||
def eval_policy(
|
||||
env: gym.vector.VectorEnv,
|
||||
policy,
|
||||
save_video: bool = False,
|
||||
policy: torch.nn.Module,
|
||||
max_episodes_rendered: int = 0,
|
||||
video_dir: Path = None,
|
||||
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
|
||||
fps: int = 15,
|
||||
return_first_video: bool = False,
|
||||
transform: callable = None,
|
||||
seed=None,
|
||||
):
|
||||
fps = env.unwrapped.metadata["render_fps"]
|
||||
|
||||
if policy is not None:
|
||||
policy.eval()
|
||||
device = "cpu" if policy is None else next(policy.parameters()).device
|
||||
@@ -83,14 +84,11 @@ def eval_policy(
|
||||
# needed as I'm currently taking a ceil.
|
||||
ep_frames = []
|
||||
|
||||
def maybe_render_frame(env):
|
||||
if save_video: # noqa: B023
|
||||
if return_first_video:
|
||||
visu = env.envs[0].render()
|
||||
visu = visu[None, ...] # add batch dim
|
||||
else:
|
||||
visu = np.stack([env.render() for env in env.envs])
|
||||
ep_frames.append(visu) # noqa: B023
|
||||
def render_frame(env):
|
||||
# noqa: B023
|
||||
eps_rendered = min(max_episodes_rendered, len(env.envs))
|
||||
visu = np.stack([env.envs[i].render() for i in range(eps_rendered)])
|
||||
ep_frames.append(visu) # noqa: B023
|
||||
|
||||
for _ in range(num_episodes):
|
||||
seeds.append("TODO")
|
||||
@@ -104,8 +102,14 @@ def eval_policy(
|
||||
|
||||
# reset the environment
|
||||
observation, info = env.reset(seed=seed)
|
||||
maybe_render_frame(env)
|
||||
if max_episodes_rendered > 0:
|
||||
render_frame(env)
|
||||
|
||||
observations = []
|
||||
actions = []
|
||||
# episode
|
||||
# frame_id
|
||||
# timestamp
|
||||
rewards = []
|
||||
successes = []
|
||||
dones = []
|
||||
@@ -113,8 +117,13 @@ def eval_policy(
|
||||
done = torch.tensor([False for _ in env.envs])
|
||||
step = 0
|
||||
while not done.all():
|
||||
# format from env keys to lerobot keys
|
||||
observation = preprocess_observation(observation)
|
||||
observations.append(deepcopy(observation))
|
||||
|
||||
# apply transform to normalize the observations
|
||||
observation = preprocess_observation(observation, transform)
|
||||
for key in observation:
|
||||
observation[key] = torch.stack([transform({key: item})[key] for item in observation[key]])
|
||||
|
||||
# send observation to device/gpu
|
||||
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
|
||||
@@ -126,11 +135,13 @@ def eval_policy(
|
||||
# apply inverse transform to unnormalize the action
|
||||
action = postprocess_action(action, transform)
|
||||
|
||||
# apply the next
|
||||
# apply the next action
|
||||
observation, reward, terminated, truncated, info = env.step(action)
|
||||
maybe_render_frame(env)
|
||||
if max_episodes_rendered > 0:
|
||||
render_frame(env)
|
||||
|
||||
# TODO(rcadene): implement a wrapper over env to return torch tensors in float32 (and cuda?)
|
||||
action = torch.from_numpy(action)
|
||||
reward = torch.from_numpy(reward)
|
||||
terminated = torch.from_numpy(terminated)
|
||||
truncated = torch.from_numpy(truncated)
|
||||
@@ -147,12 +158,24 @@ def eval_policy(
|
||||
success = [False for _ in env.envs]
|
||||
success = torch.tensor(success)
|
||||
|
||||
actions.append(action)
|
||||
rewards.append(reward)
|
||||
dones.append(done)
|
||||
successes.append(success)
|
||||
|
||||
step += 1
|
||||
|
||||
env.close()
|
||||
|
||||
# add the last observation when the env is done
|
||||
observation = preprocess_observation(observation)
|
||||
observations.append(deepcopy(observation))
|
||||
|
||||
new_obses = {}
|
||||
for key in observations[0].keys(): # noqa: SIM118
|
||||
new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1)
|
||||
observations = new_obses
|
||||
actions = torch.stack(actions, dim=1)
|
||||
rewards = torch.stack(rewards, dim=1)
|
||||
successes = torch.stack(successes, dim=1)
|
||||
dones = torch.stack(dones, dim=1)
|
||||
@@ -172,29 +195,61 @@ def eval_policy(
|
||||
max_rewards.extend(batch_max_reward.tolist())
|
||||
all_successes.extend(batch_success.tolist())
|
||||
|
||||
env.close()
|
||||
# similar logic is implemented in dataset preprocessing
|
||||
ep_dicts = []
|
||||
num_episodes = dones.shape[0]
|
||||
total_frames = 0
|
||||
idx0 = idx1 = 0
|
||||
data_ids_per_episode = {}
|
||||
for ep_id in range(num_episodes):
|
||||
num_frames = done_indices[ep_id].item() + 1
|
||||
# TODO(rcadene): We need to add a missing last frame which is the observation
|
||||
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||
ep_dict = {
|
||||
"action": actions[ep_id, :num_frames],
|
||||
"episode": torch.tensor([ep_id] * num_frames),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||
"next.done": dones[ep_id, :num_frames],
|
||||
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
|
||||
}
|
||||
for key in observations:
|
||||
ep_dict[key] = observations[key][ep_id, :num_frames]
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
if save_video or return_first_video:
|
||||
total_frames += num_frames
|
||||
idx1 += num_frames
|
||||
|
||||
data_ids_per_episode[ep_id] = torch.arange(idx0, idx1, 1)
|
||||
|
||||
idx0 = idx1
|
||||
|
||||
# similar logic is implemented in dataset preprocessing
|
||||
data_dict = {}
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
if max_episodes_rendered > 0:
|
||||
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
||||
|
||||
if save_video:
|
||||
for stacked_frames, done_index in zip(
|
||||
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
|
||||
):
|
||||
if episode_counter >= num_episodes:
|
||||
continue
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
|
||||
thread = threading.Thread(
|
||||
target=write_video,
|
||||
args=(str(video_path), stacked_frames[:done_index], fps),
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
episode_counter += 1
|
||||
for stacked_frames, done_index in zip(
|
||||
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
|
||||
):
|
||||
if episode_counter >= num_episodes:
|
||||
continue
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
|
||||
thread = threading.Thread(
|
||||
target=write_video,
|
||||
args=(str(video_path), stacked_frames[:done_index], fps),
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
episode_counter += 1
|
||||
|
||||
if return_first_video:
|
||||
first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2)
|
||||
videos = einops.rearrange(batch_stacked_frames, "b t h w c -> b t c h w")
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
@@ -225,9 +280,13 @@ def eval_policy(
|
||||
"eval_s": time.time() - start,
|
||||
"eval_ep_s": (time.time() - start) / num_episodes,
|
||||
},
|
||||
"episodes": {
|
||||
"data_dict": data_dict,
|
||||
"data_ids_per_episode": data_ids_per_episode,
|
||||
},
|
||||
}
|
||||
if return_first_video:
|
||||
return info, first_video
|
||||
if max_episodes_rendered > 0:
|
||||
info["videos"] = videos
|
||||
return info
|
||||
|
||||
|
||||
@@ -253,16 +312,14 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||
logging.info("Making environment.")
|
||||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||
|
||||
# when policy is None, rollout a random policy
|
||||
policy = make_policy(cfg) if cfg.policy.pretrained_model_path else None
|
||||
logging.info("Making policy.")
|
||||
policy = make_policy(cfg)
|
||||
|
||||
info = eval_policy(
|
||||
env,
|
||||
policy=policy,
|
||||
save_video=True,
|
||||
policy,
|
||||
max_episodes_rendered=10,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
fps=cfg.env.fps,
|
||||
# TODO(rcadene): what should we do with the transform?
|
||||
transform=transform,
|
||||
seed=cfg.seed,
|
||||
)
|
||||
@@ -270,6 +327,9 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||
|
||||
# Save info
|
||||
with open(Path(out_dir) / "eval_info.json", "w") as f:
|
||||
# remove pytorch tensors which are not serializable to save the evaluation results only
|
||||
del info["episodes"]
|
||||
del info["videos"]
|
||||
json.dump(info, f, indent=2)
|
||||
|
||||
logging.info("End of eval")
|
||||
|
||||
Reference in New Issue
Block a user