id -> index, finish moving compute_stats before hf_dataset push_to_hub

This commit is contained in:
Cadene
2024-04-19 10:33:42 +00:00
parent 64b09ea7a7
commit 714a776277
9 changed files with 120 additions and 99 deletions

View File

@@ -19,7 +19,7 @@ from huggingface_hub import HfApi
from PIL import Image as PILImage
from safetensors.torch import save_file
from lerobot.common.datasets.utils import compute_stats, flatten_dict
from lerobot.common.datasets.utils import compute_stats, convert_images_to_channel_first_tensors, flatten_dict
def download_and_upload(root, revision, dataset_id):
@@ -75,28 +75,18 @@ def concatenate_episodes(ep_dicts):
for x in ep_dict[key]:
data_dict[key].append(x)
total_frames = data_dict["frame_id"].shape[0]
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
return data_dict
def push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id):
hf_dataset = hf_dataset.with_format("torch")
def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id):
# push to main to indicate latest version
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
# push to version branch
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision=revision)
# get stats
stats_pth_path = root / dataset_id / "stats.pth"
if stats_pth_path.exists():
stats = torch.load(stats_pth_path)
else:
stats = compute_stats(hf_dataset)
torch.save(stats, stats_pth_path)
# create and store meta_data
meta_data_dir = root / dataset_id / "meta_data"
meta_data_dir.mkdir(parents=True, exist_ok=True)
@@ -237,8 +227,8 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
"observation.state": agent_pos,
"action": actions[id_from:id_to],
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1),
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.image": image[1:],
# "next.observation.state": agent_pos[1:],
@@ -262,8 +252,8 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_id": Value(dtype="int64", id=None),
"frame_id": Value(dtype="int64", id=None),
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
"next.reward": Value(dtype="float32", id=None),
"next.done": Value(dtype="bool", id=None),
@@ -272,11 +262,14 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
}
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
info = {
"fps": fps,
}
push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id)
stats = compute_stats(hf_dataset)
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
def download_and_upload_xarm(root, revision, dataset_id, fps=15):
@@ -334,8 +327,8 @@ def download_and_upload_xarm(root, revision, dataset_id, fps=15):
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
"observation.state": state,
"action": action,
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1),
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.image": next_image,
# "next.observation.state": next_state,
@@ -358,8 +351,8 @@ def download_and_upload_xarm(root, revision, dataset_id, fps=15):
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_id": Value(dtype="int64", id=None),
"frame_id": Value(dtype="int64", id=None),
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
"next.reward": Value(dtype="float32", id=None),
"next.done": Value(dtype="bool", id=None),
@@ -368,11 +361,14 @@ def download_and_upload_xarm(root, revision, dataset_id, fps=15):
}
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
info = {
"fps": fps,
}
push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id)
stats = compute_stats(hf_dataset)
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
def download_and_upload_aloha(root, revision, dataset_id, fps=50):
@@ -464,8 +460,8 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
{
"observation.state": state,
"action": action,
"episode_id": torch.tensor([ep_id] * num_frames),
"frame_id": torch.arange(0, num_frames, 1),
"episode_index": torch.tensor([ep_id] * num_frames),
"frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.state": state,
# TODO(rcadene): compute reward and success
@@ -493,8 +489,8 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_id": Value(dtype="int64", id=None),
"frame_id": Value(dtype="int64", id=None),
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
#'next.reward': Value(dtype='float32', id=None),
"next.done": Value(dtype="bool", id=None),
@@ -503,11 +499,14 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
}
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
info = {
"fps": fps,
}
push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id)
stats = compute_stats(hf_dataset)
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
if __name__ == "__main__":