Integrate pusht env from diffusion

This commit is contained in:
Simon Alibert
2024-03-10 16:33:03 +01:00
parent 302b78962c
commit 6c867d78ef
10 changed files with 801 additions and 26 deletions

View File

@@ -5,7 +5,6 @@ import einops
import numpy as np
import pygame
import pymunk
import shapely.geometry as sg
import torch
import torchrl
import tqdm
@@ -16,29 +15,16 @@ from torchrl.data.replay_buffers.writers import Writer
from lerobot.common.datasets.abstract import AbstractExperienceReplay
from lerobot.common.datasets.utils import download_and_extract_zip
from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
# as define in env
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
DEFAULT_TEE_MASK = pymunk.ShapeFilter.ALL_MASKS()
PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
def pymunk_to_shapely(body, shapes):
geoms = []
for shape in shapes:
if isinstance(shape, pymunk.shapes.Poly):
verts = [body.local_to_world(v) for v in shape.get_vertices()]
verts += [verts[0]]
geoms.append(sg.Polygon(verts))
else:
raise RuntimeError(f"Unsupported shape type {type(shape)}")
geom = sg.MultiPolygon(geoms)
return geom
def get_goal_pose_body(pose):
mass = 1
inertia = pymunk.moment_for_box(mass, (50, 100))
@@ -62,8 +48,10 @@ def add_tee(
angle,
scale=30,
color="LightSlateGray",
mask=DEFAULT_TEE_MASK,
mask=None,
):
if mask is None:
mask = pymunk.ShapeFilter.ALL_MASKS()
mass = 1
length = 4
vertices1 = [