Add download_and_upload_dataset.py in script, update all datasets, update online training
This commit is contained in:
@@ -41,6 +41,7 @@ import gymnasium as gym
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
@@ -199,30 +200,28 @@ def eval_policy(
|
||||
ep_dicts = []
|
||||
num_episodes = dones.shape[0]
|
||||
total_frames = 0
|
||||
idx0 = idx1 = 0
|
||||
data_ids_per_episode = {}
|
||||
idx_from = 0
|
||||
for ep_id in range(num_episodes):
|
||||
num_frames = done_indices[ep_id].item() + 1
|
||||
total_frames += num_frames
|
||||
|
||||
# TODO(rcadene): We need to add a missing last frame which is the observation
|
||||
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||
ep_dict = {
|
||||
"action": actions[ep_id, :num_frames],
|
||||
"episode": torch.tensor([ep_id] * num_frames),
|
||||
"episode_id": torch.tensor([ep_id] * num_frames),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||
"next.done": dones[ep_id, :num_frames],
|
||||
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
|
||||
"episode_data_id_from": torch.tensor([idx_from] * num_frames),
|
||||
"episode_data_id_to": torch.tensor([idx_from + num_frames - 1] * num_frames),
|
||||
}
|
||||
for key in observations:
|
||||
ep_dict[key] = observations[key][ep_id, :num_frames]
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
total_frames += num_frames
|
||||
idx1 += num_frames
|
||||
|
||||
data_ids_per_episode[ep_id] = torch.arange(idx0, idx1, 1)
|
||||
|
||||
idx0 = idx1
|
||||
idx_from += num_frames
|
||||
|
||||
# similar logic is implemented in dataset preprocessing
|
||||
data_dict = {}
|
||||
@@ -231,6 +230,8 @@ def eval_policy(
|
||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
data_dict = Dataset.from_dict(data_dict).with_format("torch")
|
||||
|
||||
if max_episodes_rendered > 0:
|
||||
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
||||
|
||||
@@ -280,10 +281,7 @@ def eval_policy(
|
||||
"eval_s": time.time() - start,
|
||||
"eval_ep_s": (time.time() - start) / num_episodes,
|
||||
},
|
||||
"episodes": {
|
||||
"data_dict": data_dict,
|
||||
"data_ids_per_episode": data_ids_per_episode,
|
||||
},
|
||||
"episodes": data_dict,
|
||||
}
|
||||
if max_episodes_rendered > 0:
|
||||
info["videos"] = videos
|
||||
|
||||
Reference in New Issue
Block a user