forked from tangger/lerobot
WIP add load functions + episode_data_index
This commit is contained in:
@@ -17,9 +17,9 @@ import tqdm
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from huggingface_hub import HfApi
|
||||
from PIL import Image as PILImage
|
||||
from safetensors.numpy import save_file
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.utils import compute_stats
|
||||
from lerobot.common.datasets.utils import compute_stats, flatten_dict
|
||||
|
||||
|
||||
def download_and_upload(root, revision, dataset_id):
|
||||
@@ -98,7 +98,7 @@ def push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id
|
||||
torch.save(stats, stats_pth_path)
|
||||
|
||||
# create and store meta_data
|
||||
meta_data_dir = root / dataset_id / "train" / "meta_data"
|
||||
meta_data_dir = root / dataset_id / "meta_data"
|
||||
meta_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
api = HfApi()
|
||||
@@ -115,18 +115,17 @@ def push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id
|
||||
)
|
||||
|
||||
# stats
|
||||
for key in stats:
|
||||
stats_path = meta_data_dir / f"stats_{key}.safetensors"
|
||||
save_file(episode_data_index, stats_path)
|
||||
api.upload_file(
|
||||
path_or_fileobj=stats_path,
|
||||
path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""),
|
||||
repo_id=f"lerobot/{dataset_id}",
|
||||
repo_type="dataset",
|
||||
)
|
||||
stats_path = meta_data_dir / "stats.safetensors"
|
||||
save_file(flatten_dict(stats), stats_path)
|
||||
api.upload_file(
|
||||
path_or_fileobj=stats_path,
|
||||
path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""),
|
||||
repo_id=f"lerobot/{dataset_id}",
|
||||
repo_type="dataset",
|
||||
)
|
||||
|
||||
# episode_data_index
|
||||
episode_data_index = {key: np.array(episode_data_index[key]) for key in episode_data_index}
|
||||
episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
|
||||
ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
|
||||
save_file(episode_data_index, ep_data_idx_path)
|
||||
api.upload_file(
|
||||
@@ -139,7 +138,7 @@ def push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id
|
||||
# copy in tests folder, the first episode and the meta_data directory
|
||||
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
||||
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train")
|
||||
shutil.copytree(meta_data_dir, f"tests/{meta_data_dir}")
|
||||
shutil.copytree(meta_data_dir, f"tests/data/{dataset_id}/meta_data")
|
||||
|
||||
|
||||
def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
|
||||
@@ -516,12 +515,12 @@ if __name__ == "__main__":
|
||||
revision = "v1.1"
|
||||
|
||||
dataset_ids = [
|
||||
# "pusht",
|
||||
"pusht",
|
||||
# "xarm_lift_medium",
|
||||
# "aloha_sim_insertion_human",
|
||||
# "aloha_sim_insertion_scripted",
|
||||
# "aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
# "aloha_sim_transfer_cube_scripted",
|
||||
]
|
||||
for dataset_id in dataset_ids:
|
||||
download_and_upload(root, revision, dataset_id)
|
||||
|
||||
Reference in New Issue
Block a user