Integrate pusht env from diffusion
This commit is contained in:
@@ -5,7 +5,6 @@ import einops
|
||||
import numpy as np
|
||||
import pygame
|
||||
import pymunk
|
||||
import shapely.geometry as sg
|
||||
import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
@@ -16,29 +15,16 @@ from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
||||
from lerobot.common.datasets.utils import download_and_extract_zip
|
||||
from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely
|
||||
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
||||
|
||||
# as define in env
|
||||
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
||||
|
||||
DEFAULT_TEE_MASK = pymunk.ShapeFilter.ALL_MASKS()
|
||||
PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
||||
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
|
||||
|
||||
|
||||
def pymunk_to_shapely(body, shapes):
|
||||
geoms = []
|
||||
for shape in shapes:
|
||||
if isinstance(shape, pymunk.shapes.Poly):
|
||||
verts = [body.local_to_world(v) for v in shape.get_vertices()]
|
||||
verts += [verts[0]]
|
||||
geoms.append(sg.Polygon(verts))
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported shape type {type(shape)}")
|
||||
geom = sg.MultiPolygon(geoms)
|
||||
return geom
|
||||
|
||||
|
||||
def get_goal_pose_body(pose):
|
||||
mass = 1
|
||||
inertia = pymunk.moment_for_box(mass, (50, 100))
|
||||
@@ -62,8 +48,10 @@ def add_tee(
|
||||
angle,
|
||||
scale=30,
|
||||
color="LightSlateGray",
|
||||
mask=DEFAULT_TEE_MASK,
|
||||
mask=None,
|
||||
):
|
||||
if mask is None:
|
||||
mask = pymunk.ShapeFilter.ALL_MASKS()
|
||||
mass = 1
|
||||
length = 4
|
||||
vertices1 = [
|
||||
|
||||
Reference in New Issue
Block a user