diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 068b154e0..0e39a92db 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -2,11 +2,8 @@ from pathlib import Path import einops import numpy as np -import pygame -import pymunk import torch import tqdm -from gym_pusht.envs.pusht import pymunk_to_shapely from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer @@ -18,64 +15,6 @@ PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip" PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr") -def get_goal_pose_body(pose): - mass = 1 - inertia = pymunk.moment_for_box(mass, (50, 100)) - body = pymunk.Body(mass, inertia) - # preserving the legacy assignment order for compatibility - # the order here doesn't matter somehow, maybe because CoM is aligned with body origin - body.position = pose[:2].tolist() - body.angle = pose[2] - return body - - -def add_segment(space, a, b, radius): - shape = pymunk.Segment(space.static_body, a, b, radius) - shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names - return shape - - -def add_tee( - space, - position, - angle, - scale=30, - color="LightSlateGray", - mask=None, -): - if mask is None: - mask = pymunk.ShapeFilter.ALL_MASKS() - mass = 1 - length = 4 - vertices1 = [ - (-length * scale / 2, scale), - (length * scale / 2, scale), - (length * scale / 2, 0), - (-length * scale / 2, 0), - ] - inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1) - vertices2 = [ - (-scale / 2, scale), - (-scale / 2, length * scale), - (scale / 2, length * scale), - (scale / 2, scale), - ] - inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1) - body = pymunk.Body(mass, inertia1 + inertia2) - shape1 = pymunk.Poly(body, vertices1) - shape2 = pymunk.Poly(body, vertices2) - shape1.color = pygame.Color(color) - shape2.color = pygame.Color(color) - shape1.filter = pymunk.ShapeFilter(mask=mask) - shape2.filter = pymunk.ShapeFilter(mask=mask) - body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2 - body.position = position - body.angle = angle - body.friction = 1 - space.add(body, shape1, shape2) - return body - - class PushtDataset(torch.utils.data.Dataset): """ @@ -156,6 +95,13 @@ class PushtDataset(torch.utils.data.Dataset): return item def _download_and_preproc_obsolete(self): + try: + import pymunk + from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely + except ModuleNotFoundError as e: + print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`") + raise e + assert self.root is not None raw_dir = self.root / f"{self.dataset_id}_raw" zarr_path = (raw_dir / PUSHT_ZARR).resolve() @@ -180,7 +126,7 @@ class PushtDataset(torch.utils.data.Dataset): # TODO: verify that goal pose is expected to be fixed goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians) - goal_body = get_goal_pose_body(goal_pos_angle) + goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle) imgs = torch.from_numpy(dataset_dict["img"]) imgs = einops.rearrange(imgs, "b h w c -> b c h w") @@ -215,14 +161,14 @@ class PushtDataset(torch.utils.data.Dataset): # Add walls. walls = [ - add_segment(space, (5, 506), (5, 5), 2), - add_segment(space, (5, 5), (506, 5), 2), - add_segment(space, (506, 5), (506, 506), 2), - add_segment(space, (5, 506), (506, 506), 2), + PushTEnv.add_segment(space, (5, 506), (5, 5), 2), + PushTEnv.add_segment(space, (5, 5), (506, 5), 2), + PushTEnv.add_segment(space, (506, 5), (506, 506), 2), + PushTEnv.add_segment(space, (5, 506), (506, 506), 2), ] space.add(*walls) - block_body = add_tee(space, block_pos[i].tolist(), block_angle[i].item()) + block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item()) goal_geom = pymunk_to_shapely(goal_body, block_body.shapes) block_geom = pymunk_to_shapely(block_body, block_body.shapes) intersection_area = goal_geom.intersection(block_geom).area