Files
lerobot_piper/lerobot/common/datasets/simxarm.py

191 lines
6.2 KiB
Python

import os
import pickle
from pathlib import Path
from typing import Any, Callable, Dict, Tuple
import torch
import torchrl
import tqdm
from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import (
TensorDictPrioritizedReplayBuffer,
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers.samplers import (
Sampler,
SliceSampler,
SliceSamplerWithoutReplacement,
)
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
class SimxarmExperienceReplay(TensorDictReplayBuffer):
available_datasets = [
"xarm_lift_medium",
]
def __init__(
self,
dataset_id,
batch_size: int = None,
*,
shuffle: bool = True,
num_slices: int = None,
slice_len: int = None,
pad: float = None,
replacement: bool = None,
streaming: bool = False,
root: Path = None,
download: bool = False,
sampler: Sampler = None,
writer: Writer = None,
collate_fn: Callable = None,
pin_memory: bool = False,
prefetch: int = None,
transform: "torchrl.envs.Transform" = None, # noqa-F821
split_trajs: bool = False,
strict_length: bool = True,
):
self.download = download
if streaming:
raise NotImplementedError
self.streaming = streaming
self.dataset_id = dataset_id
self.split_trajs = split_trajs
self.shuffle = shuffle
self.num_slices = num_slices
self.slice_len = slice_len
self.pad = pad
self.strict_length = strict_length
if (self.num_slices is not None) and (self.slice_len is not None):
raise ValueError("num_slices or slice_len can be not None, but not both.")
if split_trajs:
raise NotImplementedError
if root is None:
root = _get_root_dir("simxarm")
os.makedirs(root, exist_ok=True)
self.root = Path(root)
if self.download == "force" or (self.download and not self._is_downloaded()):
storage = self._download_and_preproc()
else:
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
if num_slices is not None or slice_len is not None:
if sampler is not None:
raise ValueError(
"`num_slices` and `slice_len` are exclusive with the `sampler` argument."
)
if replacement:
if not self.shuffle:
raise RuntimeError(
"shuffle=False can only be used when replacement=False."
)
sampler = SliceSampler(
num_slices=num_slices,
slice_len=slice_len,
strict_length=strict_length,
)
else:
sampler = SliceSamplerWithoutReplacement(
num_slices=num_slices,
slice_len=slice_len,
strict_length=strict_length,
shuffle=self.shuffle,
)
if writer is None:
writer = ImmutableDatasetWriter()
if collate_fn is None:
collate_fn = _collate_id
super().__init__(
storage=storage,
sampler=sampler,
writer=writer,
collate_fn=collate_fn,
pin_memory=pin_memory,
prefetch=prefetch,
batch_size=batch_size,
transform=transform,
)
@property
def data_path_root(self):
if self.streaming:
return None
return self.root / self.dataset_id
def _is_downloaded(self):
return os.path.exists(self.data_path_root)
def _download_and_preproc(self):
# download
# TODO(rcadene)
# load
dataset_dir = Path("data") / self.dataset_id
dataset_path = dataset_dir / f"buffer.pkl"
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
total_frames = dataset_dict["actions"].shape[0]
idx0 = 0
idx1 = 0
episode_id = 0
for i in tqdm.tqdm(range(total_frames)):
idx1 += 1
if not dataset_dict["dones"][i]:
continue
num_frames = idx1 - idx0
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
next_image = torch.tensor(
dataset_dict["next_observations"]["rgb"][idx0:idx1]
)
next_state = torch.tensor(
dataset_dict["next_observations"]["state"][idx0:idx1]
)
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
episode = TensorDict(
{
("observation", "image"): image,
("observation", "state"): state,
"action": torch.tensor(dataset_dict["actions"][idx0:idx1]),
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1),
("next", "observation", "image"): next_image,
("next", "observation", "state"): next_state,
("next", "observation", "reward"): next_reward,
("next", "observation", "done"): next_done,
},
batch_size=num_frames,
)
if episode_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = (
episode[0]
.expand(total_frames)
.memmap_like(self.root / self.dataset_id)
)
td_data[idx0:idx1] = episode
episode_id += 1
idx0 = idx1
return TensorStorage(td_data.lock_())