Add mobile Aloha and visu with rerun.io

This commit is contained in:
Remi Cadene
2024-04-20 16:19:55 +02:00
parent 2a59825a00
commit 7cee7a0f20
3 changed files with 148 additions and 11 deletions

View File

@@ -27,8 +27,10 @@ def download_and_upload(root, revision, dataset_id):
download_and_upload_pusht(root, revision, dataset_id)
elif "xarm" in dataset_id:
download_and_upload_xarm(root, revision, dataset_id)
elif "aloha" in dataset_id:
elif "aloha_sim" in dataset_id:
download_and_upload_aloha(root, revision, dataset_id)
elif "aloha_mobile" in dataset_id:
download_and_upload_aloha_mobile(root, revision, dataset_id)
else:
raise ValueError(dataset_id)
@@ -530,20 +532,125 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
def download_and_upload_aloha_mobile(root, revision, dataset_id, fps=50):
num_episodes = {
"aloha_mobile_trossen_block_handoff": 5,
}
# episode_len = {
# "aloha_sim_insertion_human": 500,
# "aloha_sim_insertion_scripted": 400,
# "aloha_sim_transfer_cube_human": 400,
# "aloha_sim_transfer_cube_scripted": 400,
# }
cameras = {
"aloha_mobile_trossen_block_handoff": ['cam_high', 'cam_left_wrist', 'cam_right_wrist'],
}
root = Path(root)
raw_dir = root / f"{dataset_id}_raw"
ep_dicts = []
episode_data_index = {"from": [], "to": []}
id_from = 0
for ep_id in tqdm.tqdm(range(num_episodes[dataset_id])):
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
with h5py.File(ep_path, "r") as ep:
num_frames = ep["/action"].shape[0]
#assert episode_len[dataset_id] == num_frames
# last step of demonstration is considered done
done = torch.zeros(num_frames, dtype=torch.bool)
done[-1] = True
state = torch.from_numpy(ep["/observations/qpos"][:num_frames])
action = torch.from_numpy(ep["/action"][:num_frames])
ep_dict = {}
for cam in cameras[dataset_id]:
image = ep[f"/observations/images/{cam}"][:num_frames] # b h w c
import cv2
# un-pad and uncompress from: https://github.com/MarkFzp/act-plus-plus/blob/26bab0789d05b7496bacef04f5c6b2541a4403b5/postprocess_episodes.py#L50
image = np.array([cv2.imdecode(x, 1) for x in image])
image = [PILImage.fromarray(x) for x in image]
ep_dict[f"observation.images.{cam}"] = image
ep_dict.update(
{
"observation.state": state,
"action": action,
"episode_index": torch.tensor([ep_id] * num_frames),
"frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.state": state,
# TODO(rcadene): compute reward and success
# "next.reward": reward,
"next.done": done,
# "next.success": success,
}
)
assert isinstance(ep_id, int)
ep_dicts.append(ep_dict)
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from += num_frames
break
data_dict = concatenate_episodes(ep_dicts)
features = {}
for cam in cameras[dataset_id]:
features[f"observation.images.{cam}"] = Image()
features.update({
"observation.state": Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
#'next.reward': Value(dtype='float32', id=None),
"next.done": Value(dtype="bool", id=None),
#'next.success': Value(dtype='bool', id=None),
"index": Value(dtype="int64", id=None),
})
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset.set_transform(hf_transform_to_torch)
info = {
"fps": fps,
}
stats = compute_stats(hf_dataset)
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
if __name__ == "__main__":
root = "data"
revision = "v1.1"
dataset_ids = [
"pusht",
"xarm_lift_medium",
"xarm_lift_medium_replay",
"xarm_push_medium",
"xarm_push_medium_replay",
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
# "pusht",
# "xarm_lift_medium",
# "xarm_lift_medium_replay",
# "xarm_push_medium",
# "xarm_push_medium_replay",
# "aloha_sim_insertion_human",
# "aloha_sim_insertion_scripted",
# "aloha_sim_transfer_cube_human",
# "aloha_sim_transfer_cube_scripted",
"aloha_mobile_trossen_block_handoff",
]
for dataset_id in dataset_ids:
download_and_upload(root, revision, dataset_id)

2
poetry.lock generated
View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]]
name = "absl-py"

30
test.py Normal file
View File

@@ -0,0 +1,30 @@
import rerun as rr
from datasets import load_from_disk
# download/load dataset in pyarrow format
print("Loading dataset…")
#dataset = load_dataset("lerobot/aloha_mobile_trossen_block_handoff", split="train")
dataset = load_from_disk("tests/data/aloha_mobile_trossen_block_handoff/train")
# select the frames belonging to episode number 5
print("Select specific episode…")
print("Starting Rerun…")
rr.init("rerun_example_lerobot", spawn=True)
print("Logging to Rerun…")
# for frame_index, timestamp, cam_high, cam_left_wrist, cam_right_wrist, state, action, next_reward in zip(
for d in dataset:
rr.set_time_sequence("frame_index", d["frame_index"])
rr.set_time_seconds("timestamp", d["timestamp"])
rr.log("observation.images.cam_high", rr.Image( d["observation.images.cam_high"]))
rr.log("observation.images.cam_left_wrist", rr.Image(d["observation.images.cam_left_wrist"]))
rr.log("observation.images.cam_right_wrist", rr.Image(d["observation.images.cam_right_wrist"]))
#rr.log("observation/state", rr.BarChart(state))
#rr.log("observation/action", rr.BarChart(action))
for idx, val in enumerate(d["action"]):
rr.log(f"action_{idx}", rr.Scalar(val))
for idx, val in enumerate(d["observation.state"]):
rr.log(f"state_{idx}", rr.Scalar(val))