fix environment seeding

add fixes for reproducibility

only try to start env if it is closed

revision

fix normalization and data type

Improve README

Improve README

Tests are passing, Eval pretrained model works, Add gif

Update gif

Update gif

Update gif

Update gif

Update README

Update README

update minor

Update README.md

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>

Update README.md

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>

Address suggestions

Update thumbnail + stats

Update thumbnail + stats

Update README.md

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>

Add more comments

Add test_examples.py
This commit is contained in:
Alexander Soare
2024-03-22 13:25:23 +00:00
committed by Cadene
parent 203bcd7ca5
commit 1a1308d62f
32 changed files with 686 additions and 282 deletions

View File

@@ -9,11 +9,11 @@ import torch
import torchrl
import tqdm
from tensordict import TensorDict
from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.samplers import Sampler
from torchrl.data.replay_buffers.storages import TensorStorage
from torchrl.data.replay_buffers.writers import Writer
from lerobot.common.datasets.abstract import AbstractExperienceReplay
from lerobot.common.datasets.abstract import AbstractDataset
from lerobot.common.datasets.utils import download_and_extract_zip
from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
@@ -83,20 +83,22 @@ def add_tee(
return body
class PushtExperienceReplay(AbstractExperienceReplay):
class PushtDataset(AbstractDataset):
available_datasets = ["pusht"]
def __init__(
self,
dataset_id: str,
version: str | None = "v1.2",
batch_size: int = None,
batch_size: int | None = None,
*,
shuffle: bool = True,
root: Path | None = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: SliceSampler = None,
collate_fn: Callable = None,
writer: Writer = None,
sampler: Sampler | None = None,
collate_fn: Callable | None = None,
writer: Writer | None = None,
transform: "torchrl.envs.Transform" = None,
):
super().__init__(