Add mobile Aloha and visu with rerun.io
This commit is contained in:
@@ -27,8 +27,10 @@ def download_and_upload(root, revision, dataset_id):
|
|||||||
download_and_upload_pusht(root, revision, dataset_id)
|
download_and_upload_pusht(root, revision, dataset_id)
|
||||||
elif "xarm" in dataset_id:
|
elif "xarm" in dataset_id:
|
||||||
download_and_upload_xarm(root, revision, dataset_id)
|
download_and_upload_xarm(root, revision, dataset_id)
|
||||||
elif "aloha" in dataset_id:
|
elif "aloha_sim" in dataset_id:
|
||||||
download_and_upload_aloha(root, revision, dataset_id)
|
download_and_upload_aloha(root, revision, dataset_id)
|
||||||
|
elif "aloha_mobile" in dataset_id:
|
||||||
|
download_and_upload_aloha_mobile(root, revision, dataset_id)
|
||||||
else:
|
else:
|
||||||
raise ValueError(dataset_id)
|
raise ValueError(dataset_id)
|
||||||
|
|
||||||
@@ -530,20 +532,125 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
|
|||||||
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
|
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
|
||||||
|
|
||||||
|
|
||||||
|
def download_and_upload_aloha_mobile(root, revision, dataset_id, fps=50):
|
||||||
|
num_episodes = {
|
||||||
|
"aloha_mobile_trossen_block_handoff": 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
# episode_len = {
|
||||||
|
# "aloha_sim_insertion_human": 500,
|
||||||
|
# "aloha_sim_insertion_scripted": 400,
|
||||||
|
# "aloha_sim_transfer_cube_human": 400,
|
||||||
|
# "aloha_sim_transfer_cube_scripted": 400,
|
||||||
|
# }
|
||||||
|
|
||||||
|
cameras = {
|
||||||
|
"aloha_mobile_trossen_block_handoff": ['cam_high', 'cam_left_wrist', 'cam_right_wrist'],
|
||||||
|
}
|
||||||
|
|
||||||
|
root = Path(root)
|
||||||
|
raw_dir = root / f"{dataset_id}_raw"
|
||||||
|
|
||||||
|
ep_dicts = []
|
||||||
|
episode_data_index = {"from": [], "to": []}
|
||||||
|
|
||||||
|
id_from = 0
|
||||||
|
for ep_id in tqdm.tqdm(range(num_episodes[dataset_id])):
|
||||||
|
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
||||||
|
with h5py.File(ep_path, "r") as ep:
|
||||||
|
num_frames = ep["/action"].shape[0]
|
||||||
|
|
||||||
|
#assert episode_len[dataset_id] == num_frames
|
||||||
|
|
||||||
|
# last step of demonstration is considered done
|
||||||
|
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||||
|
done[-1] = True
|
||||||
|
|
||||||
|
state = torch.from_numpy(ep["/observations/qpos"][:num_frames])
|
||||||
|
action = torch.from_numpy(ep["/action"][:num_frames])
|
||||||
|
|
||||||
|
ep_dict = {}
|
||||||
|
|
||||||
|
for cam in cameras[dataset_id]:
|
||||||
|
image = ep[f"/observations/images/{cam}"][:num_frames] # b h w c
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
# un-pad and uncompress from: https://github.com/MarkFzp/act-plus-plus/blob/26bab0789d05b7496bacef04f5c6b2541a4403b5/postprocess_episodes.py#L50
|
||||||
|
image = np.array([cv2.imdecode(x, 1) for x in image])
|
||||||
|
image = [PILImage.fromarray(x) for x in image]
|
||||||
|
ep_dict[f"observation.images.{cam}"] = image
|
||||||
|
|
||||||
|
ep_dict.update(
|
||||||
|
{
|
||||||
|
"observation.state": state,
|
||||||
|
"action": action,
|
||||||
|
"episode_index": torch.tensor([ep_id] * num_frames),
|
||||||
|
"frame_index": torch.arange(0, num_frames, 1),
|
||||||
|
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||||
|
# "next.observation.state": state,
|
||||||
|
# TODO(rcadene): compute reward and success
|
||||||
|
# "next.reward": reward,
|
||||||
|
"next.done": done,
|
||||||
|
# "next.success": success,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(ep_id, int)
|
||||||
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
|
episode_data_index["from"].append(id_from)
|
||||||
|
episode_data_index["to"].append(id_from + num_frames)
|
||||||
|
|
||||||
|
id_from += num_frames
|
||||||
|
|
||||||
|
break
|
||||||
|
|
||||||
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
|
|
||||||
|
features = {}
|
||||||
|
for cam in cameras[dataset_id]:
|
||||||
|
features[f"observation.images.{cam}"] = Image()
|
||||||
|
features.update({
|
||||||
|
"observation.state": Sequence(
|
||||||
|
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
),
|
||||||
|
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||||
|
"episode_index": Value(dtype="int64", id=None),
|
||||||
|
"frame_index": Value(dtype="int64", id=None),
|
||||||
|
"timestamp": Value(dtype="float32", id=None),
|
||||||
|
#'next.reward': Value(dtype='float32', id=None),
|
||||||
|
"next.done": Value(dtype="bool", id=None),
|
||||||
|
#'next.success': Value(dtype='bool', id=None),
|
||||||
|
"index": Value(dtype="int64", id=None),
|
||||||
|
})
|
||||||
|
|
||||||
|
features = Features(features)
|
||||||
|
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||||
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"fps": fps,
|
||||||
|
}
|
||||||
|
stats = compute_stats(hf_dataset)
|
||||||
|
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
root = "data"
|
root = "data"
|
||||||
revision = "v1.1"
|
revision = "v1.1"
|
||||||
|
|
||||||
dataset_ids = [
|
dataset_ids = [
|
||||||
"pusht",
|
# "pusht",
|
||||||
"xarm_lift_medium",
|
# "xarm_lift_medium",
|
||||||
"xarm_lift_medium_replay",
|
# "xarm_lift_medium_replay",
|
||||||
"xarm_push_medium",
|
# "xarm_push_medium",
|
||||||
"xarm_push_medium_replay",
|
# "xarm_push_medium_replay",
|
||||||
"aloha_sim_insertion_human",
|
# "aloha_sim_insertion_human",
|
||||||
"aloha_sim_insertion_scripted",
|
# "aloha_sim_insertion_scripted",
|
||||||
"aloha_sim_transfer_cube_human",
|
# "aloha_sim_transfer_cube_human",
|
||||||
"aloha_sim_transfer_cube_scripted",
|
# "aloha_sim_transfer_cube_scripted",
|
||||||
|
"aloha_mobile_trossen_block_handoff",
|
||||||
]
|
]
|
||||||
for dataset_id in dataset_ids:
|
for dataset_id in dataset_ids:
|
||||||
download_and_upload(root, revision, dataset_id)
|
download_and_upload(root, revision, dataset_id)
|
||||||
|
|||||||
2
poetry.lock
generated
2
poetry.lock
generated
@@ -1,4 +1,4 @@
|
|||||||
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "absl-py"
|
name = "absl-py"
|
||||||
|
|||||||
30
test.py
Normal file
30
test.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
import rerun as rr
|
||||||
|
from datasets import load_from_disk
|
||||||
|
|
||||||
|
# download/load dataset in pyarrow format
|
||||||
|
print("Loading dataset…")
|
||||||
|
#dataset = load_dataset("lerobot/aloha_mobile_trossen_block_handoff", split="train")
|
||||||
|
dataset = load_from_disk("tests/data/aloha_mobile_trossen_block_handoff/train")
|
||||||
|
|
||||||
|
# select the frames belonging to episode number 5
|
||||||
|
print("Select specific episode…")
|
||||||
|
|
||||||
|
print("Starting Rerun…")
|
||||||
|
rr.init("rerun_example_lerobot", spawn=True)
|
||||||
|
|
||||||
|
print("Logging to Rerun…")
|
||||||
|
# for frame_index, timestamp, cam_high, cam_left_wrist, cam_right_wrist, state, action, next_reward in zip(
|
||||||
|
|
||||||
|
for d in dataset:
|
||||||
|
rr.set_time_sequence("frame_index", d["frame_index"])
|
||||||
|
rr.set_time_seconds("timestamp", d["timestamp"])
|
||||||
|
rr.log("observation.images.cam_high", rr.Image( d["observation.images.cam_high"]))
|
||||||
|
rr.log("observation.images.cam_left_wrist", rr.Image(d["observation.images.cam_left_wrist"]))
|
||||||
|
rr.log("observation.images.cam_right_wrist", rr.Image(d["observation.images.cam_right_wrist"]))
|
||||||
|
#rr.log("observation/state", rr.BarChart(state))
|
||||||
|
#rr.log("observation/action", rr.BarChart(action))
|
||||||
|
for idx, val in enumerate(d["action"]):
|
||||||
|
rr.log(f"action_{idx}", rr.Scalar(val))
|
||||||
|
|
||||||
|
for idx, val in enumerate(d["observation.state"]):
|
||||||
|
rr.log(f"state_{idx}", rr.Scalar(val))
|
||||||
Reference in New Issue
Block a user