Use v1.1, hf_transform_to_torch, Add 3 xarm datasets
This commit is contained in:
@@ -19,7 +19,7 @@ from huggingface_hub import HfApi
|
||||
from PIL import Image as PILImage
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.utils import compute_stats, convert_images_to_channel_first_tensors, flatten_dict
|
||||
from lerobot.common.datasets.utils import compute_stats, flatten_dict, hf_transform_to_torch
|
||||
|
||||
|
||||
def download_and_upload(root, revision, dataset_id):
|
||||
@@ -127,7 +127,11 @@ def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dat
|
||||
|
||||
# copy in tests folder, the first episode and the meta_data directory
|
||||
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
||||
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train")
|
||||
hf_dataset.select(range(num_items_first_ep)).with_format("torch").save_to_disk(
|
||||
f"tests/data/{dataset_id}/train"
|
||||
)
|
||||
if Path(f"tests/data/{dataset_id}/meta_data").exists():
|
||||
shutil.rmtree(f"tests/data/{dataset_id}/meta_data")
|
||||
shutil.copytree(meta_data_dir, f"tests/data/{dataset_id}/meta_data")
|
||||
|
||||
|
||||
@@ -262,8 +266,7 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
|
||||
}
|
||||
features = Features(features)
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset = hf_dataset.with_format("torch")
|
||||
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
||||
info = {
|
||||
"fps": fps,
|
||||
@@ -274,13 +277,14 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
|
||||
|
||||
def download_and_upload_xarm(root, revision, dataset_id, fps=15):
|
||||
root = Path(root)
|
||||
raw_dir = root / f"{dataset_id}_raw"
|
||||
raw_dir = root / "xarm_datasets_raw"
|
||||
if not raw_dir.exists():
|
||||
import zipfile
|
||||
|
||||
import gdown
|
||||
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
# from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py
|
||||
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
||||
zip_path = raw_dir / "data.zip"
|
||||
gdown.download(url, str(zip_path), quiet=False)
|
||||
@@ -361,8 +365,7 @@ def download_and_upload_xarm(root, revision, dataset_id, fps=15):
|
||||
}
|
||||
features = Features(features)
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset = hf_dataset.with_format("torch")
|
||||
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
||||
info = {
|
||||
"fps": fps,
|
||||
@@ -468,8 +471,6 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
|
||||
# "next.reward": reward,
|
||||
"next.done": done,
|
||||
# "next.success": success,
|
||||
"episode_data_index_from": torch.tensor([id_from] * num_frames),
|
||||
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -499,8 +500,7 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
|
||||
}
|
||||
features = Features(features)
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset = hf_dataset.with_format("torch")
|
||||
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
||||
info = {
|
||||
"fps": fps,
|
||||
@@ -515,11 +515,14 @@ if __name__ == "__main__":
|
||||
|
||||
dataset_ids = [
|
||||
"pusht",
|
||||
# "xarm_lift_medium",
|
||||
# "aloha_sim_insertion_human",
|
||||
# "aloha_sim_insertion_scripted",
|
||||
# "aloha_sim_transfer_cube_human",
|
||||
# "aloha_sim_transfer_cube_scripted",
|
||||
"xarm_lift_medium",
|
||||
"xarm_lift_medium_replay",
|
||||
"xarm_push_medium",
|
||||
"xarm_push_medium_replay",
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
"aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
]
|
||||
for dataset_id in dataset_ids:
|
||||
download_and_upload(root, revision, dataset_id)
|
||||
|
||||
Reference in New Issue
Block a user