Improve dataset examples (#82)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi
2024-04-18 11:43:16 +02:00
committed by GitHub
parent d5c4b0c344
commit 0928afd37d
15 changed files with 274 additions and 165 deletions

View File

@@ -200,13 +200,13 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
"episode_data_index_to": Value(dtype="int64", id=None),
}
features = Features(features)
dataset = Dataset.from_dict(data_dict, features=features)
dataset = dataset.with_format("torch")
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset = hf_dataset.with_format("torch")
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
@@ -311,13 +311,13 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
"episode_data_index_to": Value(dtype="int64", id=None),
}
features = Features(features)
dataset = Dataset.from_dict(data_dict, features=features)
dataset = dataset.with_format("torch")
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset = hf_dataset.with_format("torch")
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
@@ -460,13 +460,13 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
"episode_data_index_to": Value(dtype="int64", id=None),
}
features = Features(features)
dataset = Dataset.from_dict(data_dict, features=features)
dataset = dataset.with_format("torch")
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset = hf_dataset.with_format("torch")
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
if __name__ == "__main__":