Refactor datasets with abstract class
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
@@ -7,130 +6,52 @@ import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.datasets.utils import _get_root_dir
|
||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers.samplers import (
|
||||
Sampler,
|
||||
SliceSampler,
|
||||
SliceSamplerWithoutReplacement,
|
||||
)
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
||||
|
||||
|
||||
class SimxarmExperienceReplay(TensorDictReplayBuffer):
|
||||
class SimxarmExperienceReplay(AbstractExperienceReplay):
|
||||
available_datasets = [
|
||||
"xarm_lift_medium",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id,
|
||||
dataset_id: str,
|
||||
batch_size: int = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
num_slices: int = None,
|
||||
slice_len: int = None,
|
||||
pad: float = None,
|
||||
replacement: bool = None,
|
||||
streaming: bool = False,
|
||||
root: Path = None,
|
||||
download: bool = False,
|
||||
sampler: Sampler = None,
|
||||
writer: Writer = None,
|
||||
collate_fn: Callable = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
transform: "torchrl.envs.Transform" = None, # noqa-F821
|
||||
split_trajs: bool = False,
|
||||
strict_length: bool = True,
|
||||
sampler: SliceSampler = None,
|
||||
collate_fn: Callable = None,
|
||||
writer: Writer = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
self.download = download
|
||||
if streaming:
|
||||
raise NotImplementedError
|
||||
self.streaming = streaming
|
||||
self.dataset_id = dataset_id
|
||||
self.split_trajs = split_trajs
|
||||
self.shuffle = shuffle
|
||||
self.num_slices = num_slices
|
||||
self.slice_len = slice_len
|
||||
self.pad = pad
|
||||
|
||||
self.strict_length = strict_length
|
||||
if (self.num_slices is not None) and (self.slice_len is not None):
|
||||
raise ValueError("num_slices or slice_len can be not None, but not both.")
|
||||
if split_trajs:
|
||||
raise NotImplementedError
|
||||
|
||||
if root is None:
|
||||
root = _get_root_dir("simxarm")
|
||||
os.makedirs(root, exist_ok=True)
|
||||
self.root = Path(root)
|
||||
if self.download == "force" or (self.download and not self._is_downloaded()):
|
||||
storage = self._download_and_preproc()
|
||||
else:
|
||||
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
||||
|
||||
if num_slices is not None or slice_len is not None:
|
||||
if sampler is not None:
|
||||
raise ValueError("`num_slices` and `slice_len` are exclusive with the `sampler` argument.")
|
||||
|
||||
if replacement:
|
||||
if not self.shuffle:
|
||||
raise RuntimeError("shuffle=False can only be used when replacement=False.")
|
||||
sampler = SliceSampler(
|
||||
num_slices=num_slices,
|
||||
slice_len=slice_len,
|
||||
strict_length=strict_length,
|
||||
)
|
||||
else:
|
||||
sampler = SliceSamplerWithoutReplacement(
|
||||
num_slices=num_slices,
|
||||
slice_len=slice_len,
|
||||
strict_length=strict_length,
|
||||
shuffle=self.shuffle,
|
||||
)
|
||||
|
||||
if writer is None:
|
||||
writer = ImmutableDatasetWriter()
|
||||
if collate_fn is None:
|
||||
collate_fn = _collate_id
|
||||
|
||||
super().__init__(
|
||||
storage=storage,
|
||||
sampler=sampler,
|
||||
writer=writer,
|
||||
collate_fn=collate_fn,
|
||||
dataset_id,
|
||||
batch_size,
|
||||
shuffle=shuffle,
|
||||
root=root,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
@property
|
||||
def num_samples(self):
|
||||
return len(self)
|
||||
|
||||
@property
|
||||
def num_episodes(self):
|
||||
return len(self._storage._storage["episode"].unique())
|
||||
|
||||
@property
|
||||
def data_path_root(self):
|
||||
if self.streaming:
|
||||
return None
|
||||
return self.root / self.dataset_id
|
||||
|
||||
def _is_downloaded(self):
|
||||
return os.path.exists(self.data_path_root)
|
||||
|
||||
def _download_and_preproc(self):
|
||||
# download
|
||||
# TODO(rcadene)
|
||||
|
||||
# load
|
||||
dataset_dir = Path("data") / self.dataset_id
|
||||
dataset_path = dataset_dir / "buffer.pkl"
|
||||
dataset_path = self.data_dir / "buffer.pkl"
|
||||
print(f"Using offline dataset '{dataset_path}'")
|
||||
with open(dataset_path, "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
@@ -172,7 +93,7 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer):
|
||||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = episode[0].expand(total_frames).memmap_like(self.root / self.dataset_id)
|
||||
td_data = episode[0].expand(total_frames).memmap_like(self.data_dir)
|
||||
|
||||
td_data[idx0:idx1] = episode
|
||||
|
||||
|
||||
Reference in New Issue
Block a user