fix environment seeding

add fixes for reproducibility

only try to start env if it is closed

revision

fix normalization and data type

Improve README

Improve README

Tests are passing, Eval pretrained model works, Add gif

Update gif

Update gif

Update gif

Update gif

Update README

Update README

update minor

Update README.md

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>

Update README.md

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>

Address suggestions

Update thumbnail + stats

Update thumbnail + stats

Update README.md

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>

Add more comments

Add test_examples.py
This commit is contained in:
Alexander Soare
2024-03-22 13:25:23 +00:00
committed by Cadene
parent 203bcd7ca5
commit 1a1308d62f
32 changed files with 686 additions and 282 deletions

View File

@@ -9,7 +9,7 @@ import tqdm
from huggingface_hub import snapshot_download
from tensordict import TensorDict
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.samplers import Sampler
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
from torchrl.envs.transforms.transforms import Compose
@@ -17,22 +17,56 @@ from torchrl.envs.transforms.transforms import Compose
HF_USER = "lerobot"
class AbstractExperienceReplay(TensorDictReplayBuffer):
class AbstractDataset(TensorDictReplayBuffer):
"""
AbstractDataset represents a dataset in the context of imitation learning or reinforcement learning.
This class is designed to be subclassed by concrete implementations that specify particular types of datasets.
These implementations can vary based on the source of the data, the environment the data pertains to,
or the specific kind of data manipulation applied.
Note:
- `TensorDictReplayBuffer` is the base class from which `AbstractDataset` inherits. It provides the foundational
functionality for storing and retrieving `TensorDict`-like data.
- `available_datasets` should be overridden by concrete subclasses to list the specific dataset variants supported.
It is expected that these variants correspond to a HuggingFace dataset on the hub.
For instance, the `AlohaDataset` which inherites from `AbstractDataset` has 4 available dataset variants:
- [aloha_sim_transfer_cube_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
- [aloha_sim_insertion_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
- [aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
- [aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
- When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
1. set the required class attributes:
- for classes inheriting from `AbstractDataset`: `available_datasets`
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
- for classes inheriting from `AbstractPolicy`: `name`
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
3. update variables in `tests/test_available.py` by importing your new class
"""
available_datasets: list[str] | None = None
def __init__(
self,
dataset_id: str,
version: str | None = None,
batch_size: int = None,
batch_size: int | None = None,
*,
shuffle: bool = True,
root: Path | None = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: SliceSampler = None,
collate_fn: Callable = None,
writer: Writer = None,
sampler: Sampler | None = None,
collate_fn: Callable | None = None,
writer: Writer | None = None,
transform: "torchrl.envs.Transform" = None,
):
assert (
self.available_datasets is not None
), "Subclasses of `AbstractDataset` should set the `available_datasets` class attribute."
assert (
dataset_id in self.available_datasets
), f"The provided dataset ({dataset_id}) is not on the list of available datasets {self.available_datasets}."
self.dataset_id = dataset_id
self.version = version
self.shuffle = shuffle