Add download_and_upload_dataset.py in script, update all datasets, update online training

This commit is contained in:
Cadene
2024-04-15 21:26:33 +00:00
parent c6aca7fe44
commit 67d79732f9
8 changed files with 493 additions and 519 deletions

View File

@@ -41,6 +41,7 @@ import gymnasium as gym
import imageio
import numpy as np
import torch
from datasets import Dataset
from huggingface_hub import snapshot_download
from lerobot.common.datasets.factory import make_dataset
@@ -199,30 +200,28 @@ def eval_policy(
ep_dicts = []
num_episodes = dones.shape[0]
total_frames = 0
idx0 = idx1 = 0
data_ids_per_episode = {}
idx_from = 0
for ep_id in range(num_episodes):
num_frames = done_indices[ep_id].item() + 1
total_frames += num_frames
# TODO(rcadene): We need to add a missing last frame which is the observation
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
ep_dict = {
"action": actions[ep_id, :num_frames],
"episode": torch.tensor([ep_id] * num_frames),
"episode_id": torch.tensor([ep_id] * num_frames),
"frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
"next.done": dones[ep_id, :num_frames],
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
"episode_data_id_from": torch.tensor([idx_from] * num_frames),
"episode_data_id_to": torch.tensor([idx_from + num_frames - 1] * num_frames),
}
for key in observations:
ep_dict[key] = observations[key][ep_id, :num_frames]
ep_dicts.append(ep_dict)
total_frames += num_frames
idx1 += num_frames
data_ids_per_episode[ep_id] = torch.arange(idx0, idx1, 1)
idx0 = idx1
idx_from += num_frames
# similar logic is implemented in dataset preprocessing
data_dict = {}
@@ -231,6 +230,8 @@ def eval_policy(
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
data_dict["index"] = torch.arange(0, total_frames, 1)
data_dict = Dataset.from_dict(data_dict).with_format("torch")
if max_episodes_rendered > 0:
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
@@ -280,10 +281,7 @@ def eval_policy(
"eval_s": time.time() - start,
"eval_ep_s": (time.time() - start) / num_episodes,
},
"episodes": {
"data_dict": data_dict,
"data_ids_per_episode": data_ids_per_episode,
},
"episodes": data_dict,
}
if max_episodes_rendered > 0:
info["videos"] = videos