Ran pre-commit run --all-files

This commit is contained in:
Simon Alibert
2024-02-29 13:37:48 +01:00
parent 0e0e760e43
commit 7e024fdce6
16 changed files with 124 additions and 237 deletions

View File

@@ -1,7 +1,6 @@
import os
import pickle
from pathlib import Path
from typing import Any, Callable, Dict, Tuple
from typing import Callable
import einops
import numpy as np
@@ -10,25 +9,25 @@ import pymunk
import torch
import torchrl
import tqdm
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import (
TensorDictPrioritizedReplayBuffer,
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers.samplers import (
Sampler,
SliceSampler,
SliceSamplerWithoutReplacement,
)
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
# as define in env
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
DEFAULT_TEE_MASK = pymunk.ShapeFilter.ALL_MASKS()
def get_goal_pose_body(pose):
mass = 1
@@ -53,7 +52,7 @@ def add_tee(
angle,
scale=30,
color="LightSlateGray",
mask=pymunk.ShapeFilter.ALL_MASKS(),
mask=DEFAULT_TEE_MASK,
):
mass = 1
length = 4
@@ -87,7 +86,6 @@ def add_tee(
class PushtExperienceReplay(TensorDictReplayBuffer):
def __init__(
self,
dataset_id,
@@ -127,7 +125,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
if split_trajs:
raise NotImplementedError
if self.download == True:
if self.download:
raise NotImplementedError()
if root is None:
@@ -193,18 +191,18 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
# TODO(rcadene)
# load
# TODO(aliberts): Dynamic paths
zarr_path = (
"/home/rcadene/code/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr"
# "/home/simon/build/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr"
)
dataset_dict = ReplayBuffer.copy_from_path(
zarr_path
) # , keys=['img', 'state', 'action'])
dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
episode_ids = dataset_dict.get_episode_idxs()
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
total_frames = dataset_dict["action"].shape[0]
assert len(
set([dataset_dict[key].shape[0] for key in dataset_dict.keys()])
{dataset_dict[key].shape[0] for key in dataset_dict}
), "Some data type dont have the same number of total frames."
# TODO: verify that goal pose is expected to be fixed
@@ -245,9 +243,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
]
space.add(*walls)
block_body = add_tee(
space, block_pos[i].tolist(), block_angle[i].item()
)
block_body = add_tee(space, block_pos[i].tolist(), block_angle[i].item())
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
intersection_area = goal_geom.intersection(block_geom).area
@@ -278,11 +274,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
if episode_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = (
episode[0]
.expand(total_frames)
.memmap_like(self.root / self.dataset_id)
)
td_data = episode[0].expand(total_frames).memmap_like(self.root / self.dataset_id)
td_data[idxtd : idxtd + len(episode)] = episode