Fix unit tests, Refactor, Add pusht env, (TODO pusht replay buffer, image preprocessing)

This commit is contained in:
Cadene
2024-02-20 12:26:57 +00:00
parent fdfb2010fd
commit 3da6ffb2cb
10 changed files with 559 additions and 89 deletions

View File

@@ -0,0 +1,47 @@
import torch
from lerobot.common.datasets.pusht import PushtExperienceReplay
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
from rl.torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler
def make_offline_buffer(cfg):
num_traj_per_batch = cfg.batch_size # // cfg.horizon
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
sampler = PrioritizedSliceSampler(
max_capacity=100_000,
alpha=cfg.per_alpha,
beta=cfg.per_beta,
num_slices=num_traj_per_batch,
strict_length=False,
)
if cfg.env == "simxarm":
# TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here
offline_buffer = SimxarmExperienceReplay(
f"xarm_{cfg.task}_medium",
# download="force",
download=True,
streaming=False,
root="data",
sampler=sampler,
)
elif cfg.env == "pusht":
offline_buffer = PushtExperienceReplay(
f"xarm_{cfg.task}_medium",
# download="force",
download=True,
streaming=False,
root="data",
sampler=sampler,
)
else:
raise ValueError(cfg.env)
num_steps = len(offline_buffer)
index = torch.arange(0, num_steps, 1)
sampler.extend(index)
return offline_buffer

View File

@@ -0,0 +1,192 @@
import os
import pickle
from pathlib import Path
from typing import Any, Callable, Dict, Tuple
import torch
import torchrl
import tqdm
from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import (
TensorDictPrioritizedReplayBuffer,
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers.samplers import (
Sampler,
SliceSampler,
SliceSamplerWithoutReplacement,
)
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
class PushtExperienceReplay(TensorDictReplayBuffer):
available_datasets = [
"xarm_lift_medium",
]
def __init__(
self,
dataset_id,
batch_size: int = None,
*,
shuffle: bool = True,
num_slices: int = None,
slice_len: int = None,
pad: float = None,
replacement: bool = None,
streaming: bool = False,
root: Path = None,
download: bool = False,
sampler: Sampler = None,
writer: Writer = None,
collate_fn: Callable = None,
pin_memory: bool = False,
prefetch: int = None,
transform: "torchrl.envs.Transform" = None, # noqa-F821
split_trajs: bool = False,
strict_length: bool = True,
):
# TODO
raise NotImplementedError()
self.download = download
if streaming:
raise NotImplementedError
self.streaming = streaming
self.dataset_id = dataset_id
self.split_trajs = split_trajs
self.shuffle = shuffle
self.num_slices = num_slices
self.slice_len = slice_len
self.pad = pad
self.strict_length = strict_length
if (self.num_slices is not None) and (self.slice_len is not None):
raise ValueError("num_slices or slice_len can be not None, but not both.")
if split_trajs:
raise NotImplementedError
if root is None:
root = _get_root_dir("simxarm")
os.makedirs(root, exist_ok=True)
self.root = Path(root)
if self.download == "force" or (self.download and not self._is_downloaded()):
storage = self._download_and_preproc()
else:
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
if num_slices is not None or slice_len is not None:
if sampler is not None:
raise ValueError(
"`num_slices` and `slice_len` are exclusive with the `sampler` argument."
)
if replacement:
if not self.shuffle:
raise RuntimeError(
"shuffle=False can only be used when replacement=False."
)
sampler = SliceSampler(
num_slices=num_slices,
slice_len=slice_len,
strict_length=strict_length,
)
else:
sampler = SliceSamplerWithoutReplacement(
num_slices=num_slices,
slice_len=slice_len,
strict_length=strict_length,
shuffle=self.shuffle,
)
if writer is None:
writer = ImmutableDatasetWriter()
if collate_fn is None:
collate_fn = _collate_id
super().__init__(
storage=storage,
sampler=sampler,
writer=writer,
collate_fn=collate_fn,
pin_memory=pin_memory,
prefetch=prefetch,
batch_size=batch_size,
transform=transform,
)
@property
def data_path_root(self):
if self.streaming:
return None
return self.root / self.dataset_id
def _is_downloaded(self):
return os.path.exists(self.data_path_root)
def _download_and_preproc(self):
# download
# TODO(rcadene)
# load
dataset_dir = Path("data") / self.dataset_id
dataset_path = dataset_dir / f"buffer.pkl"
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
total_frames = dataset_dict["actions"].shape[0]
idx0 = 0
idx1 = 0
episode_id = 0
for i in tqdm.tqdm(range(total_frames)):
idx1 += 1
if not dataset_dict["dones"][i]:
continue
num_frames = idx1 - idx0
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
next_image = torch.tensor(
dataset_dict["next_observations"]["rgb"][idx0:idx1]
)
next_state = torch.tensor(
dataset_dict["next_observations"]["state"][idx0:idx1]
)
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
episode = TensorDict(
{
("observation", "image"): image,
("observation", "state"): state,
"action": torch.tensor(dataset_dict["actions"][idx0:idx1]),
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1),
("next", "observation", "image"): next_image,
("next", "observation", "state"): next_state,
("next", "observation", "reward"): next_reward,
("next", "observation", "done"): next_done,
},
batch_size=num_frames,
)
if episode_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = (
episode[0]
.expand(total_frames)
.memmap_like(self.root / self.dataset_id)
)
td_data[idx0:idx1] = episode
episode_id += 1
idx0 = idx1
return TensorStorage(td_data.lock_())