Added pusht dataset auto-download

This commit is contained in:
Simon Alibert
2024-03-01 14:31:54 +01:00
parent ca948c1e5b
commit b862145e22
3 changed files with 57 additions and 28 deletions

View File

@@ -9,8 +9,6 @@ import pymunk
import torch
import torchrl
import tqdm
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
@@ -18,10 +16,16 @@ from torchrl.data.replay_buffers.samplers import Sampler
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
from lerobot.common import utils
# as define in env
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
DEFAULT_TEE_MASK = pymunk.ShapeFilter.ALL_MASKS()
PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
def get_goal_pose_body(pose):
@@ -83,7 +87,7 @@ def add_tee(
class PushtExperienceReplay(TensorDictReplayBuffer):
def __init__(
self,
dataset_id,
dataset_id: str,
batch_size: int = None,
*,
shuffle: bool = True,
@@ -93,7 +97,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
replacement: bool = None,
streaming: bool = False,
root: Path = None,
download: bool = False,
download: bool | str = False,
sampler: Sampler = None,
writer: Writer = None,
collate_fn: Callable = None,
@@ -120,13 +124,12 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
if split_trajs:
raise NotImplementedError
if self.download:
raise NotImplementedError()
if root is None:
root = _get_root_dir("pusht")
os.makedirs(root, exist_ok=True)
self.root = Path(root)
self.root = root
self.raw = self.root / "raw"
if self.download == "force" or (self.download and not self._is_downloaded()):
storage = self._download_and_preproc()
else:
@@ -173,39 +176,34 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
)
@property
def num_samples(self):
def num_samples(self) -> int:
return len(self)
@property
def num_episodes(self):
def num_episodes(self) -> int:
return len(self._storage._storage["episode"].unique())
@property
def data_path_root(self):
if self.streaming:
return None
return self.root / self.dataset_id
def data_path_root(self) -> Path:
return None if self.streaming else self.root / self.dataset_id
def _is_downloaded(self):
return os.path.exists(self.data_path_root)
def _is_downloaded(self) -> bool:
return self.data_path_root.is_dir()
def _download_and_preproc(self):
# download
# TODO(rcadene)
self.raw.mkdir(exist_ok=True)
utils.download_and_extract_zip(PUSHT_URL, self.raw)
zarr_path = (self.raw / PUSHT_ZARR).resolve()
# load
# TODO(aliberts): Dynamic paths
zarr_path = (
"/home/rcadene/code/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr"
# "/home/simon/build/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr"
)
dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
episode_ids = dataset_dict.get_episode_idxs()
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
total_frames = dataset_dict["action"].shape[0]
assert len(
{dataset_dict[key].shape[0] for key in dataset_dict}
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
), "Some data type dont have the same number of total frames."
# TODO: verify that goal pose is expected to be fixed