Address comments
This commit is contained in:
@@ -17,6 +17,17 @@ from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
|
||||
def download_and_upload(root, root_tests, dataset_id):
|
||||
if "pusht" in dataset_id:
|
||||
download_and_upload_pusht(root, root_tests, dataset_id)
|
||||
elif "xarm" in dataset_id:
|
||||
download_and_upload_xarm(root, root_tests, dataset_id)
|
||||
elif "aloha" in dataset_id:
|
||||
download_and_upload_aloha(root, root_tests, dataset_id)
|
||||
else:
|
||||
raise ValueError(dataset_id)
|
||||
|
||||
|
||||
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||
import zipfile
|
||||
|
||||
@@ -87,7 +98,6 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
|
||||
states = torch.from_numpy(dataset_dict["state"])
|
||||
actions = torch.from_numpy(dataset_dict["action"])
|
||||
|
||||
data_ids_per_episode = {}
|
||||
ep_dicts = []
|
||||
|
||||
id_from = 0
|
||||
@@ -150,15 +160,11 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
|
||||
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
|
||||
"next.done": torch.cat([done[1:], done[[-1]]]),
|
||||
"next.success": torch.cat([success[1:], success[[-1]]]),
|
||||
"episode_data_id_from": torch.tensor([id_from] * num_frames),
|
||||
"episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames),
|
||||
"episode_data_index_from": torch.tensor([id_from] * num_frames),
|
||||
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
|
||||
}
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
assert isinstance(episode_id, int)
|
||||
data_ids_per_episode[episode_id] = torch.arange(id_from, id_to, 1)
|
||||
assert len(data_ids_per_episode[episode_id]) == num_frames
|
||||
|
||||
id_from += num_frames
|
||||
|
||||
data_dict = {}
|
||||
@@ -190,8 +196,8 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
|
||||
"next.done": Value(dtype="bool", id=None),
|
||||
"next.success": Value(dtype="bool", id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
"episode_data_id_from": Value(dtype="int64", id=None),
|
||||
"episode_data_id_to": Value(dtype="int64", id=None),
|
||||
"episode_data_index_from": Value(dtype="int64", id=None),
|
||||
"episode_data_index_to": Value(dtype="int64", id=None),
|
||||
}
|
||||
features = Features(features)
|
||||
dataset = Dataset.from_dict(data_dict, features=features)
|
||||
@@ -265,8 +271,8 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
|
||||
# "next.observation.state": next_state,
|
||||
"next.reward": next_reward,
|
||||
"next.done": next_done,
|
||||
"episode_data_id_from": torch.tensor([id_from] * num_frames),
|
||||
"episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames),
|
||||
"episode_data_index_from": torch.tensor([id_from] * num_frames),
|
||||
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
|
||||
}
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
@@ -301,8 +307,8 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
|
||||
"next.done": Value(dtype="bool", id=None),
|
||||
#'next.success': Value(dtype='bool', id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
"episode_data_id_from": Value(dtype="int64", id=None),
|
||||
"episode_data_id_to": Value(dtype="int64", id=None),
|
||||
"episode_data_index_from": Value(dtype="int64", id=None),
|
||||
"episode_data_index_to": Value(dtype="int64", id=None),
|
||||
}
|
||||
features = Features(features)
|
||||
dataset = Dataset.from_dict(data_dict, features=features)
|
||||
@@ -390,20 +396,7 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
|
||||
ep_dict = {
|
||||
"observation.state": state,
|
||||
"action": action,
|
||||
"episode_id": torch.tensor([ep_id] * num_frames),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||
# "next.observation.state": state,
|
||||
# TODO(rcadene): compute reward and success
|
||||
# "next.reward": reward,
|
||||
"next.done": done,
|
||||
# "next.success": success,
|
||||
"episode_data_id_from": torch.tensor([id_from] * num_frames),
|
||||
"episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames),
|
||||
}
|
||||
ep_dict = {}
|
||||
|
||||
for cam in cameras[dataset_id]:
|
||||
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) # b h w c
|
||||
@@ -411,6 +404,23 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
|
||||
ep_dict[f"observation.images.{cam}"] = [PILImage.fromarray(x.numpy()) for x in image]
|
||||
# ep_dict[f"next.observation.images.{cam}"] = image
|
||||
|
||||
ep_dict.update(
|
||||
{
|
||||
"observation.state": state,
|
||||
"action": action,
|
||||
"episode_id": torch.tensor([ep_id] * num_frames),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||
# "next.observation.state": state,
|
||||
# TODO(rcadene): compute reward and success
|
||||
# "next.reward": reward,
|
||||
"next.done": done,
|
||||
# "next.success": success,
|
||||
"episode_data_index_from": torch.tensor([id_from] * num_frames),
|
||||
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
|
||||
}
|
||||
)
|
||||
|
||||
assert isinstance(ep_id, int)
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
@@ -446,8 +456,8 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
|
||||
"next.done": Value(dtype="bool", id=None),
|
||||
#'next.success': Value(dtype='bool', id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
"episode_data_id_from": Value(dtype="int64", id=None),
|
||||
"episode_data_id_to": Value(dtype="int64", id=None),
|
||||
"episode_data_index_from": Value(dtype="int64", id=None),
|
||||
"episode_data_index_to": Value(dtype="int64", id=None),
|
||||
}
|
||||
features = Features(features)
|
||||
dataset = Dataset.from_dict(data_dict, features=features)
|
||||
@@ -461,23 +471,17 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
|
||||
|
||||
if __name__ == "__main__":
|
||||
root = "data"
|
||||
root_tests = "{root_tests}"
|
||||
|
||||
download_and_upload_pusht(root, root_tests, dataset_id="pusht")
|
||||
download_and_upload_xarm(root, root_tests, dataset_id="xarm_lift_medium")
|
||||
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_insertion_human")
|
||||
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_insertion_scripted")
|
||||
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_transfer_cube_human")
|
||||
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_transfer_cube_scripted")
|
||||
root_tests = "tests/data"
|
||||
|
||||
dataset_ids = [
|
||||
"pusht",
|
||||
"xarm_lift_medium",
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
"aloha_sim_transfer_cube_human",
|
||||
# "pusht",
|
||||
# "xarm_lift_medium",
|
||||
# "aloha_sim_insertion_human",
|
||||
# "aloha_sim_insertion_scripted",
|
||||
# "aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
]
|
||||
for dataset_id in dataset_ids:
|
||||
download_and_upload(root, root_tests, dataset_id)
|
||||
# assume stats have been precomputed
|
||||
shutil.copy(f"{root}/{dataset_id}/stats.pth", f"{root_tests}/{dataset_id}/stats.pth")
|
||||
|
||||
Reference in New Issue
Block a user