Address comments

This commit is contained in:
Cadene
2024-04-16 17:14:40 +00:00
parent b241ea46dd
commit 36d9e885ef
24 changed files with 100 additions and 94 deletions

View File

@@ -17,6 +17,17 @@ from datasets import Dataset, Features, Image, Sequence, Value
from PIL import Image as PILImage
def download_and_upload(root, root_tests, dataset_id):
if "pusht" in dataset_id:
download_and_upload_pusht(root, root_tests, dataset_id)
elif "xarm" in dataset_id:
download_and_upload_xarm(root, root_tests, dataset_id)
elif "aloha" in dataset_id:
download_and_upload_aloha(root, root_tests, dataset_id)
else:
raise ValueError(dataset_id)
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
import zipfile
@@ -87,7 +98,6 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
states = torch.from_numpy(dataset_dict["state"])
actions = torch.from_numpy(dataset_dict["action"])
data_ids_per_episode = {}
ep_dicts = []
id_from = 0
@@ -150,15 +160,11 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
"next.done": torch.cat([done[1:], done[[-1]]]),
"next.success": torch.cat([success[1:], success[[-1]]]),
"episode_data_id_from": torch.tensor([id_from] * num_frames),
"episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames),
"episode_data_index_from": torch.tensor([id_from] * num_frames),
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
}
ep_dicts.append(ep_dict)
assert isinstance(episode_id, int)
data_ids_per_episode[episode_id] = torch.arange(id_from, id_to, 1)
assert len(data_ids_per_episode[episode_id]) == num_frames
id_from += num_frames
data_dict = {}
@@ -190,8 +196,8 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
"next.done": Value(dtype="bool", id=None),
"next.success": Value(dtype="bool", id=None),
"index": Value(dtype="int64", id=None),
"episode_data_id_from": Value(dtype="int64", id=None),
"episode_data_id_to": Value(dtype="int64", id=None),
"episode_data_index_from": Value(dtype="int64", id=None),
"episode_data_index_to": Value(dtype="int64", id=None),
}
features = Features(features)
dataset = Dataset.from_dict(data_dict, features=features)
@@ -265,8 +271,8 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
# "next.observation.state": next_state,
"next.reward": next_reward,
"next.done": next_done,
"episode_data_id_from": torch.tensor([id_from] * num_frames),
"episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames),
"episode_data_index_from": torch.tensor([id_from] * num_frames),
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
}
ep_dicts.append(ep_dict)
@@ -301,8 +307,8 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
"next.done": Value(dtype="bool", id=None),
#'next.success': Value(dtype='bool', id=None),
"index": Value(dtype="int64", id=None),
"episode_data_id_from": Value(dtype="int64", id=None),
"episode_data_id_to": Value(dtype="int64", id=None),
"episode_data_index_from": Value(dtype="int64", id=None),
"episode_data_index_to": Value(dtype="int64", id=None),
}
features = Features(features)
dataset = Dataset.from_dict(data_dict, features=features)
@@ -390,20 +396,7 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
state = torch.from_numpy(ep["/observations/qpos"][:])
action = torch.from_numpy(ep["/action"][:])
ep_dict = {
"observation.state": state,
"action": action,
"episode_id": torch.tensor([ep_id] * num_frames),
"frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.state": state,
# TODO(rcadene): compute reward and success
# "next.reward": reward,
"next.done": done,
# "next.success": success,
"episode_data_id_from": torch.tensor([id_from] * num_frames),
"episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames),
}
ep_dict = {}
for cam in cameras[dataset_id]:
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) # b h w c
@@ -411,6 +404,23 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
ep_dict[f"observation.images.{cam}"] = [PILImage.fromarray(x.numpy()) for x in image]
# ep_dict[f"next.observation.images.{cam}"] = image
ep_dict.update(
{
"observation.state": state,
"action": action,
"episode_id": torch.tensor([ep_id] * num_frames),
"frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.state": state,
# TODO(rcadene): compute reward and success
# "next.reward": reward,
"next.done": done,
# "next.success": success,
"episode_data_index_from": torch.tensor([id_from] * num_frames),
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
}
)
assert isinstance(ep_id, int)
ep_dicts.append(ep_dict)
@@ -446,8 +456,8 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
"next.done": Value(dtype="bool", id=None),
#'next.success': Value(dtype='bool', id=None),
"index": Value(dtype="int64", id=None),
"episode_data_id_from": Value(dtype="int64", id=None),
"episode_data_id_to": Value(dtype="int64", id=None),
"episode_data_index_from": Value(dtype="int64", id=None),
"episode_data_index_to": Value(dtype="int64", id=None),
}
features = Features(features)
dataset = Dataset.from_dict(data_dict, features=features)
@@ -461,23 +471,17 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
if __name__ == "__main__":
root = "data"
root_tests = "{root_tests}"
download_and_upload_pusht(root, root_tests, dataset_id="pusht")
download_and_upload_xarm(root, root_tests, dataset_id="xarm_lift_medium")
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_insertion_human")
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_insertion_scripted")
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_transfer_cube_human")
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_transfer_cube_scripted")
root_tests = "tests/data"
dataset_ids = [
"pusht",
"xarm_lift_medium",
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
# "pusht",
# "xarm_lift_medium",
# "aloha_sim_insertion_human",
# "aloha_sim_insertion_scripted",
# "aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
]
for dataset_id in dataset_ids:
download_and_upload(root, root_tests, dataset_id)
# assume stats have been precomputed
shutil.copy(f"{root}/{dataset_id}/stats.pth", f"{root_tests}/{dataset_id}/stats.pth")