From 7cee7a0f206420cfc4122dddeb4af2332ab49197 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Sat, 20 Apr 2024 16:19:55 +0200 Subject: [PATCH] Add mobile Aloha and visu with rerun.io --- download_and_upload_dataset.py | 127 ++++++++++++++++++++++++++++++--- poetry.lock | 2 +- test.py | 30 ++++++++ 3 files changed, 148 insertions(+), 11 deletions(-) create mode 100644 test.py diff --git a/download_and_upload_dataset.py b/download_and_upload_dataset.py index 8e1e27cea..9ee0467f8 100644 --- a/download_and_upload_dataset.py +++ b/download_and_upload_dataset.py @@ -27,8 +27,10 @@ def download_and_upload(root, revision, dataset_id): download_and_upload_pusht(root, revision, dataset_id) elif "xarm" in dataset_id: download_and_upload_xarm(root, revision, dataset_id) - elif "aloha" in dataset_id: + elif "aloha_sim" in dataset_id: download_and_upload_aloha(root, revision, dataset_id) + elif "aloha_mobile" in dataset_id: + download_and_upload_aloha_mobile(root, revision, dataset_id) else: raise ValueError(dataset_id) @@ -530,20 +532,125 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50): push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id) +def download_and_upload_aloha_mobile(root, revision, dataset_id, fps=50): + num_episodes = { + "aloha_mobile_trossen_block_handoff": 5, + } + + # episode_len = { + # "aloha_sim_insertion_human": 500, + # "aloha_sim_insertion_scripted": 400, + # "aloha_sim_transfer_cube_human": 400, + # "aloha_sim_transfer_cube_scripted": 400, + # } + + cameras = { + "aloha_mobile_trossen_block_handoff": ['cam_high', 'cam_left_wrist', 'cam_right_wrist'], + } + + root = Path(root) + raw_dir = root / f"{dataset_id}_raw" + + ep_dicts = [] + episode_data_index = {"from": [], "to": []} + + id_from = 0 + for ep_id in tqdm.tqdm(range(num_episodes[dataset_id])): + ep_path = raw_dir / f"episode_{ep_id}.hdf5" + with h5py.File(ep_path, "r") as ep: + num_frames = ep["/action"].shape[0] + + #assert episode_len[dataset_id] == num_frames + + # last step of demonstration is considered done + done = torch.zeros(num_frames, dtype=torch.bool) + done[-1] = True + + state = torch.from_numpy(ep["/observations/qpos"][:num_frames]) + action = torch.from_numpy(ep["/action"][:num_frames]) + + ep_dict = {} + + for cam in cameras[dataset_id]: + image = ep[f"/observations/images/{cam}"][:num_frames] # b h w c + + import cv2 + # un-pad and uncompress from: https://github.com/MarkFzp/act-plus-plus/blob/26bab0789d05b7496bacef04f5c6b2541a4403b5/postprocess_episodes.py#L50 + image = np.array([cv2.imdecode(x, 1) for x in image]) + image = [PILImage.fromarray(x) for x in image] + ep_dict[f"observation.images.{cam}"] = image + + ep_dict.update( + { + "observation.state": state, + "action": action, + "episode_index": torch.tensor([ep_id] * num_frames), + "frame_index": torch.arange(0, num_frames, 1), + "timestamp": torch.arange(0, num_frames, 1) / fps, + # "next.observation.state": state, + # TODO(rcadene): compute reward and success + # "next.reward": reward, + "next.done": done, + # "next.success": success, + } + ) + + assert isinstance(ep_id, int) + ep_dicts.append(ep_dict) + + episode_data_index["from"].append(id_from) + episode_data_index["to"].append(id_from + num_frames) + + id_from += num_frames + + break + + data_dict = concatenate_episodes(ep_dicts) + + features = {} + for cam in cameras[dataset_id]: + features[f"observation.images.{cam}"] = Image() + features.update({ + "observation.state": Sequence( + length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) + ), + "action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)), + "episode_index": Value(dtype="int64", id=None), + "frame_index": Value(dtype="int64", id=None), + "timestamp": Value(dtype="float32", id=None), + #'next.reward': Value(dtype='float32', id=None), + "next.done": Value(dtype="bool", id=None), + #'next.success': Value(dtype='bool', id=None), + "index": Value(dtype="int64", id=None), + }) + + features = Features(features) + hf_dataset = Dataset.from_dict(data_dict, features=features) + hf_dataset.set_transform(hf_transform_to_torch) + + info = { + "fps": fps, + } + stats = compute_stats(hf_dataset) + push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id) + + + if __name__ == "__main__": root = "data" revision = "v1.1" dataset_ids = [ - "pusht", - "xarm_lift_medium", - "xarm_lift_medium_replay", - "xarm_push_medium", - "xarm_push_medium_replay", - "aloha_sim_insertion_human", - "aloha_sim_insertion_scripted", - "aloha_sim_transfer_cube_human", - "aloha_sim_transfer_cube_scripted", + # "pusht", + # "xarm_lift_medium", + # "xarm_lift_medium_replay", + # "xarm_push_medium", + # "xarm_push_medium_replay", + # "aloha_sim_insertion_human", + # "aloha_sim_insertion_scripted", + # "aloha_sim_transfer_cube_human", + # "aloha_sim_transfer_cube_scripted", + "aloha_mobile_trossen_block_handoff", ] for dataset_id in dataset_ids: download_and_upload(root, revision, dataset_id) diff --git a/poetry.lock b/poetry.lock index a70e404a9..1df4ed847 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "absl-py" diff --git a/test.py b/test.py new file mode 100644 index 000000000..0d9615870 --- /dev/null +++ b/test.py @@ -0,0 +1,30 @@ +import rerun as rr +from datasets import load_from_disk + +# download/load dataset in pyarrow format +print("Loading dataset…") +#dataset = load_dataset("lerobot/aloha_mobile_trossen_block_handoff", split="train") +dataset = load_from_disk("tests/data/aloha_mobile_trossen_block_handoff/train") + +# select the frames belonging to episode number 5 +print("Select specific episode…") + +print("Starting Rerun…") +rr.init("rerun_example_lerobot", spawn=True) + +print("Logging to Rerun…") +# for frame_index, timestamp, cam_high, cam_left_wrist, cam_right_wrist, state, action, next_reward in zip( + +for d in dataset: + rr.set_time_sequence("frame_index", d["frame_index"]) + rr.set_time_seconds("timestamp", d["timestamp"]) + rr.log("observation.images.cam_high", rr.Image( d["observation.images.cam_high"])) + rr.log("observation.images.cam_left_wrist", rr.Image(d["observation.images.cam_left_wrist"])) + rr.log("observation.images.cam_right_wrist", rr.Image(d["observation.images.cam_right_wrist"])) + #rr.log("observation/state", rr.BarChart(state)) + #rr.log("observation/action", rr.BarChart(action)) + for idx, val in enumerate(d["action"]): + rr.log(f"action_{idx}", rr.Scalar(val)) + + for idx, val in enumerate(d["observation.state"]): + rr.log(f"state_{idx}", rr.Scalar(val)) \ No newline at end of file