WIP add load functions + episode_data_index

This commit is contained in:
Cadene
2024-04-18 23:54:52 +00:00
parent 0bd2ca8d82
commit 64b09ea7a7
4 changed files with 159 additions and 40 deletions

View File

@@ -17,9 +17,9 @@ import tqdm
from datasets import Dataset, Features, Image, Sequence, Value
from huggingface_hub import HfApi
from PIL import Image as PILImage
from safetensors.numpy import save_file
from safetensors.torch import save_file
from lerobot.common.datasets.utils import compute_stats
from lerobot.common.datasets.utils import compute_stats, flatten_dict
def download_and_upload(root, revision, dataset_id):
@@ -98,7 +98,7 @@ def push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id
torch.save(stats, stats_pth_path)
# create and store meta_data
meta_data_dir = root / dataset_id / "train" / "meta_data"
meta_data_dir = root / dataset_id / "meta_data"
meta_data_dir.mkdir(parents=True, exist_ok=True)
api = HfApi()
@@ -115,18 +115,17 @@ def push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id
)
# stats
for key in stats:
stats_path = meta_data_dir / f"stats_{key}.safetensors"
save_file(episode_data_index, stats_path)
api.upload_file(
path_or_fileobj=stats_path,
path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""),
repo_id=f"lerobot/{dataset_id}",
repo_type="dataset",
)
stats_path = meta_data_dir / "stats.safetensors"
save_file(flatten_dict(stats), stats_path)
api.upload_file(
path_or_fileobj=stats_path,
path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""),
repo_id=f"lerobot/{dataset_id}",
repo_type="dataset",
)
# episode_data_index
episode_data_index = {key: np.array(episode_data_index[key]) for key in episode_data_index}
episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
save_file(episode_data_index, ep_data_idx_path)
api.upload_file(
@@ -139,7 +138,7 @@ def push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id
# copy in tests folder, the first episode and the meta_data directory
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train")
shutil.copytree(meta_data_dir, f"tests/{meta_data_dir}")
shutil.copytree(meta_data_dir, f"tests/data/{dataset_id}/meta_data")
def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
@@ -516,12 +515,12 @@ if __name__ == "__main__":
revision = "v1.1"
dataset_ids = [
# "pusht",
"pusht",
# "xarm_lift_medium",
# "aloha_sim_insertion_human",
# "aloha_sim_insertion_scripted",
# "aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
# "aloha_sim_transfer_cube_scripted",
]
for dataset_id in dataset_ids:
download_and_upload(root, revision, dataset_id)