Add online training with TD-MPC as proof of concept (#338)

This commit is contained in:
Alexander Soare
2024-07-25 11:16:38 +01:00
committed by GitHub
parent abbb1d2367
commit f8a6574698
25 changed files with 1291 additions and 233 deletions

View File

@@ -56,16 +56,13 @@ import einops
import gymnasium as gym
import numpy as np
import torch
from datasets import Dataset, Features, Image, Sequence, Value, concatenate_datasets
from huggingface_hub import snapshot_download
from huggingface_hub.utils._errors import RepositoryNotFoundError
from huggingface_hub.utils._validators import HFValidationError
from PIL import Image as PILImage
from torch import Tensor, nn
from tqdm import trange
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import hf_transform_to_torch
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.logger import log_output_dir
@@ -318,41 +315,17 @@ def eval_policy(
rollout_data,
done_indices,
start_episode_index=batch_ix * env.num_envs,
start_data_index=(
0 if episode_data is None else (episode_data["episode_data_index"]["to"][-1].item())
),
start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)),
fps=env.unwrapped.metadata["render_fps"],
)
if episode_data is None:
episode_data = this_episode_data
else:
# Some sanity checks to make sure we are not correctly compiling the data.
assert (
episode_data["hf_dataset"]["episode_index"][-1] + 1
== this_episode_data["hf_dataset"]["episode_index"][0]
)
assert (
episode_data["hf_dataset"]["index"][-1] + 1 == this_episode_data["hf_dataset"]["index"][0]
)
assert torch.equal(
episode_data["episode_data_index"]["to"][-1],
this_episode_data["episode_data_index"]["from"][0],
)
# Some sanity checks to make sure we are correctly compiling the data.
assert episode_data["episode_index"][-1] + 1 == this_episode_data["episode_index"][0]
assert episode_data["index"][-1] + 1 == this_episode_data["index"][0]
# Concatenate the episode data.
episode_data = {
"hf_dataset": concatenate_datasets(
[episode_data["hf_dataset"], this_episode_data["hf_dataset"]]
),
"episode_data_index": {
k: torch.cat(
[
episode_data["episode_data_index"][k],
this_episode_data["episode_data_index"][k],
]
)
for k in ["from", "to"]
},
}
episode_data = {k: torch.cat([episode_data[k], this_episode_data[k]]) for k in episode_data}
# Maybe render video for visualization.
if max_episodes_rendered > 0 and len(ep_frames) > 0:
@@ -434,89 +407,39 @@ def _compile_episode_data(
Similar logic is implemented when datasets are pushed to hub (see: `push_to_hub`).
"""
ep_dicts = []
episode_data_index = {"from": [], "to": []}
total_frames = 0
data_index_from = start_data_index
for ep_ix in range(rollout_data["action"].shape[0]):
num_frames = done_indices[ep_ix].item() + 1 # + 1 to include the first done frame
# + 2 to include the first done frame and the last observation frame.
num_frames = done_indices[ep_ix].item() + 2
total_frames += num_frames
# TODO(rcadene): We need to add a missing last frame which is the observation
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
# Here we do `num_frames - 1` as we don't want to include the last observation frame just yet.
ep_dict = {
"action": rollout_data["action"][ep_ix, :num_frames],
"episode_index": torch.tensor([start_episode_index + ep_ix] * num_frames),
"frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
"next.done": rollout_data["done"][ep_ix, :num_frames],
"next.reward": rollout_data["reward"][ep_ix, :num_frames].type(torch.float32),
"action": rollout_data["action"][ep_ix, : num_frames - 1],
"episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
"frame_index": torch.arange(0, num_frames - 1, 1),
"timestamp": torch.arange(0, num_frames - 1, 1) / fps,
"next.done": rollout_data["done"][ep_ix, : num_frames - 1],
"next.success": rollout_data["success"][ep_ix, : num_frames - 1],
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32),
}
# For the last observation frame, all other keys will just be copy padded.
for k in ep_dict:
ep_dict[k] = torch.cat([ep_dict[k], ep_dict[k][-1:]])
for key in rollout_data["observation"]:
ep_dict[key] = rollout_data["observation"][key][ep_ix][:num_frames]
ep_dict[key] = rollout_data["observation"][key][ep_ix, :num_frames]
ep_dicts.append(ep_dict)
episode_data_index["from"].append(data_index_from)
episode_data_index["to"].append(data_index_from + num_frames)
data_index_from += num_frames
data_dict = {}
for key in ep_dicts[0]:
if "image" not in key:
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for img in ep_dict[key]:
# sanity check that images are channel first
c, h, w = img.shape
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
# sanity check that images are float32 in range [0,1]
assert img.dtype == torch.float32, f"expect torch.float32, but instead {img.dtype=}"
assert img.max() <= 1, f"expect pixels lower than 1, but instead {img.max()=}"
assert img.min() >= 0, f"expect pixels greater than 1, but instead {img.min()=}"
# from float32 in range [0,1] to uint8 in range [0,255]
img *= 255
img = img.type(torch.uint8)
# convert to channel last and numpy as expected by PIL
img = PILImage.fromarray(img.permute(1, 2, 0).numpy())
data_dict[key].append(img)
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
episode_data_index["from"] = torch.tensor(episode_data_index["from"])
episode_data_index["to"] = torch.tensor(episode_data_index["to"])
# TODO(rcadene): clean this
features = {}
for key in rollout_data["observation"]:
if "image" in key:
features[key] = Image()
else:
features[key] = Sequence(length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None))
features.update(
{
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
"next.reward": Value(dtype="float32", id=None),
"next.done": Value(dtype="bool", id=None),
#'next.success': Value(dtype='bool', id=None),
"index": Value(dtype="int64", id=None),
}
)
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset.set_transform(hf_transform_to_torch)
return {
"hf_dataset": hf_dataset,
"episode_data_index": episode_data_index,
}
return data_dict
def main(