Add mobile Aloha and visu with rerun.io
This commit is contained in:
@@ -27,8 +27,10 @@ def download_and_upload(root, revision, dataset_id):
|
||||
download_and_upload_pusht(root, revision, dataset_id)
|
||||
elif "xarm" in dataset_id:
|
||||
download_and_upload_xarm(root, revision, dataset_id)
|
||||
elif "aloha" in dataset_id:
|
||||
elif "aloha_sim" in dataset_id:
|
||||
download_and_upload_aloha(root, revision, dataset_id)
|
||||
elif "aloha_mobile" in dataset_id:
|
||||
download_and_upload_aloha_mobile(root, revision, dataset_id)
|
||||
else:
|
||||
raise ValueError(dataset_id)
|
||||
|
||||
@@ -530,20 +532,125 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
|
||||
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
|
||||
|
||||
|
||||
def download_and_upload_aloha_mobile(root, revision, dataset_id, fps=50):
|
||||
num_episodes = {
|
||||
"aloha_mobile_trossen_block_handoff": 5,
|
||||
}
|
||||
|
||||
# episode_len = {
|
||||
# "aloha_sim_insertion_human": 500,
|
||||
# "aloha_sim_insertion_scripted": 400,
|
||||
# "aloha_sim_transfer_cube_human": 400,
|
||||
# "aloha_sim_transfer_cube_scripted": 400,
|
||||
# }
|
||||
|
||||
cameras = {
|
||||
"aloha_mobile_trossen_block_handoff": ['cam_high', 'cam_left_wrist', 'cam_right_wrist'],
|
||||
}
|
||||
|
||||
root = Path(root)
|
||||
raw_dir = root / f"{dataset_id}_raw"
|
||||
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
id_from = 0
|
||||
for ep_id in tqdm.tqdm(range(num_episodes[dataset_id])):
|
||||
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
num_frames = ep["/action"].shape[0]
|
||||
|
||||
#assert episode_len[dataset_id] == num_frames
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
done[-1] = True
|
||||
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:num_frames])
|
||||
action = torch.from_numpy(ep["/action"][:num_frames])
|
||||
|
||||
ep_dict = {}
|
||||
|
||||
for cam in cameras[dataset_id]:
|
||||
image = ep[f"/observations/images/{cam}"][:num_frames] # b h w c
|
||||
|
||||
import cv2
|
||||
# un-pad and uncompress from: https://github.com/MarkFzp/act-plus-plus/blob/26bab0789d05b7496bacef04f5c6b2541a4403b5/postprocess_episodes.py#L50
|
||||
image = np.array([cv2.imdecode(x, 1) for x in image])
|
||||
image = [PILImage.fromarray(x) for x in image]
|
||||
ep_dict[f"observation.images.{cam}"] = image
|
||||
|
||||
ep_dict.update(
|
||||
{
|
||||
"observation.state": state,
|
||||
"action": action,
|
||||
"episode_index": torch.tensor([ep_id] * num_frames),
|
||||
"frame_index": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||
# "next.observation.state": state,
|
||||
# TODO(rcadene): compute reward and success
|
||||
# "next.reward": reward,
|
||||
"next.done": done,
|
||||
# "next.success": success,
|
||||
}
|
||||
)
|
||||
|
||||
assert isinstance(ep_id, int)
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
|
||||
id_from += num_frames
|
||||
|
||||
break
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
features = {}
|
||||
for cam in cameras[dataset_id]:
|
||||
features[f"observation.images.{cam}"] = Image()
|
||||
features.update({
|
||||
"observation.state": Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
),
|
||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||
"episode_index": Value(dtype="int64", id=None),
|
||||
"frame_index": Value(dtype="int64", id=None),
|
||||
"timestamp": Value(dtype="float32", id=None),
|
||||
#'next.reward': Value(dtype='float32', id=None),
|
||||
"next.done": Value(dtype="bool", id=None),
|
||||
#'next.success': Value(dtype='bool', id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
})
|
||||
|
||||
features = Features(features)
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
||||
info = {
|
||||
"fps": fps,
|
||||
}
|
||||
stats = compute_stats(hf_dataset)
|
||||
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
root = "data"
|
||||
revision = "v1.1"
|
||||
|
||||
dataset_ids = [
|
||||
"pusht",
|
||||
"xarm_lift_medium",
|
||||
"xarm_lift_medium_replay",
|
||||
"xarm_push_medium",
|
||||
"xarm_push_medium_replay",
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
"aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
# "pusht",
|
||||
# "xarm_lift_medium",
|
||||
# "xarm_lift_medium_replay",
|
||||
# "xarm_push_medium",
|
||||
# "xarm_push_medium_replay",
|
||||
# "aloha_sim_insertion_human",
|
||||
# "aloha_sim_insertion_scripted",
|
||||
# "aloha_sim_transfer_cube_human",
|
||||
# "aloha_sim_transfer_cube_scripted",
|
||||
"aloha_mobile_trossen_block_handoff",
|
||||
]
|
||||
for dataset_id in dataset_ids:
|
||||
download_and_upload(root, revision, dataset_id)
|
||||
|
||||
2
poetry.lock
generated
2
poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "absl-py"
|
||||
|
||||
30
test.py
Normal file
30
test.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import rerun as rr
|
||||
from datasets import load_from_disk
|
||||
|
||||
# download/load dataset in pyarrow format
|
||||
print("Loading dataset…")
|
||||
#dataset = load_dataset("lerobot/aloha_mobile_trossen_block_handoff", split="train")
|
||||
dataset = load_from_disk("tests/data/aloha_mobile_trossen_block_handoff/train")
|
||||
|
||||
# select the frames belonging to episode number 5
|
||||
print("Select specific episode…")
|
||||
|
||||
print("Starting Rerun…")
|
||||
rr.init("rerun_example_lerobot", spawn=True)
|
||||
|
||||
print("Logging to Rerun…")
|
||||
# for frame_index, timestamp, cam_high, cam_left_wrist, cam_right_wrist, state, action, next_reward in zip(
|
||||
|
||||
for d in dataset:
|
||||
rr.set_time_sequence("frame_index", d["frame_index"])
|
||||
rr.set_time_seconds("timestamp", d["timestamp"])
|
||||
rr.log("observation.images.cam_high", rr.Image( d["observation.images.cam_high"]))
|
||||
rr.log("observation.images.cam_left_wrist", rr.Image(d["observation.images.cam_left_wrist"]))
|
||||
rr.log("observation.images.cam_right_wrist", rr.Image(d["observation.images.cam_right_wrist"]))
|
||||
#rr.log("observation/state", rr.BarChart(state))
|
||||
#rr.log("observation/action", rr.BarChart(action))
|
||||
for idx, val in enumerate(d["action"]):
|
||||
rr.log(f"action_{idx}", rr.Scalar(val))
|
||||
|
||||
for idx, val in enumerate(d["observation.state"]):
|
||||
rr.log(f"state_{idx}", rr.Scalar(val))
|
||||
Reference in New Issue
Block a user