from pathlib import Path import einops import numpy as np import pygame import pymunk import torch import tqdm from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer # as define in env SUCCESS_THRESHOLD = 0.95 # 95% coverage, PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip" PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr") def get_goal_pose_body(pose): mass = 1 inertia = pymunk.moment_for_box(mass, (50, 100)) body = pymunk.Body(mass, inertia) # preserving the legacy assignment order for compatibility # the order here doesn't matter somehow, maybe because CoM is aligned with body origin body.position = pose[:2].tolist() body.angle = pose[2] return body def add_segment(space, a, b, radius): shape = pymunk.Segment(space.static_body, a, b, radius) shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names return shape def add_tee( space, position, angle, scale=30, color="LightSlateGray", mask=None, ): if mask is None: mask = pymunk.ShapeFilter.ALL_MASKS() mass = 1 length = 4 vertices1 = [ (-length * scale / 2, scale), (length * scale / 2, scale), (length * scale / 2, 0), (-length * scale / 2, 0), ] inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1) vertices2 = [ (-scale / 2, scale), (-scale / 2, length * scale), (scale / 2, length * scale), (scale / 2, scale), ] inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1) body = pymunk.Body(mass, inertia1 + inertia2) shape1 = pymunk.Poly(body, vertices1) shape2 = pymunk.Poly(body, vertices2) shape1.color = pygame.Color(color) shape2.color = pygame.Color(color) shape1.filter = pymunk.ShapeFilter(mask=mask) shape2.filter = pymunk.ShapeFilter(mask=mask) body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2 body.position = position body.angle = angle body.friction = 1 space.add(body, shape1, shape2) return body class PushtDataset(torch.utils.data.Dataset): """ Arguments ---------- delta_timestamps : dict[list[float]] | None, optional Loads data from frames with a shift in timestamps with a different strategy for each data key (e.g. state, action or image) If `None`, no shift is applied to current timestamp and the data from the current frame is loaded. """ available_datasets = ["pusht"] fps = 10 image_keys = ["observation.image"] def __init__( self, dataset_id: str, version: str | None = "v1.2", root: Path | None = None, transform: callable = None, delta_timestamps: dict[list[float]] | None = None, ): super().__init__() self.dataset_id = dataset_id self.version = version self.root = root self.transform = transform self.delta_timestamps = delta_timestamps self.data_dir = self.root / f"{self.dataset_id}" if (self.data_dir / "data_dict.pth").exists() and ( self.data_dir / "data_ids_per_episode.pth" ).exists(): self.data_dict = torch.load(self.data_dir / "data_dict.pth") self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth") else: self._download_and_preproc_obsolete() self.data_dir.mkdir(parents=True, exist_ok=True) torch.save(self.data_dict, self.data_dir / "data_dict.pth") torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth") @property def num_samples(self) -> int: return len(self.data_dict["index"]) @property def num_episodes(self) -> int: return len(self.data_ids_per_episode) def __len__(self): return self.num_samples def __getitem__(self, idx): item = {} # get episode id and timestamp of the sampled frame current_ts = self.data_dict["timestamp"][idx].item() episode = self.data_dict["episode"][idx].item() for key in self.data_dict: if self.delta_timestamps is not None and key in self.delta_timestamps: data, is_pad = load_data_with_delta_timestamps( self.data_dict, self.data_ids_per_episode, self.delta_timestamps, key, current_ts, episode, ) item[key] = data item[f"{key}_is_pad"] = is_pad else: item[key] = self.data_dict[key][idx] if self.transform is not None: item = self.transform(item) return item def _download_and_preproc_obsolete(self): assert self.root is not None raw_dir = self.root / f"{self.dataset_id}_raw" zarr_path = (raw_dir / PUSHT_ZARR).resolve() if not zarr_path.is_dir(): raw_dir.mkdir(parents=True, exist_ok=True) download_and_extract_zip(PUSHT_URL, raw_dir) # load dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path( zarr_path ) # , keys=['img', 'state', 'action']) episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs()) num_episodes = dataset_dict.meta["episode_ends"].shape[0] total_frames = dataset_dict["action"].shape[0] # to create test artifact # num_episodes = 1 # total_frames = 50 assert len( {dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118 ), "Some data type dont have the same number of total frames." # TODO: verify that goal pose is expected to be fixed goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians) goal_body = get_goal_pose_body(goal_pos_angle) imgs = torch.from_numpy(dataset_dict["img"]) imgs = einops.rearrange(imgs, "b h w c -> b c h w") states = torch.from_numpy(dataset_dict["state"]) actions = torch.from_numpy(dataset_dict["action"]) self.data_ids_per_episode = {} ep_dicts = [] idx0 = 0 for episode_id in tqdm.tqdm(range(num_episodes)): idx1 = dataset_dict.meta["episode_ends"][episode_id] # to create test artifact # idx1 = 51 num_frames = idx1 - idx0 assert (episode_ids[idx0:idx1] == episode_id).all() image = imgs[idx0:idx1] state = states[idx0:idx1] agent_pos = state[:, :2] block_pos = state[:, 2:4] block_angle = state[:, 4] reward = torch.zeros(num_frames, 1) success = torch.zeros(num_frames, 1, dtype=torch.bool) done = torch.zeros(num_frames, 1, dtype=torch.bool) for i in range(num_frames): space = pymunk.Space() space.gravity = 0, 0 space.damping = 0 # Add walls. walls = [ add_segment(space, (5, 506), (5, 5), 2), add_segment(space, (5, 5), (506, 5), 2), add_segment(space, (506, 5), (506, 506), 2), add_segment(space, (5, 506), (506, 506), 2), ] space.add(*walls) block_body = add_tee(space, block_pos[i].tolist(), block_angle[i].item()) goal_geom = pymunk_to_shapely(goal_body, block_body.shapes) block_geom = pymunk_to_shapely(block_body, block_body.shapes) intersection_area = goal_geom.intersection(block_geom).area goal_area = goal_geom.area coverage = intersection_area / goal_area reward[i] = np.clip(coverage / SUCCESS_THRESHOLD, 0, 1) success[i] = coverage > SUCCESS_THRESHOLD # last step of demonstration is considered done done[-1] = True ep_dict = { "observation.image": image, "observation.state": agent_pos, "action": actions[idx0:idx1], "episode": torch.tensor([episode_id] * num_frames, dtype=torch.int), "frame_id": torch.arange(0, num_frames, 1), "timestamp": torch.arange(0, num_frames, 1) / self.fps, # "next.observation.image": image[1:], # "next.observation.state": agent_pos[1:], # TODO(rcadene): verify that reward and done are aligned with image and agent_pos "next.reward": torch.cat([reward[1:], reward[[-1]]]), "next.done": torch.cat([done[1:], done[[-1]]]), "next.success": torch.cat([success[1:], success[[-1]]]), } ep_dicts.append(ep_dict) assert isinstance(episode_id, int) self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1) assert len(self.data_ids_per_episode[episode_id]) == num_frames idx0 = idx1 self.data_dict = {} keys = ep_dicts[0].keys() for key in keys: self.data_dict[key] = torch.cat([x[key] for x in ep_dicts]) self.data_dict["index"] = torch.arange(0, total_frames, 1) if __name__ == "__main__": dataset = PushtDataset( "pusht", root=Path("data"), delta_timestamps={ "observation.image": [0, -1, -0.2, -0.1], "observation.state": [0, -1, -0.2, -0.1], "action": [-0.1, 0, 1, 2, 3], }, ) dataset[10]