Merge remote-tracking branch 'Cadene/user/rcadene/2024_03_31_remove_torchrl' into refactor_act_remove_torchrl
This commit is contained in:
@@ -1,234 +0,0 @@
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from huggingface_hub import snapshot_download
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers.samplers import Sampler, SamplerWithoutReplacement
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
from torchrl.envs.transforms.transforms import Compose
|
||||
|
||||
HF_USER = "lerobot"
|
||||
|
||||
|
||||
class AbstractDataset(TensorDictReplayBuffer):
|
||||
"""
|
||||
AbstractDataset represents a dataset in the context of imitation learning or reinforcement learning.
|
||||
This class is designed to be subclassed by concrete implementations that specify particular types of datasets.
|
||||
These implementations can vary based on the source of the data, the environment the data pertains to,
|
||||
or the specific kind of data manipulation applied.
|
||||
|
||||
Note:
|
||||
- `TensorDictReplayBuffer` is the base class from which `AbstractDataset` inherits. It provides the foundational
|
||||
functionality for storing and retrieving `TensorDict`-like data.
|
||||
- `available_datasets` should be overridden by concrete subclasses to list the specific dataset variants supported.
|
||||
It is expected that these variants correspond to a HuggingFace dataset on the hub.
|
||||
For instance, the `AlohaDataset` which inherites from `AbstractDataset` has 4 available dataset variants:
|
||||
- [aloha_sim_transfer_cube_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
|
||||
- [aloha_sim_insertion_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
|
||||
- [aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
|
||||
- [aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
|
||||
- When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||
1. set the required class attributes:
|
||||
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
||||
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
||||
- for classes inheriting from `AbstractPolicy`: `name`
|
||||
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||
3. update variables in `tests/test_available.py` by importing your new class
|
||||
"""
|
||||
|
||||
available_datasets: list[str] | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = None,
|
||||
batch_size: int | None = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: Sampler | None = None,
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
assert (
|
||||
self.available_datasets is not None
|
||||
), "Subclasses of `AbstractDataset` should set the `available_datasets` class attribute."
|
||||
assert (
|
||||
dataset_id in self.available_datasets
|
||||
), f"The provided dataset ({dataset_id}) is not on the list of available datasets {self.available_datasets}."
|
||||
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.shuffle = shuffle
|
||||
self.root = root if root is None else Path(root)
|
||||
|
||||
if self.root is not None and self.version is not None:
|
||||
logging.warning(
|
||||
f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})."
|
||||
)
|
||||
|
||||
storage = self._download_or_load_dataset()
|
||||
|
||||
super().__init__(
|
||||
storage=storage,
|
||||
sampler=sampler,
|
||||
writer=ImmutableDatasetWriter() if writer is None else writer,
|
||||
collate_fn=_collate_id if collate_fn is None else collate_fn,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
batch_size=batch_size,
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
@property
|
||||
def stats_patterns(self) -> dict:
|
||||
return {
|
||||
("observation", "state"): "b c -> c",
|
||||
("observation", "image"): "b c h w -> c 1 1",
|
||||
("action",): "b c -> c",
|
||||
}
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list:
|
||||
return [("observation", "image")]
|
||||
|
||||
@property
|
||||
def num_cameras(self) -> int:
|
||||
return len(self.image_keys)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self._storage._storage["episode"].unique())
|
||||
|
||||
@property
|
||||
def transform(self):
|
||||
return self._transform
|
||||
|
||||
def set_transform(self, transform):
|
||||
if not isinstance(transform, Compose):
|
||||
# required since torchrl calls `len(self._transform)` downstream
|
||||
if isinstance(transform, list):
|
||||
self._transform = Compose(*transform)
|
||||
else:
|
||||
self._transform = Compose(transform)
|
||||
else:
|
||||
self._transform = transform
|
||||
|
||||
def compute_or_load_stats(self, batch_size: int = 32) -> TensorDict:
|
||||
stats_path = self.data_dir / "stats.pth"
|
||||
if stats_path.exists():
|
||||
stats = torch.load(stats_path)
|
||||
else:
|
||||
logging.info(f"compute_stats and save to {stats_path}")
|
||||
stats = self._compute_stats(batch_size)
|
||||
torch.save(stats, stats_path)
|
||||
return stats
|
||||
|
||||
def _download_or_load_dataset(self) -> torch.StorageBase:
|
||||
if self.root is None:
|
||||
self.data_dir = Path(
|
||||
snapshot_download(
|
||||
repo_id=f"{HF_USER}/{self.dataset_id}", repo_type="dataset", revision=self.version
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.data_dir = self.root / self.dataset_id
|
||||
return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer"))
|
||||
|
||||
def _compute_stats(self, batch_size: int = 32):
|
||||
"""Compute dataset statistics including minimum, maximum, mean, and standard deviation.
|
||||
|
||||
TODO(alexander-soare): Add a num_batches argument which essentially allows one to use a subset of the
|
||||
full dataset (for handling very large datasets). The sampling would then have to be random
|
||||
(preferably without replacement). Both stats computation loops would ideally sample the same
|
||||
items.
|
||||
"""
|
||||
rb = TensorDictReplayBuffer(
|
||||
storage=self._storage,
|
||||
batch_size=32,
|
||||
prefetch=True,
|
||||
# Note: Due to be refactored soon. The point is that we should go through the whole dataset.
|
||||
sampler=SamplerWithoutReplacement(drop_last=False, shuffle=False),
|
||||
)
|
||||
|
||||
# mean and std will be computed incrementally while max and min will track the running value.
|
||||
mean, std, max, min = {}, {}, {}, {}
|
||||
for key in self.stats_patterns:
|
||||
mean[key] = torch.tensor(0.0).float()
|
||||
std[key] = torch.tensor(0.0).float()
|
||||
max[key] = torch.tensor(-float("inf")).float()
|
||||
min[key] = torch.tensor(float("inf")).float()
|
||||
|
||||
# Compute mean, min, max.
|
||||
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
|
||||
# surprises when rerunning the sampler.
|
||||
first_batch = None
|
||||
running_item_count = 0 # for online mean computation
|
||||
for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))):
|
||||
batch = rb.sample()
|
||||
this_batch_size = batch.batch_size[0]
|
||||
running_item_count += this_batch_size
|
||||
if first_batch is None:
|
||||
first_batch = deepcopy(batch)
|
||||
for key, pattern in self.stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
# Numerically stable update step for mean computation.
|
||||
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
||||
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
|
||||
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
|
||||
# and x is the current batch mean. Some rearrangement is then required to avoid risking
|
||||
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
|
||||
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
|
||||
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
|
||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||
|
||||
# Compute std.
|
||||
first_batch_ = None
|
||||
running_item_count = 0 # for online std computation
|
||||
for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))):
|
||||
batch = rb.sample()
|
||||
this_batch_size = batch.batch_size[0]
|
||||
running_item_count += this_batch_size
|
||||
# Sanity check to make sure the batches are still in the same order as before.
|
||||
if first_batch_ is None:
|
||||
first_batch_ = deepcopy(batch)
|
||||
for key in self.stats_patterns:
|
||||
assert torch.equal(first_batch_[key], first_batch[key])
|
||||
for key, pattern in self.stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
# Numerically stable update step for mean computation (where the mean is over squared
|
||||
# residuals).See notes in the mean computation loop above.
|
||||
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
|
||||
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
|
||||
|
||||
for key in self.stats_patterns:
|
||||
std[key] = torch.sqrt(std[key])
|
||||
|
||||
stats = TensorDict({}, batch_size=[])
|
||||
for key in self.stats_patterns:
|
||||
stats[(*key, "mean")] = mean[key]
|
||||
stats[(*key, "std")] = std[key]
|
||||
stats[(*key, "max")] = max[key]
|
||||
stats[(*key, "min")] = min[key]
|
||||
|
||||
if key[0] == "observation":
|
||||
# use same stats for the next observations
|
||||
stats[("next", *key)] = stats[key]
|
||||
return stats
|
||||
@@ -1,26 +1,13 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
import gdown
|
||||
import h5py
|
||||
import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers.samplers import Sampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from lerobot.common.datasets.abstract import AbstractDataset
|
||||
|
||||
DATASET_IDS = [
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
"aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
]
|
||||
from lerobot.common.datasets.utils import load_data_with_delta_timestamps
|
||||
|
||||
FOLDER_URLS = {
|
||||
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
|
||||
@@ -66,7 +53,6 @@ CAMERAS = {
|
||||
|
||||
|
||||
def download(data_dir, dataset_id):
|
||||
assert dataset_id in DATASET_IDS
|
||||
assert dataset_id in FOLDER_URLS
|
||||
assert dataset_id in EP48_URLS
|
||||
assert dataset_id in EP49_URLS
|
||||
@@ -80,51 +66,80 @@ def download(data_dir, dataset_id):
|
||||
gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True)
|
||||
|
||||
|
||||
class AlohaDataset(AbstractDataset):
|
||||
available_datasets = DATASET_IDS
|
||||
class AlohaDataset(torch.utils.data.Dataset):
|
||||
available_datasets = [
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
"aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
]
|
||||
fps = 50
|
||||
image_keys = ["observation.images.top"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.2",
|
||||
batch_size: int | None = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: Sampler | None = None,
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
version,
|
||||
batch_size,
|
||||
shuffle=shuffle,
|
||||
root=root,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
)
|
||||
super().__init__()
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
|
||||
self.data_dir = self.root / f"{self.dataset_id}"
|
||||
if (self.data_dir / "data_dict.pth").exists() and (
|
||||
self.data_dir / "data_ids_per_episode.pth"
|
||||
).exists():
|
||||
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
|
||||
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
|
||||
else:
|
||||
self._download_and_preproc_obsolete()
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
|
||||
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
|
||||
|
||||
@property
|
||||
def stats_patterns(self) -> dict:
|
||||
d = {
|
||||
("observation", "state"): "b c -> c",
|
||||
("action",): "b c -> c",
|
||||
}
|
||||
for cam in CAMERAS[self.dataset_id]:
|
||||
d[("observation", "image", cam)] = "b c h w -> c 1 1"
|
||||
return d
|
||||
def num_samples(self) -> int:
|
||||
return len(self.data_dict["index"])
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list:
|
||||
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.data_ids_per_episode)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = {}
|
||||
|
||||
# get episode id and timestamp of the sampled frame
|
||||
current_ts = self.data_dict["timestamp"][idx].item()
|
||||
episode = self.data_dict["episode"][idx].item()
|
||||
|
||||
for key in self.data_dict:
|
||||
if self.delta_timestamps is not None and key in self.delta_timestamps:
|
||||
data, is_pad = load_data_with_delta_timestamps(
|
||||
self.data_dict,
|
||||
self.data_ids_per_episode,
|
||||
self.delta_timestamps,
|
||||
key,
|
||||
current_ts,
|
||||
episode,
|
||||
)
|
||||
item[key] = data
|
||||
item[f"{key}_is_pad"] = is_pad
|
||||
else:
|
||||
item[key] = self.data_dict[key][idx]
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
return item
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
assert self.root is not None
|
||||
@@ -132,54 +147,55 @@ class AlohaDataset(AbstractDataset):
|
||||
if not raw_dir.is_dir():
|
||||
download(raw_dir, self.dataset_id)
|
||||
|
||||
total_num_frames = 0
|
||||
total_frames = 0
|
||||
logging.info("Compute total number of frames to initialize offline buffer")
|
||||
for ep_id in range(NUM_EPISODES[self.dataset_id]):
|
||||
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
total_num_frames += ep["/action"].shape[0] - 1
|
||||
logging.info(f"{total_num_frames=}")
|
||||
total_frames += ep["/action"].shape[0] - 1
|
||||
logging.info(f"{total_frames=}")
|
||||
|
||||
self.data_ids_per_episode = {}
|
||||
ep_dicts = []
|
||||
|
||||
logging.info("Initialize and feed offline buffer")
|
||||
idxtd = 0
|
||||
for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])):
|
||||
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
ep_num_frames = ep["/action"].shape[0]
|
||||
num_frames = ep["/action"].shape[0]
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done = torch.zeros(ep_num_frames, 1, dtype=torch.bool)
|
||||
done = torch.zeros(num_frames, 1, dtype=torch.bool)
|
||||
done[-1] = True
|
||||
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
|
||||
ep_td = TensorDict(
|
||||
{
|
||||
("observation", "state"): state[:-1],
|
||||
"action": action[:-1],
|
||||
"episode": torch.tensor([ep_id] * (ep_num_frames - 1)),
|
||||
"frame_id": torch.arange(0, ep_num_frames - 1, 1),
|
||||
("next", "observation", "state"): state[1:],
|
||||
# TODO: compute reward and success
|
||||
# ("next", "reward"): reward[1:],
|
||||
("next", "done"): done[1:],
|
||||
# ("next", "success"): success[1:],
|
||||
},
|
||||
batch_size=ep_num_frames - 1,
|
||||
)
|
||||
ep_dict = {
|
||||
"observation.state": state,
|
||||
"action": action,
|
||||
"episode": torch.tensor([ep_id] * num_frames),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||
# "next.observation.state": state,
|
||||
# TODO(rcadene): compute reward and success
|
||||
# "next.reward": reward[1:],
|
||||
"next.done": done[1:],
|
||||
# "next.success": success[1:],
|
||||
}
|
||||
|
||||
for cam in CAMERAS[self.dataset_id]:
|
||||
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
|
||||
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
|
||||
ep_td["observation", "image", cam] = image[:-1]
|
||||
ep_td["next", "observation", "image", cam] = image[1:]
|
||||
ep_dict[f"observation.images.{cam}"] = image[:-1]
|
||||
# ep_dict[f"next.observation.images.{cam}"] = image[1:]
|
||||
|
||||
if ep_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = ep_td[0].expand(total_num_frames).memmap_like(self.root / f"{self.dataset_id}")
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
td_data[idxtd : idxtd + len(ep_td)] = ep_td
|
||||
idxtd = idxtd + len(ep_td)
|
||||
self.data_dict = {}
|
||||
|
||||
return TensorStorage(td_data.lock_())
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
|
||||
self.data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
|
||||
from torchvision.transforms import v2
|
||||
|
||||
from lerobot.common.datasets.utils import compute_or_load_stats
|
||||
from lerobot.common.transforms import NormalizeTransform, Prod
|
||||
|
||||
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
|
||||
@@ -13,57 +13,12 @@ from lerobot.common.transforms import NormalizeTransform, Prod
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
|
||||
def make_offline_buffer(
|
||||
def make_dataset(
|
||||
cfg,
|
||||
overwrite_sampler=None,
|
||||
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
|
||||
normalize=True,
|
||||
overwrite_batch_size=None,
|
||||
overwrite_prefetch=None,
|
||||
stats_path=None,
|
||||
):
|
||||
if cfg.policy.balanced_sampling:
|
||||
assert cfg.online_steps > 0
|
||||
batch_size = None
|
||||
pin_memory = False
|
||||
prefetch = None
|
||||
else:
|
||||
assert cfg.online_steps == 0
|
||||
num_slices = cfg.policy.batch_size
|
||||
batch_size = cfg.policy.horizon * num_slices
|
||||
pin_memory = cfg.device == "cuda"
|
||||
prefetch = cfg.prefetch
|
||||
|
||||
if overwrite_batch_size is not None:
|
||||
batch_size = overwrite_batch_size
|
||||
|
||||
if overwrite_prefetch is not None:
|
||||
prefetch = overwrite_prefetch
|
||||
|
||||
if overwrite_sampler is None:
|
||||
# TODO(rcadene): move batch_size outside
|
||||
num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon
|
||||
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
|
||||
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
|
||||
|
||||
if cfg.offline_prioritized_sampler:
|
||||
logging.info("use prioritized sampler for offline dataset")
|
||||
sampler = PrioritizedSliceSampler(
|
||||
max_capacity=100_000,
|
||||
alpha=cfg.policy.per_alpha,
|
||||
beta=cfg.policy.per_beta,
|
||||
num_slices=num_traj_per_batch,
|
||||
strict_length=False,
|
||||
)
|
||||
else:
|
||||
logging.info("use simple sampler for offline dataset")
|
||||
sampler = SliceSampler(
|
||||
num_slices=num_traj_per_batch,
|
||||
strict_length=False,
|
||||
)
|
||||
else:
|
||||
sampler = overwrite_sampler
|
||||
|
||||
if cfg.env.name == "simxarm":
|
||||
from lerobot.common.datasets.simxarm import SimxarmDataset
|
||||
|
||||
@@ -81,47 +36,29 @@ def make_offline_buffer(
|
||||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
offline_buffer = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
root=DATA_DIR,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
)
|
||||
|
||||
if cfg.policy.name == "tdmpc":
|
||||
img_keys = []
|
||||
for key in offline_buffer.image_keys:
|
||||
img_keys.append(("next", *key))
|
||||
img_keys += offline_buffer.image_keys
|
||||
else:
|
||||
img_keys = offline_buffer.image_keys
|
||||
|
||||
transforms = None
|
||||
if normalize:
|
||||
transforms = [Prod(in_keys=img_keys, prod=1 / 255)]
|
||||
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
|
||||
# min_max_from_spec
|
||||
stats = offline_buffer.compute_or_load_stats() if stats_path is None else torch.load(stats_path)
|
||||
|
||||
# we only normalize the state and action, since the images are usually normalized inside the model for
|
||||
# now (except for tdmpc: see the following)
|
||||
in_keys = [("observation", "state"), ("action")]
|
||||
|
||||
if cfg.policy.name == "tdmpc":
|
||||
# TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now
|
||||
in_keys += img_keys
|
||||
# TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now.
|
||||
in_keys += [("next", *key) for key in img_keys]
|
||||
in_keys.append(("next", "observation", "state"))
|
||||
# stats = dataset.compute_or_load_stats() if stats_path is None else torch.load(stats_path)
|
||||
|
||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||
stats = {}
|
||||
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
||||
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
||||
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
||||
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
stats["observation.state"] = {}
|
||||
stats["observation.state"]["min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
||||
stats["observation.state"]["max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
||||
stats["action"] = {}
|
||||
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
else:
|
||||
# instantiate a one frame dataset with light transform
|
||||
stats_dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
root=DATA_DIR,
|
||||
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||
)
|
||||
stats = compute_or_load_stats(stats_dataset)
|
||||
|
||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||
@@ -211,12 +148,38 @@ def make_offline_buffer(
|
||||
0.38381037,
|
||||
]
|
||||
)
|
||||
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode))
|
||||
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode)) # noqa: F821
|
||||
|
||||
offline_buffer.set_transform(transforms)
|
||||
transforms = v2.Compose(
|
||||
[
|
||||
# TODO(rcadene): we need to do something about image_keys
|
||||
Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||
NormalizeTransform(
|
||||
stats,
|
||||
in_keys=[
|
||||
"observation.state",
|
||||
"action",
|
||||
],
|
||||
mode=normalization_mode,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if not overwrite_sampler:
|
||||
index = torch.arange(0, offline_buffer.num_samples, 1)
|
||||
sampler.extend(index)
|
||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||
# TODO(rcadene): implement delta_timestamps in config
|
||||
delta_timestamps = {
|
||||
"observation.image": [-0.1, 0],
|
||||
"observation.state": [-0.1, 0],
|
||||
"action": [-0.1] + [i / clsfunc.fps for i in range(15)],
|
||||
}
|
||||
else:
|
||||
delta_timestamps = None
|
||||
|
||||
return offline_buffer
|
||||
dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
root=DATA_DIR,
|
||||
delta_timestamps=delta_timestamps,
|
||||
transform=transforms,
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -1,20 +1,13 @@
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import pygame
|
||||
import pymunk
|
||||
import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers.samplers import Sampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from lerobot.common.datasets.abstract import AbstractDataset
|
||||
from lerobot.common.datasets.utils import download_and_extract_zip
|
||||
from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps
|
||||
from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely
|
||||
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
||||
|
||||
@@ -83,37 +76,84 @@ def add_tee(
|
||||
return body
|
||||
|
||||
|
||||
class PushtDataset(AbstractDataset):
|
||||
class PushtDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
|
||||
Arguments
|
||||
----------
|
||||
delta_timestamps : dict[list[float]] | None, optional
|
||||
Loads data from frames with a shift in timestamps with a different strategy for each data key (e.g. state, action or image)
|
||||
If `None`, no shift is applied to current timestamp and the data from the current frame is loaded.
|
||||
"""
|
||||
|
||||
available_datasets = ["pusht"]
|
||||
fps = 10
|
||||
image_keys = ["observation.image"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.2",
|
||||
batch_size: int | None = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: Sampler | None = None,
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
version,
|
||||
batch_size,
|
||||
shuffle=shuffle,
|
||||
root=root,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
)
|
||||
super().__init__()
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
|
||||
self.data_dir = self.root / f"{self.dataset_id}"
|
||||
if (self.data_dir / "data_dict.pth").exists() and (
|
||||
self.data_dir / "data_ids_per_episode.pth"
|
||||
).exists():
|
||||
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
|
||||
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
|
||||
else:
|
||||
self._download_and_preproc_obsolete()
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
|
||||
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.data_dict["index"])
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.data_ids_per_episode)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = {}
|
||||
|
||||
# get episode id and timestamp of the sampled frame
|
||||
current_ts = self.data_dict["timestamp"][idx].item()
|
||||
episode = self.data_dict["episode"][idx].item()
|
||||
|
||||
for key in self.data_dict:
|
||||
if self.delta_timestamps is not None and key in self.delta_timestamps:
|
||||
data, is_pad = load_data_with_delta_timestamps(
|
||||
self.data_dict,
|
||||
self.data_ids_per_episode,
|
||||
self.delta_timestamps,
|
||||
key,
|
||||
current_ts,
|
||||
episode,
|
||||
)
|
||||
item[key] = data
|
||||
item[f"{key}_is_pad"] = is_pad
|
||||
else:
|
||||
item[key] = self.data_dict[key][idx]
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
return item
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
assert self.root is not None
|
||||
@@ -147,8 +187,10 @@ class PushtDataset(AbstractDataset):
|
||||
states = torch.from_numpy(dataset_dict["state"])
|
||||
actions = torch.from_numpy(dataset_dict["action"])
|
||||
|
||||
self.data_ids_per_episode = {}
|
||||
ep_dicts = []
|
||||
|
||||
idx0 = 0
|
||||
idxtd = 0
|
||||
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||
idx1 = dataset_dict.meta["episode_ends"][episode_id]
|
||||
# to create test artifact
|
||||
@@ -194,30 +236,45 @@ class PushtDataset(AbstractDataset):
|
||||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
|
||||
ep_td = TensorDict(
|
||||
{
|
||||
("observation", "image"): image[:-1],
|
||||
("observation", "state"): agent_pos[:-1],
|
||||
"action": actions[idx0:idx1][:-1],
|
||||
"episode": episode_ids[idx0:idx1][:-1],
|
||||
"frame_id": torch.arange(0, num_frames - 1, 1),
|
||||
("next", "observation", "image"): image[1:],
|
||||
("next", "observation", "state"): agent_pos[1:],
|
||||
# TODO: verify that reward and done are aligned with image and agent_pos
|
||||
("next", "reward"): reward[1:],
|
||||
("next", "done"): done[1:],
|
||||
("next", "success"): success[1:],
|
||||
},
|
||||
batch_size=num_frames - 1,
|
||||
)
|
||||
ep_dict = {
|
||||
"observation.image": image,
|
||||
"observation.state": agent_pos,
|
||||
"action": actions[idx0:idx1],
|
||||
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||
# "next.observation.image": image[1:],
|
||||
# "next.observation.state": agent_pos[1:],
|
||||
# TODO(rcadene): verify that reward and done are aligned with image and agent_pos
|
||||
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
|
||||
"next.done": torch.cat([done[1:], done[[-1]]]),
|
||||
"next.success": torch.cat([success[1:], success[[-1]]]),
|
||||
}
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = ep_td[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}")
|
||||
|
||||
td_data[idxtd : idxtd + len(ep_td)] = ep_td
|
||||
assert isinstance(episode_id, int)
|
||||
self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
|
||||
assert len(self.data_ids_per_episode[episode_id]) == num_frames
|
||||
|
||||
idx0 = idx1
|
||||
idxtd = idxtd + len(ep_td)
|
||||
|
||||
return TensorStorage(td_data.lock_())
|
||||
self.data_dict = {}
|
||||
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
|
||||
self.data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset = PushtDataset(
|
||||
"pusht",
|
||||
root=Path("data"),
|
||||
delta_timestamps={
|
||||
"observation.image": [0, -1, -0.2, -0.1],
|
||||
"observation.state": [0, -1, -0.2, -0.1],
|
||||
"action": [-0.1, 0, 1, 2, 3],
|
||||
},
|
||||
)
|
||||
dataset[10]
|
||||
|
||||
@@ -1,75 +1,106 @@
|
||||
import pickle
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers.samplers import (
|
||||
Sampler,
|
||||
)
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from lerobot.common.datasets.abstract import AbstractDataset
|
||||
from lerobot.common.datasets.utils import load_data_with_delta_timestamps
|
||||
|
||||
|
||||
def download():
|
||||
raise NotImplementedError()
|
||||
def download(raw_dir):
|
||||
import gdown
|
||||
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
||||
download_path = "data.zip"
|
||||
gdown.download(url, download_path, quiet=False)
|
||||
zip_path = raw_dir / "data.zip"
|
||||
gdown.download(url, str(zip_path), quiet=False)
|
||||
print("Extracting...")
|
||||
with zipfile.ZipFile(download_path, "r") as zip_f:
|
||||
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
|
||||
for member in zip_f.namelist():
|
||||
if member.startswith("data/xarm") and member.endswith(".pkl"):
|
||||
print(member)
|
||||
zip_f.extract(member=member)
|
||||
Path(download_path).unlink()
|
||||
zip_path.unlink()
|
||||
|
||||
|
||||
class SimxarmDataset(AbstractDataset):
|
||||
class SimxarmDataset(torch.utils.data.Dataset):
|
||||
available_datasets = [
|
||||
"xarm_lift_medium",
|
||||
]
|
||||
fps = 15
|
||||
image_keys = ["observation.image"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.1",
|
||||
batch_size: int | None = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: Sampler | None = None,
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
version,
|
||||
batch_size,
|
||||
shuffle=shuffle,
|
||||
root=root,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
)
|
||||
super().__init__()
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
|
||||
self.data_dir = self.root / f"{self.dataset_id}"
|
||||
if (self.data_dir / "data_dict.pth").exists() and (
|
||||
self.data_dir / "data_ids_per_episode.pth"
|
||||
).exists():
|
||||
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
|
||||
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
|
||||
else:
|
||||
self._download_and_preproc_obsolete()
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
|
||||
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.data_dict["index"])
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.data_ids_per_episode)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = {}
|
||||
|
||||
# get episode id and timestamp of the sampled frame
|
||||
current_ts = self.data_dict["timestamp"][idx].item()
|
||||
episode = self.data_dict["episode"][idx].item()
|
||||
|
||||
for key in self.data_dict:
|
||||
if self.delta_timestamps is not None and key in self.delta_timestamps:
|
||||
data, is_pad = load_data_with_delta_timestamps(
|
||||
self.data_dict,
|
||||
self.data_ids_per_episode,
|
||||
self.delta_timestamps,
|
||||
key,
|
||||
current_ts,
|
||||
episode,
|
||||
)
|
||||
item[key] = data
|
||||
item[f"{key}_is_pad"] = is_pad
|
||||
else:
|
||||
item[key] = self.data_dict[key][idx]
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
return item
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
# assert self.root is not None
|
||||
# TODO(rcadene): finish download
|
||||
# download()
|
||||
assert self.root is not None
|
||||
raw_dir = self.root / f"{self.dataset_id}_raw"
|
||||
if not raw_dir.exists():
|
||||
download(raw_dir)
|
||||
|
||||
dataset_path = self.root / f"{self.dataset_id}" / "buffer.pkl"
|
||||
print(f"Using offline dataset '{dataset_path}'")
|
||||
@@ -78,6 +109,9 @@ class SimxarmDataset(AbstractDataset):
|
||||
|
||||
total_frames = dataset_dict["actions"].shape[0]
|
||||
|
||||
self.data_ids_per_episode = {}
|
||||
ep_dicts = []
|
||||
|
||||
idx0 = 0
|
||||
idx1 = 0
|
||||
episode_id = 0
|
||||
@@ -91,37 +125,38 @@ class SimxarmDataset(AbstractDataset):
|
||||
|
||||
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
|
||||
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
|
||||
next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
|
||||
next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
|
||||
action = torch.tensor(dataset_dict["actions"][idx0:idx1])
|
||||
# TODO(rcadene): concat the last "next_observations" to "observations"
|
||||
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
|
||||
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
|
||||
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
|
||||
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
|
||||
|
||||
episode = TensorDict(
|
||||
{
|
||||
("observation", "image"): image,
|
||||
("observation", "state"): state,
|
||||
"action": torch.tensor(dataset_dict["actions"][idx0:idx1]),
|
||||
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
("next", "observation", "image"): next_image,
|
||||
("next", "observation", "state"): next_state,
|
||||
("next", "reward"): next_reward,
|
||||
("next", "done"): next_done,
|
||||
},
|
||||
batch_size=num_frames,
|
||||
)
|
||||
ep_dict = {
|
||||
"observation.image": image,
|
||||
"observation.state": state,
|
||||
"action": action,
|
||||
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||
# "next.observation.image": next_image,
|
||||
# "next.observation.state": next_state,
|
||||
"next.reward": next_reward,
|
||||
"next.done": next_done,
|
||||
}
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = (
|
||||
episode[0]
|
||||
.expand(total_frames)
|
||||
.memmap_like(self.root / f"{self.dataset_id}" / "replay_buffer")
|
||||
)
|
||||
assert isinstance(episode_id, int)
|
||||
self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
|
||||
assert len(self.data_ids_per_episode[episode_id]) == num_frames
|
||||
|
||||
td_data[idx0:idx1] = episode
|
||||
|
||||
episode_id += 1
|
||||
idx0 = idx1
|
||||
episode_id += 1
|
||||
|
||||
return TensorStorage(td_data.lock_())
|
||||
self.data_dict = {}
|
||||
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
|
||||
self.data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
import io
|
||||
import logging
|
||||
import zipfile
|
||||
from copy import deepcopy
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import requests
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
|
||||
@@ -28,3 +33,173 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def euclidean_distance_matrix(mat0, mat1):
|
||||
# Compute the square of the distance matrix
|
||||
sq0 = torch.sum(mat0**2, dim=1, keepdim=True)
|
||||
sq1 = torch.sum(mat1**2, dim=1, keepdim=True)
|
||||
distance_sq = sq0 + sq1.transpose(0, 1) - 2 * mat0 @ mat1.transpose(0, 1)
|
||||
|
||||
# Taking the square root to get the euclidean distance
|
||||
distance = torch.sqrt(torch.clamp(distance_sq, min=0))
|
||||
return distance
|
||||
|
||||
|
||||
def is_contiguously_true_or_false(bool_vector):
|
||||
assert bool_vector.ndim == 1
|
||||
assert bool_vector.dtype == torch.bool
|
||||
|
||||
# Compare each element with its neighbor to find changes
|
||||
changes = bool_vector[1:] != bool_vector[:-1]
|
||||
|
||||
# Count the number of changes
|
||||
num_changes = changes.sum().item()
|
||||
|
||||
# If there's more than one change, the list is not contiguous
|
||||
return num_changes <= 1
|
||||
|
||||
# examples = [
|
||||
# ([True, False, True, False, False, False], False),
|
||||
# ([True, True, True, False, False, False], True),
|
||||
# ([False, False, False, False, False, False], True)
|
||||
# ]
|
||||
# for bool_list, expected in examples:
|
||||
# result = is_contiguously_true_or_false(bool_list)
|
||||
|
||||
|
||||
def load_data_with_delta_timestamps(
|
||||
data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode
|
||||
):
|
||||
# get indices of the frames associated to the episode, and their timestamps
|
||||
ep_data_ids = data_ids_per_episode[episode]
|
||||
ep_timestamps = data_dict["timestamp"][ep_data_ids]
|
||||
|
||||
# get timestamps used as query to retrieve data of previous/future frames
|
||||
delta_ts = delta_timestamps[key]
|
||||
query_ts = current_ts + torch.tensor(delta_ts)
|
||||
|
||||
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
|
||||
dist = euclidean_distance_matrix(query_ts[:, None], ep_timestamps[:, None])
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
# get the indices of the data that are closest to the query timestamps
|
||||
data_ids = ep_data_ids[argmin_]
|
||||
# closest_ts = ep_timestamps[argmin_]
|
||||
|
||||
# get the data
|
||||
data = data_dict[key][data_ids].clone()
|
||||
|
||||
# TODO(rcadene): synchronize timestamps + interpolation if needed
|
||||
|
||||
tol = 0.02
|
||||
is_pad = min_ > tol
|
||||
|
||||
assert is_contiguously_true_or_false(is_pad), (
|
||||
"One or several timestamps unexpectedly violate the tolerance."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
)
|
||||
|
||||
return data, is_pad
|
||||
|
||||
|
||||
def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
|
||||
stats_path = dataset.data_dir / "stats.pth"
|
||||
if stats_path.exists():
|
||||
return torch.load(stats_path)
|
||||
|
||||
logging.info(f"compute_stats and save to {stats_path}")
|
||||
|
||||
if max_num_samples is None:
|
||||
max_num_samples = len(dataset)
|
||||
else:
|
||||
raise NotImplementedError("We need to set shuffle=True, but this violate an assert for now.")
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
# pin_memory=cfg.device != "cpu",
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
# these einops patterns will be used to aggregate batches and compute statistics
|
||||
stats_patterns = {
|
||||
"action": "b c -> c",
|
||||
"observation.state": "b c -> c",
|
||||
}
|
||||
for key in dataset.image_keys:
|
||||
stats_patterns[key] = "b c h w -> c 1 1"
|
||||
|
||||
# mean and std will be computed incrementally while max and min will track the running value.
|
||||
mean, std, max, min = {}, {}, {}, {}
|
||||
for key in stats_patterns:
|
||||
mean[key] = torch.tensor(0.0).float()
|
||||
std[key] = torch.tensor(0.0).float()
|
||||
max[key] = torch.tensor(-float("inf")).float()
|
||||
min[key] = torch.tensor(float("inf")).float()
|
||||
|
||||
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
|
||||
# surprises when rerunning the sampler.
|
||||
first_batch = None
|
||||
running_item_count = 0 # for online mean computation
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
|
||||
):
|
||||
this_batch_size = len(batch["index"])
|
||||
running_item_count += this_batch_size
|
||||
if first_batch is None:
|
||||
first_batch = deepcopy(batch)
|
||||
for key, pattern in stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
# Numerically stable update step for mean computation.
|
||||
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
||||
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
|
||||
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
|
||||
# and x is the current batch mean. Some rearrangement is then required to avoid risking
|
||||
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
|
||||
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
|
||||
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
|
||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||
|
||||
if i == ceil(max_num_samples / batch_size) - 1:
|
||||
break
|
||||
|
||||
first_batch_ = None
|
||||
running_item_count = 0 # for online std computation
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
|
||||
):
|
||||
this_batch_size = len(batch["index"])
|
||||
running_item_count += this_batch_size
|
||||
# Sanity check to make sure the batches are still in the same order as before.
|
||||
if first_batch_ is None:
|
||||
first_batch_ = deepcopy(batch)
|
||||
for key in stats_patterns:
|
||||
assert torch.equal(first_batch_[key], first_batch[key])
|
||||
for key, pattern in stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
# Numerically stable update step for mean computation (where the mean is over squared
|
||||
# residuals).See notes in the mean computation loop above.
|
||||
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
|
||||
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
|
||||
|
||||
if i == ceil(max_num_samples / batch_size) - 1:
|
||||
break
|
||||
|
||||
for key in stats_patterns:
|
||||
std[key] = torch.sqrt(std[key])
|
||||
|
||||
stats = {}
|
||||
for key in stats_patterns:
|
||||
stats[key] = {
|
||||
"mean": mean[key],
|
||||
"std": std[key],
|
||||
"max": max[key],
|
||||
"min": min[key],
|
||||
}
|
||||
|
||||
torch.save(stats, stats_path)
|
||||
return stats
|
||||
|
||||
@@ -1,64 +1,40 @@
|
||||
from torchrl.envs import SerialEnv
|
||||
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
def make_env(cfg, transform=None):
|
||||
def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
|
||||
"""
|
||||
Note: The returned environment is wrapped in a torchrl.SerialEnv with cfg.rollout_batch_size underlying
|
||||
environments. The env therefore returns batches.`
|
||||
Note: When `num_parallel_envs > 0`, this function returns a `SyncVectorEnv` which takes batched action as input and
|
||||
returns batched observation, reward, terminated, truncated of `num_parallel_envs` items.
|
||||
"""
|
||||
|
||||
kwargs = {
|
||||
"frame_skip": cfg.env.action_repeat,
|
||||
"from_pixels": cfg.env.from_pixels,
|
||||
"pixels_only": cfg.env.pixels_only,
|
||||
"image_size": cfg.env.image_size,
|
||||
"num_prev_obs": cfg.n_obs_steps - 1,
|
||||
}
|
||||
kwargs = {}
|
||||
|
||||
if cfg.env.name == "simxarm":
|
||||
from lerobot.common.envs.simxarm.env import SimxarmEnv
|
||||
|
||||
kwargs["task"] = cfg.env.task
|
||||
clsfunc = SimxarmEnv
|
||||
elif cfg.env.name == "pusht":
|
||||
from lerobot.common.envs.pusht.env import PushtEnv
|
||||
import gym_pusht # noqa
|
||||
|
||||
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
|
||||
|
||||
clsfunc = PushtEnv
|
||||
kwargs.update(
|
||||
{
|
||||
"obs_type": "pixels_agent_pos",
|
||||
"render_action": False,
|
||||
}
|
||||
)
|
||||
env_fn = lambda: gym.make( # noqa: E731
|
||||
"gym_pusht/PushTPixels-v0",
|
||||
render_mode="rgb_array",
|
||||
max_episode_steps=cfg.env.episode_length,
|
||||
**kwargs,
|
||||
)
|
||||
elif cfg.env.name == "aloha":
|
||||
from lerobot.common.envs.aloha.env import AlohaEnv
|
||||
|
||||
kwargs["task"] = cfg.env.task
|
||||
clsfunc = AlohaEnv
|
||||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
def _make_env(seed):
|
||||
nonlocal kwargs
|
||||
kwargs["seed"] = seed
|
||||
env = clsfunc(**kwargs)
|
||||
|
||||
# limit rollout to max_steps
|
||||
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))
|
||||
|
||||
if transform is not None:
|
||||
# useful to add normalization
|
||||
if isinstance(transform, Compose):
|
||||
for tf in transform:
|
||||
env.append_transform(tf.clone())
|
||||
elif isinstance(transform, Transform):
|
||||
env.append_transform(transform.clone())
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return env
|
||||
|
||||
return SerialEnv(
|
||||
cfg.rollout_batch_size,
|
||||
create_env_fn=_make_env,
|
||||
create_env_kwargs=[
|
||||
{"seed": env_seed} for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
|
||||
],
|
||||
)
|
||||
if num_parallel_envs == 0:
|
||||
# non-batched version of the env that returns an observation of shape (c)
|
||||
env = env_fn()
|
||||
else:
|
||||
# batched version of the env that returns an observation of shape (b, c)
|
||||
env = gym.vector.SyncVectorEnv([env_fn for _ in range(num_parallel_envs)])
|
||||
return env
|
||||
|
||||
@@ -55,7 +55,7 @@ class SimxarmEnv(AbstractEnv):
|
||||
if not _has_gym:
|
||||
raise ImportError("Cannot import gymnasium.")
|
||||
|
||||
import gymnasium
|
||||
import gymnasium as gym
|
||||
|
||||
from lerobot.common.envs.simxarm.simxarm import TASKS
|
||||
|
||||
@@ -65,7 +65,7 @@ class SimxarmEnv(AbstractEnv):
|
||||
self._env = TASKS[self.task]["env"]()
|
||||
|
||||
num_actions = len(TASKS[self.task]["action_space"])
|
||||
self._action_space = gymnasium.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
|
||||
self._action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
|
||||
self._action_padding = np.zeros((MAX_NUM_ACTIONS - num_actions), dtype=np.float32)
|
||||
if "w" not in TASKS[self.task]["action_space"]:
|
||||
self._action_padding[-1] = 1.0
|
||||
|
||||
@@ -243,10 +243,9 @@ class DiffusionUnetImagePolicy(BaseImagePolicy):
|
||||
result = {"action": action, "action_pred": action_pred}
|
||||
return result
|
||||
|
||||
def compute_loss(self, batch):
|
||||
assert "valid_mask" not in batch
|
||||
nobs = batch["obs"]
|
||||
nactions = batch["action"]
|
||||
def compute_loss(self, obs_dict, action):
|
||||
nobs = obs_dict
|
||||
nactions = action
|
||||
batch_size = nactions.shape[0]
|
||||
horizon = nactions.shape[1]
|
||||
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
import copy
|
||||
import logging
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
||||
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder
|
||||
from lerobot.common.policies.utils import populate_queues
|
||||
from lerobot.common.utils import get_safe_torch_device
|
||||
|
||||
|
||||
class DiffusionPolicy(AbstractPolicy):
|
||||
class DiffusionPolicy(nn.Module):
|
||||
name = "diffusion"
|
||||
|
||||
def __init__(
|
||||
@@ -38,8 +40,12 @@ class DiffusionPolicy(AbstractPolicy):
|
||||
# parameters passed to step
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(n_action_steps)
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.n_action_steps = n_action_steps
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
self._queues = None
|
||||
|
||||
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
||||
rgb_model_input_shape = copy.deepcopy(shape_meta.obs.image.shape)
|
||||
@@ -100,76 +106,59 @@ class DiffusionPolicy(AbstractPolicy):
|
||||
last_epoch=self.global_step - 1,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Clear observation and action queues. Should be called on `env.reset()`
|
||||
"""
|
||||
self._queues = {
|
||||
"observation.image": deque(maxlen=self.n_obs_steps),
|
||||
"observation.state": deque(maxlen=self.n_obs_steps),
|
||||
"action": deque(maxlen=self.n_action_steps),
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def select_actions(self, observation, step_count):
|
||||
def select_action(self, batch, step):
|
||||
"""
|
||||
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
|
||||
"""
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
# TODO(rcadene): remove unused step
|
||||
del step
|
||||
assert "observation.image" in batch
|
||||
assert "observation.state" in batch
|
||||
assert len(batch) == 2
|
||||
|
||||
obs_dict = {
|
||||
"image": observation["image"],
|
||||
"agent_pos": observation["state"],
|
||||
}
|
||||
if self.training:
|
||||
out = self.diffusion.predict_action(obs_dict)
|
||||
else:
|
||||
out = self.ema_diffusion.predict_action(obs_dict)
|
||||
action = out["action"]
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
# stack n latest observations from the queue
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||
|
||||
obs_dict = {
|
||||
"image": batch["observation.image"],
|
||||
"agent_pos": batch["observation.state"],
|
||||
}
|
||||
if self.training:
|
||||
out = self.diffusion.predict_action(obs_dict)
|
||||
else:
|
||||
out = self.ema_diffusion.predict_action(obs_dict)
|
||||
self._queues["action"].extend(out["action"].transpose(0, 1))
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
def update(self, replay_buffer, step):
|
||||
def forward(self, batch, step):
|
||||
start_time = time.time()
|
||||
|
||||
self.diffusion.train()
|
||||
|
||||
num_slices = self.cfg.batch_size
|
||||
batch_size = self.cfg.horizon * num_slices
|
||||
|
||||
assert batch_size % self.cfg.horizon == 0
|
||||
assert batch_size % num_slices == 0
|
||||
|
||||
def process_batch(batch, horizon, num_slices):
|
||||
# trajectory t = 64, horizon h = 16
|
||||
# (t h) ... -> t h ...
|
||||
batch = batch.reshape(num_slices, horizon) # .transpose(1, 0).contiguous()
|
||||
|
||||
# |-1|0|1|2|3|4|5|6|7|8|9|10|11|12|13|14| timestamps: 16
|
||||
# |o|o| observations: 2
|
||||
# | |a|a|a|a|a|a|a|a| actions executed: 8
|
||||
# |p|p|p|p|p|p|p|p|p|p|p| p| p| p| p| p| actions predicted: 16
|
||||
# note: we predict the action needed to go from t=-1 to t=0 similarly to an inverse kinematic model
|
||||
|
||||
image = batch["observation", "image"]
|
||||
state = batch["observation", "state"]
|
||||
action = batch["action"]
|
||||
assert image.shape[1] == horizon
|
||||
assert state.shape[1] == horizon
|
||||
assert action.shape[1] == horizon
|
||||
|
||||
if not (horizon == 16 and self.cfg.n_obs_steps == 2):
|
||||
raise NotImplementedError()
|
||||
|
||||
# keep first 2 observations of the slice corresponding to t=[-1,0]
|
||||
image = image[:, : self.cfg.n_obs_steps]
|
||||
state = state[:, : self.cfg.n_obs_steps]
|
||||
|
||||
out = {
|
||||
"obs": {
|
||||
"image": image.to(self.device, non_blocking=True),
|
||||
"agent_pos": state.to(self.device, non_blocking=True),
|
||||
},
|
||||
"action": action.to(self.device, non_blocking=True),
|
||||
}
|
||||
return out
|
||||
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
||||
|
||||
data_s = time.time() - start_time
|
||||
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
obs_dict = {
|
||||
"image": batch["observation.image"],
|
||||
"agent_pos": batch["observation.state"],
|
||||
}
|
||||
action = batch["action"]
|
||||
loss = self.diffusion.compute_loss(obs_dict, action)
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
|
||||
@@ -1,53 +1,49 @@
|
||||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
from tensordict import TensorDictBase
|
||||
from tensordict.nn import dispatch
|
||||
from tensordict.utils import NestedKey
|
||||
from torchrl.envs.transforms import ObservationTransform, Transform
|
||||
from torchvision.transforms.v2 import Compose, Transform
|
||||
|
||||
|
||||
class Prod(ObservationTransform):
|
||||
def apply_inverse_transform(item, transform):
|
||||
transforms = transform.transforms if isinstance(transform, Compose) else [transform]
|
||||
for tf in transforms[::-1]:
|
||||
if tf.invertible:
|
||||
item = tf.inverse_transform(item)
|
||||
else:
|
||||
raise ValueError(f"Inverse transform called on a non invertible transform ({tf}).")
|
||||
return item
|
||||
|
||||
|
||||
class Prod(Transform):
|
||||
invertible = True
|
||||
|
||||
def __init__(self, in_keys: Sequence[NestedKey], prod: float):
|
||||
def __init__(self, in_keys: list[str], prod: float):
|
||||
super().__init__()
|
||||
self.in_keys = in_keys
|
||||
self.prod = prod
|
||||
self.original_dtypes = {}
|
||||
|
||||
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
||||
# _reset is called once when the environment reset to normalize the first observation
|
||||
tensordict_reset = self._call(tensordict_reset)
|
||||
return tensordict_reset
|
||||
|
||||
@dispatch(source="in_keys", dest="out_keys")
|
||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||
return self._call(tensordict)
|
||||
|
||||
def _call(self, td):
|
||||
def forward(self, item):
|
||||
for key in self.in_keys:
|
||||
if td.get(key, None) is None:
|
||||
if key not in item:
|
||||
continue
|
||||
self.original_dtypes[key] = td[key].dtype
|
||||
td[key] = td[key].type(torch.float32) * self.prod
|
||||
return td
|
||||
self.original_dtypes[key] = item[key].dtype
|
||||
item[key] = item[key].type(torch.float32) * self.prod
|
||||
return item
|
||||
|
||||
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
def inverse_transform(self, item):
|
||||
for key in self.in_keys:
|
||||
if td.get(key, None) is None:
|
||||
if key not in item:
|
||||
continue
|
||||
td[key] = (td[key] / self.prod).type(self.original_dtypes[key])
|
||||
return td
|
||||
item[key] = (item[key] / self.prod).type(self.original_dtypes[key])
|
||||
return item
|
||||
|
||||
def transform_observation_spec(self, obs_spec):
|
||||
for key in self.in_keys:
|
||||
if obs_spec.get(key, None) is None:
|
||||
continue
|
||||
obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod
|
||||
obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod
|
||||
obs_spec[key].dtype = torch.float32
|
||||
return obs_spec
|
||||
# def transform_observation_spec(self, obs_spec):
|
||||
# for key in self.in_keys:
|
||||
# if obs_spec.get(key, None) is None:
|
||||
# continue
|
||||
# obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod
|
||||
# obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod
|
||||
# obs_spec[key].dtype = torch.float32
|
||||
# return obs_spec
|
||||
|
||||
|
||||
class NormalizeTransform(Transform):
|
||||
@@ -55,65 +51,50 @@ class NormalizeTransform(Transform):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stats: TensorDictBase,
|
||||
in_keys: Sequence[NestedKey] = None,
|
||||
out_keys: Sequence[NestedKey] | None = None,
|
||||
in_keys_inv: Sequence[NestedKey] | None = None,
|
||||
out_keys_inv: Sequence[NestedKey] | None = None,
|
||||
stats: dict,
|
||||
in_keys: list[str] = None,
|
||||
out_keys: list[str] | None = None,
|
||||
in_keys_inv: list[str] | None = None,
|
||||
out_keys_inv: list[str] | None = None,
|
||||
mode="mean_std",
|
||||
):
|
||||
if out_keys is None:
|
||||
out_keys = in_keys
|
||||
if in_keys_inv is None:
|
||||
in_keys_inv = out_keys
|
||||
if out_keys_inv is None:
|
||||
out_keys_inv = in_keys
|
||||
super().__init__(
|
||||
in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv
|
||||
)
|
||||
super().__init__()
|
||||
self.in_keys = in_keys
|
||||
self.out_keys = in_keys if out_keys is None else out_keys
|
||||
self.in_keys_inv = self.out_keys if in_keys_inv is None else in_keys_inv
|
||||
self.out_keys_inv = self.in_keys if out_keys_inv is None else out_keys_inv
|
||||
self.stats = stats
|
||||
assert mode in ["mean_std", "min_max"]
|
||||
self.mode = mode
|
||||
|
||||
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
||||
# _reset is called once when the environment reset to normalize the first observation
|
||||
tensordict_reset = self._call(tensordict_reset)
|
||||
return tensordict_reset
|
||||
|
||||
@dispatch(source="in_keys", dest="out_keys")
|
||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||
return self._call(tensordict)
|
||||
|
||||
def _call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
def forward(self, item):
|
||||
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
|
||||
# TODO(rcadene): don't know how to do `inkey not in td`
|
||||
if td.get(inkey, None) is None:
|
||||
if inkey not in item:
|
||||
continue
|
||||
if self.mode == "mean_std":
|
||||
mean = self.stats[inkey]["mean"]
|
||||
std = self.stats[inkey]["std"]
|
||||
td[outkey] = (td[inkey] - mean) / (std + 1e-8)
|
||||
item[outkey] = (item[inkey] - mean) / (std + 1e-8)
|
||||
else:
|
||||
min = self.stats[inkey]["min"]
|
||||
max = self.stats[inkey]["max"]
|
||||
# normalize to [0,1]
|
||||
td[outkey] = (td[inkey] - min) / (max - min)
|
||||
item[outkey] = (item[inkey] - min) / (max - min)
|
||||
# normalize to [-1, 1]
|
||||
td[outkey] = td[outkey] * 2 - 1
|
||||
return td
|
||||
item[outkey] = item[outkey] * 2 - 1
|
||||
return item
|
||||
|
||||
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
def inverse_transform(self, item):
|
||||
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False):
|
||||
# TODO(rcadene): don't know how to do `inkey not in td`
|
||||
if td.get(inkey, None) is None:
|
||||
if inkey not in item:
|
||||
continue
|
||||
if self.mode == "mean_std":
|
||||
mean = self.stats[inkey]["mean"]
|
||||
std = self.stats[inkey]["std"]
|
||||
td[outkey] = td[inkey] * std + mean
|
||||
item[outkey] = item[inkey] * std + mean
|
||||
else:
|
||||
min = self.stats[inkey]["min"]
|
||||
max = self.stats[inkey]["max"]
|
||||
td[outkey] = (td[inkey] + 1) / 2
|
||||
td[outkey] = td[outkey] * (max - min) + min
|
||||
return td
|
||||
item[outkey] = (item[inkey] + 1) / 2
|
||||
item[outkey] = item[outkey] * (max - min) + min
|
||||
return item
|
||||
|
||||
Reference in New Issue
Block a user