Add pusht dataset (TODO verify reward is aligned), Refactor visualize_dataset, Add video_dir, fps, state_dim, action_dim to config (Training works)
This commit is contained in:
@@ -5,18 +5,21 @@ from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||
from rl.torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler
|
||||
|
||||
|
||||
def make_offline_buffer(cfg):
|
||||
def make_offline_buffer(cfg, sampler=None):
|
||||
|
||||
num_traj_per_batch = cfg.batch_size # // cfg.horizon
|
||||
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
|
||||
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
|
||||
sampler = PrioritizedSliceSampler(
|
||||
max_capacity=100_000,
|
||||
alpha=cfg.per_alpha,
|
||||
beta=cfg.per_beta,
|
||||
num_slices=num_traj_per_batch,
|
||||
strict_length=False,
|
||||
)
|
||||
overwrite_sampler = sampler is not None
|
||||
|
||||
if not overwrite_sampler:
|
||||
num_traj_per_batch = cfg.batch_size # // cfg.horizon
|
||||
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
|
||||
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
|
||||
sampler = PrioritizedSliceSampler(
|
||||
max_capacity=100_000,
|
||||
alpha=cfg.per_alpha,
|
||||
beta=cfg.per_beta,
|
||||
num_slices=num_traj_per_batch,
|
||||
strict_length=False,
|
||||
)
|
||||
|
||||
if cfg.env == "simxarm":
|
||||
# TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here
|
||||
@@ -30,9 +33,9 @@ def make_offline_buffer(cfg):
|
||||
)
|
||||
elif cfg.env == "pusht":
|
||||
offline_buffer = PushtExperienceReplay(
|
||||
f"xarm_{cfg.task}_medium",
|
||||
"pusht",
|
||||
# download="force",
|
||||
download=True,
|
||||
download=False,
|
||||
streaming=False,
|
||||
root="data",
|
||||
sampler=sampler,
|
||||
@@ -40,8 +43,9 @@ def make_offline_buffer(cfg):
|
||||
else:
|
||||
raise ValueError(cfg.env)
|
||||
|
||||
num_steps = len(offline_buffer)
|
||||
index = torch.arange(0, num_steps, 1)
|
||||
sampler.extend(index)
|
||||
if not overwrite_sampler:
|
||||
num_steps = len(offline_buffer)
|
||||
index = torch.arange(0, num_steps, 1)
|
||||
sampler.extend(index)
|
||||
|
||||
return offline_buffer
|
||||
|
||||
@@ -3,9 +3,15 @@ import pickle
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import pygame
|
||||
import pymunk
|
||||
import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
||||
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.datasets.utils import _get_root_dir
|
||||
from torchrl.data.replay_buffers.replay_buffers import (
|
||||
@@ -20,12 +26,71 @@ from torchrl.data.replay_buffers.samplers import (
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
|
||||
# as define in env
|
||||
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
||||
|
||||
|
||||
def get_goal_pose_body(pose):
|
||||
mass = 1
|
||||
inertia = pymunk.moment_for_box(mass, (50, 100))
|
||||
body = pymunk.Body(mass, inertia)
|
||||
# preserving the legacy assignment order for compatibility
|
||||
# the order here doesn't matter somehow, maybe because CoM is aligned with body origin
|
||||
body.position = pose[:2].tolist()
|
||||
body.angle = pose[2]
|
||||
return body
|
||||
|
||||
|
||||
def add_segment(space, a, b, radius):
|
||||
shape = pymunk.Segment(space.static_body, a, b, radius)
|
||||
shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names
|
||||
return shape
|
||||
|
||||
|
||||
def add_tee(
|
||||
space,
|
||||
position,
|
||||
angle,
|
||||
scale=30,
|
||||
color="LightSlateGray",
|
||||
mask=pymunk.ShapeFilter.ALL_MASKS(),
|
||||
):
|
||||
mass = 1
|
||||
length = 4
|
||||
vertices1 = [
|
||||
(-length * scale / 2, scale),
|
||||
(length * scale / 2, scale),
|
||||
(length * scale / 2, 0),
|
||||
(-length * scale / 2, 0),
|
||||
]
|
||||
inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1)
|
||||
vertices2 = [
|
||||
(-scale / 2, scale),
|
||||
(-scale / 2, length * scale),
|
||||
(scale / 2, length * scale),
|
||||
(scale / 2, scale),
|
||||
]
|
||||
inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1)
|
||||
body = pymunk.Body(mass, inertia1 + inertia2)
|
||||
shape1 = pymunk.Poly(body, vertices1)
|
||||
shape2 = pymunk.Poly(body, vertices2)
|
||||
shape1.color = pygame.Color(color)
|
||||
shape2.color = pygame.Color(color)
|
||||
shape1.filter = pymunk.ShapeFilter(mask=mask)
|
||||
shape2.filter = pymunk.ShapeFilter(mask=mask)
|
||||
body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2
|
||||
body.position = position
|
||||
body.angle = angle
|
||||
body.friction = 1
|
||||
space.add(body, shape1, shape2)
|
||||
return body
|
||||
|
||||
|
||||
class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
|
||||
available_datasets = [
|
||||
"xarm_lift_medium",
|
||||
]
|
||||
# available_datasets = [
|
||||
# "xarm_lift_medium",
|
||||
# ]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -49,8 +114,6 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
split_trajs: bool = False,
|
||||
strict_length: bool = True,
|
||||
):
|
||||
# TODO
|
||||
raise NotImplementedError()
|
||||
self.download = download
|
||||
if streaming:
|
||||
raise NotImplementedError
|
||||
@@ -68,8 +131,11 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
if split_trajs:
|
||||
raise NotImplementedError
|
||||
|
||||
if self.download == True:
|
||||
raise NotImplementedError()
|
||||
|
||||
if root is None:
|
||||
root = _get_root_dir("simxarm")
|
||||
root = _get_root_dir("pusht")
|
||||
os.makedirs(root, exist_ok=True)
|
||||
self.root = Path(root)
|
||||
if self.download == "force" or (self.download and not self._is_downloaded()):
|
||||
@@ -77,29 +143,29 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
else:
|
||||
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
||||
|
||||
if num_slices is not None or slice_len is not None:
|
||||
if sampler is not None:
|
||||
raise ValueError(
|
||||
"`num_slices` and `slice_len` are exclusive with the `sampler` argument."
|
||||
)
|
||||
# if num_slices is not None or slice_len is not None:
|
||||
# if sampler is not None:
|
||||
# raise ValueError(
|
||||
# "`num_slices` and `slice_len` are exclusive with the `sampler` argument."
|
||||
# )
|
||||
|
||||
if replacement:
|
||||
if not self.shuffle:
|
||||
raise RuntimeError(
|
||||
"shuffle=False can only be used when replacement=False."
|
||||
)
|
||||
sampler = SliceSampler(
|
||||
num_slices=num_slices,
|
||||
slice_len=slice_len,
|
||||
strict_length=strict_length,
|
||||
)
|
||||
else:
|
||||
sampler = SliceSamplerWithoutReplacement(
|
||||
num_slices=num_slices,
|
||||
slice_len=slice_len,
|
||||
strict_length=strict_length,
|
||||
shuffle=self.shuffle,
|
||||
)
|
||||
# if replacement:
|
||||
# if not self.shuffle:
|
||||
# raise RuntimeError(
|
||||
# "shuffle=False can only be used when replacement=False."
|
||||
# )
|
||||
# sampler = SliceSampler(
|
||||
# num_slices=num_slices,
|
||||
# slice_len=slice_len,
|
||||
# strict_length=strict_length,
|
||||
# )
|
||||
# else:
|
||||
# sampler = SliceSamplerWithoutReplacement(
|
||||
# num_slices=num_slices,
|
||||
# slice_len=slice_len,
|
||||
# strict_length=strict_length,
|
||||
# shuffle=self.shuffle,
|
||||
# )
|
||||
|
||||
if writer is None:
|
||||
writer = ImmutableDatasetWriter()
|
||||
@@ -131,49 +197,82 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
# TODO(rcadene)
|
||||
|
||||
# load
|
||||
dataset_dir = Path("data") / self.dataset_id
|
||||
dataset_path = dataset_dir / f"buffer.pkl"
|
||||
print(f"Using offline dataset '{dataset_path}'")
|
||||
with open(dataset_path, "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
zarr_path = (
|
||||
"/home/rcadene/code/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr"
|
||||
)
|
||||
dataset_dict = ReplayBuffer.copy_from_path(
|
||||
zarr_path
|
||||
) # , keys=['img', 'state', 'action'])
|
||||
|
||||
total_frames = dataset_dict["actions"].shape[0]
|
||||
episode_ids = dataset_dict.get_episode_idxs()
|
||||
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
||||
total_frames = dataset_dict["action"].shape[0]
|
||||
assert len(
|
||||
set([dataset_dict[key].shape[0] for key in dataset_dict.keys()])
|
||||
), "Some data type dont have the same number of total frames."
|
||||
|
||||
# TODO: verify that goal pose is expected to be fixed
|
||||
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
||||
goal_body = get_goal_pose_body(goal_pos_angle)
|
||||
|
||||
idx0 = 0
|
||||
idx1 = 0
|
||||
episode_id = 0
|
||||
for i in tqdm.tqdm(range(total_frames)):
|
||||
idx1 += 1
|
||||
|
||||
if not dataset_dict["dones"][i]:
|
||||
continue
|
||||
idxtd = 0
|
||||
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||
idx1 = dataset_dict.meta["episode_ends"][episode_id]
|
||||
|
||||
num_frames = idx1 - idx0
|
||||
|
||||
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
|
||||
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
|
||||
next_image = torch.tensor(
|
||||
dataset_dict["next_observations"]["rgb"][idx0:idx1]
|
||||
)
|
||||
next_state = torch.tensor(
|
||||
dataset_dict["next_observations"]["state"][idx0:idx1]
|
||||
)
|
||||
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
|
||||
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
|
||||
assert (episode_ids[idx0:idx1] == episode_id).all()
|
||||
|
||||
image = torch.from_numpy(dataset_dict["img"][idx0:idx1])
|
||||
image = einops.rearrange(image, "b h w c -> b c h w")
|
||||
|
||||
state = torch.from_numpy(dataset_dict["state"][idx0:idx1])
|
||||
agent_pos = state[:, :2]
|
||||
block_pos = state[:, 2:4]
|
||||
block_angle = state[:, 4]
|
||||
|
||||
reward = torch.zeros(num_frames, 1)
|
||||
done = torch.zeros(num_frames, 1, dtype=torch.bool)
|
||||
for i in range(num_frames):
|
||||
space = pymunk.Space()
|
||||
space.gravity = 0, 0
|
||||
space.damping = 0
|
||||
|
||||
# Add walls.
|
||||
walls = [
|
||||
add_segment(space, (5, 506), (5, 5), 2),
|
||||
add_segment(space, (5, 5), (506, 5), 2),
|
||||
add_segment(space, (506, 5), (506, 506), 2),
|
||||
add_segment(space, (5, 506), (506, 506), 2),
|
||||
]
|
||||
space.add(*walls)
|
||||
|
||||
block_body = add_tee(
|
||||
space, block_pos[i].tolist(), block_angle[i].item()
|
||||
)
|
||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
goal_area = goal_geom.area
|
||||
coverage = intersection_area / goal_area
|
||||
reward[i] = np.clip(coverage / SUCCESS_THRESHOLD, 0, 1)
|
||||
done[i] = coverage > SUCCESS_THRESHOLD
|
||||
|
||||
episode = TensorDict(
|
||||
{
|
||||
("observation", "image"): image,
|
||||
("observation", "state"): state,
|
||||
"action": torch.tensor(dataset_dict["actions"][idx0:idx1]),
|
||||
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
("next", "observation", "image"): next_image,
|
||||
("next", "observation", "state"): next_state,
|
||||
("next", "observation", "reward"): next_reward,
|
||||
("next", "observation", "done"): next_done,
|
||||
("observation", "image"): image[:-1],
|
||||
("observation", "state"): agent_pos[:-1],
|
||||
"action": torch.from_numpy(dataset_dict["action"][idx0:idx1])[:-1],
|
||||
"episode": torch.from_numpy(episode_ids[idx0:idx1])[:-1],
|
||||
"frame_id": torch.arange(0, num_frames - 1, 1),
|
||||
("next", "observation", "image"): image[1:],
|
||||
("next", "observation", "state"): agent_pos[1:],
|
||||
# TODO: verify that reward and done are aligned with image and agent_pos
|
||||
("next", "reward"): reward[1:],
|
||||
("next", "done"): done[1:],
|
||||
},
|
||||
batch_size=num_frames,
|
||||
batch_size=num_frames - 1,
|
||||
)
|
||||
|
||||
if episode_id == 0:
|
||||
@@ -184,9 +283,9 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
.memmap_like(self.root / self.dataset_id)
|
||||
)
|
||||
|
||||
td_data[idx0:idx1] = episode
|
||||
td_data[idxtd : idxtd + len(episode)] = episode
|
||||
|
||||
episode_id += 1
|
||||
idx0 = idx1
|
||||
idxtd = idxtd + len(episode)
|
||||
|
||||
return TensorStorage(td_data.lock_())
|
||||
|
||||
Reference in New Issue
Block a user