Include observation.environment_state with keypoints in PushT dataset (#303)

Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
Alexander Soare
2024-07-09 08:27:40 +01:00
committed by GitHub
parent 7bd5ab16d1
commit a4d77b99f0
27 changed files with 739 additions and 609 deletions

View File

@@ -36,7 +36,7 @@ from lerobot.common.datasets.utils import (
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
CODEBASE_VERSION = "v1.4"
CODEBASE_VERSION = "v1.5"
class LeRobotDataset(torch.utils.data.Dataset):

View File

@@ -54,7 +54,14 @@ def check_format(raw_dir):
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
def load_from_raw(
raw_dir: Path,
videos_dir: Path,
fps: int,
video: bool,
episodes: list[int] | None = None,
keypoints_instead_of_image: bool = False,
):
try:
import pymunk
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
@@ -105,10 +112,11 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
assert (episode_ids[from_idx:to_idx] == ep_idx).all()
# get image
image = imgs[from_idx:to_idx]
assert image.min() >= 0.0
assert image.max() <= 255.0
image = image.type(torch.uint8)
if not keypoints_instead_of_image:
image = imgs[from_idx:to_idx]
assert image.min() >= 0.0
assert image.max() <= 255.0
image = image.type(torch.uint8)
# get state
state = states[from_idx:to_idx]
@@ -116,9 +124,11 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
block_pos = state[:, 2:4]
block_angle = state[:, 4]
# get reward, success, done
# get reward, success, done, and (maybe) keypoints
reward = torch.zeros(num_frames)
success = torch.zeros(num_frames, dtype=torch.bool)
if keypoints_instead_of_image:
keypoints = torch.zeros(num_frames, 16) # 8 keypoints each with 2 coords
done = torch.zeros(num_frames, dtype=torch.bool)
for i in range(num_frames):
space = pymunk.Space()
@@ -134,7 +144,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
]
space.add(*walls)
block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
intersection_area = goal_geom.intersection(block_geom).area
@@ -142,33 +152,40 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
coverage = intersection_area / goal_area
reward[i] = np.clip(coverage / success_threshold, 0, 1)
success[i] = coverage > success_threshold
if keypoints_instead_of_image:
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
# last step of demonstration is considered done
done[-1] = True
ep_dict = {}
imgs_array = [x.numpy() for x in image]
img_key = "observation.image"
if video:
# save png images in temporary directory
tmp_imgs_dir = videos_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir)
if not keypoints_instead_of_image:
imgs_array = [x.numpy() for x in image]
img_key = "observation.image"
if video:
# save png images in temporary directory
tmp_imgs_dir = videos_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir)
# encode images to a mp4 video
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
video_path = videos_dir / fname
encode_video_frames(tmp_imgs_dir, video_path, fps)
# encode images to a mp4 video
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
video_path = videos_dir / fname
encode_video_frames(tmp_imgs_dir, video_path, fps)
# clean temporary images directory
shutil.rmtree(tmp_imgs_dir)
# clean temporary images directory
shutil.rmtree(tmp_imgs_dir)
# store the reference to the video frame
ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)]
else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
# store the reference to the video frame
ep_dict[img_key] = [
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
]
else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
ep_dict["observation.state"] = agent_pos
if keypoints_instead_of_image:
ep_dict["observation.environment_state"] = keypoints
ep_dict["action"] = actions[from_idx:to_idx]
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
@@ -180,7 +197,6 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
ep_dict["next.done"] = torch.cat([done[1:], done[[-1]]])
ep_dict["next.success"] = torch.cat([success[1:], success[[-1]]])
ep_dicts.append(ep_dict)
data_dict = concatenate_episodes(ep_dicts)
total_frames = data_dict["frame_index"].shape[0]
@@ -188,17 +204,23 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
return data_dict
def to_hf_dataset(data_dict, video):
def to_hf_dataset(data_dict, video, keypoints_instead_of_image: bool = False):
features = {}
if video:
features["observation.image"] = VideoFrame()
else:
features["observation.image"] = Image()
if not keypoints_instead_of_image:
if video:
features["observation.image"] = VideoFrame()
else:
features["observation.image"] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
)
if keypoints_instead_of_image:
features["observation.environment_state"] = Sequence(
length=data_dict["observation.environment_state"].shape[1],
feature=Value(dtype="float32", id=None),
)
features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
)
@@ -222,17 +244,21 @@ def from_raw_to_lerobot_format(
video: bool = True,
episodes: list[int] | None = None,
):
# Manually change this to True to use keypoints of the T instead of an image observation (but don't merge
# with True). Also make sure to use video = 0 in the `push_dataset_to_hub.py` script.
keypoints_instead_of_image = False
# sanity check
check_format(raw_dir)
if fps is None:
fps = 10
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
hf_dataset = to_hf_dataset(data_dict, video)
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image)
hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {
"fps": fps,
"video": video,
"video": video if not keypoints_instead_of_image else 0,
}
return hf_dataset, episode_data_index, info