Ran pre-commit run --all-files
This commit is contained in:
@@ -70,6 +70,7 @@ def make_offline_buffer(cfg, sampler=None):
|
||||
offline_buffer = PushtExperienceReplay(
|
||||
"pusht",
|
||||
# download="force",
|
||||
# TODO(aliberts): automate download
|
||||
download=False,
|
||||
streaming=False,
|
||||
root="data",
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
@@ -10,25 +9,25 @@ import pymunk
|
||||
import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
||||
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.datasets.utils import _get_root_dir
|
||||
from torchrl.data.replay_buffers.replay_buffers import (
|
||||
TensorDictPrioritizedReplayBuffer,
|
||||
TensorDictReplayBuffer,
|
||||
)
|
||||
from torchrl.data.replay_buffers.samplers import (
|
||||
Sampler,
|
||||
SliceSampler,
|
||||
SliceSamplerWithoutReplacement,
|
||||
)
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
||||
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
||||
|
||||
# as define in env
|
||||
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
||||
|
||||
DEFAULT_TEE_MASK = pymunk.ShapeFilter.ALL_MASKS()
|
||||
|
||||
|
||||
def get_goal_pose_body(pose):
|
||||
mass = 1
|
||||
@@ -53,7 +52,7 @@ def add_tee(
|
||||
angle,
|
||||
scale=30,
|
||||
color="LightSlateGray",
|
||||
mask=pymunk.ShapeFilter.ALL_MASKS(),
|
||||
mask=DEFAULT_TEE_MASK,
|
||||
):
|
||||
mass = 1
|
||||
length = 4
|
||||
@@ -87,7 +86,6 @@ def add_tee(
|
||||
|
||||
|
||||
class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id,
|
||||
@@ -127,7 +125,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
if split_trajs:
|
||||
raise NotImplementedError
|
||||
|
||||
if self.download == True:
|
||||
if self.download:
|
||||
raise NotImplementedError()
|
||||
|
||||
if root is None:
|
||||
@@ -193,18 +191,18 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
# TODO(rcadene)
|
||||
|
||||
# load
|
||||
# TODO(aliberts): Dynamic paths
|
||||
zarr_path = (
|
||||
"/home/rcadene/code/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr"
|
||||
# "/home/simon/build/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr"
|
||||
)
|
||||
dataset_dict = ReplayBuffer.copy_from_path(
|
||||
zarr_path
|
||||
) # , keys=['img', 'state', 'action'])
|
||||
dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
|
||||
|
||||
episode_ids = dataset_dict.get_episode_idxs()
|
||||
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
||||
total_frames = dataset_dict["action"].shape[0]
|
||||
assert len(
|
||||
set([dataset_dict[key].shape[0] for key in dataset_dict.keys()])
|
||||
{dataset_dict[key].shape[0] for key in dataset_dict}
|
||||
), "Some data type dont have the same number of total frames."
|
||||
|
||||
# TODO: verify that goal pose is expected to be fixed
|
||||
@@ -245,9 +243,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
]
|
||||
space.add(*walls)
|
||||
|
||||
block_body = add_tee(
|
||||
space, block_pos[i].tolist(), block_angle[i].item()
|
||||
)
|
||||
block_body = add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
@@ -278,11 +274,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = (
|
||||
episode[0]
|
||||
.expand(total_frames)
|
||||
.memmap_like(self.root / self.dataset_id)
|
||||
)
|
||||
td_data = episode[0].expand(total_frames).memmap_like(self.root / self.dataset_id)
|
||||
|
||||
td_data[idxtd : idxtd + len(episode)] = episode
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torchrl
|
||||
@@ -9,7 +9,6 @@ import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.datasets.utils import _get_root_dir
|
||||
from torchrl.data.replay_buffers.replay_buffers import (
|
||||
TensorDictPrioritizedReplayBuffer,
|
||||
TensorDictReplayBuffer,
|
||||
)
|
||||
from torchrl.data.replay_buffers.samplers import (
|
||||
@@ -22,7 +21,6 @@ from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
|
||||
|
||||
class SimxarmExperienceReplay(TensorDictReplayBuffer):
|
||||
|
||||
available_datasets = [
|
||||
"xarm_lift_medium",
|
||||
]
|
||||
@@ -77,15 +75,11 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer):
|
||||
|
||||
if num_slices is not None or slice_len is not None:
|
||||
if sampler is not None:
|
||||
raise ValueError(
|
||||
"`num_slices` and `slice_len` are exclusive with the `sampler` argument."
|
||||
)
|
||||
raise ValueError("`num_slices` and `slice_len` are exclusive with the `sampler` argument.")
|
||||
|
||||
if replacement:
|
||||
if not self.shuffle:
|
||||
raise RuntimeError(
|
||||
"shuffle=False can only be used when replacement=False."
|
||||
)
|
||||
raise RuntimeError("shuffle=False can only be used when replacement=False.")
|
||||
sampler = SliceSampler(
|
||||
num_slices=num_slices,
|
||||
slice_len=slice_len,
|
||||
@@ -130,7 +124,7 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer):
|
||||
|
||||
# load
|
||||
dataset_dir = Path("data") / self.dataset_id
|
||||
dataset_path = dataset_dir / f"buffer.pkl"
|
||||
dataset_path = dataset_dir / "buffer.pkl"
|
||||
print(f"Using offline dataset '{dataset_path}'")
|
||||
with open(dataset_path, "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
@@ -150,12 +144,8 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer):
|
||||
|
||||
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
|
||||
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
|
||||
next_image = torch.tensor(
|
||||
dataset_dict["next_observations"]["rgb"][idx0:idx1]
|
||||
)
|
||||
next_state = torch.tensor(
|
||||
dataset_dict["next_observations"]["state"][idx0:idx1]
|
||||
)
|
||||
next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
|
||||
next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
|
||||
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
|
||||
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
|
||||
|
||||
@@ -176,11 +166,7 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer):
|
||||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = (
|
||||
episode[0]
|
||||
.expand(total_frames)
|
||||
.memmap_like(self.root / self.dataset_id)
|
||||
)
|
||||
td_data = episode[0].expand(total_frames).memmap_like(self.root / self.dataset_id)
|
||||
|
||||
td_data[idx0:idx1] = episode
|
||||
|
||||
|
||||
Reference in New Issue
Block a user