Ran pre-commit run --all-files

This commit is contained in:
Simon Alibert
2024-02-29 13:37:48 +01:00
parent 0e0e760e43
commit 7e024fdce6
16 changed files with 124 additions and 237 deletions

View File

@@ -1,7 +1,7 @@
import os
import pickle
from pathlib import Path
from typing import Any, Callable, Dict, Tuple
from typing import Callable
import torch
import torchrl
@@ -9,7 +9,6 @@ import tqdm
from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import (
TensorDictPrioritizedReplayBuffer,
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers.samplers import (
@@ -22,7 +21,6 @@ from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
class SimxarmExperienceReplay(TensorDictReplayBuffer):
available_datasets = [
"xarm_lift_medium",
]
@@ -77,15 +75,11 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer):
if num_slices is not None or slice_len is not None:
if sampler is not None:
raise ValueError(
"`num_slices` and `slice_len` are exclusive with the `sampler` argument."
)
raise ValueError("`num_slices` and `slice_len` are exclusive with the `sampler` argument.")
if replacement:
if not self.shuffle:
raise RuntimeError(
"shuffle=False can only be used when replacement=False."
)
raise RuntimeError("shuffle=False can only be used when replacement=False.")
sampler = SliceSampler(
num_slices=num_slices,
slice_len=slice_len,
@@ -130,7 +124,7 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer):
# load
dataset_dir = Path("data") / self.dataset_id
dataset_path = dataset_dir / f"buffer.pkl"
dataset_path = dataset_dir / "buffer.pkl"
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
@@ -150,12 +144,8 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer):
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
next_image = torch.tensor(
dataset_dict["next_observations"]["rgb"][idx0:idx1]
)
next_state = torch.tensor(
dataset_dict["next_observations"]["state"][idx0:idx1]
)
next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
@@ -176,11 +166,7 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer):
if episode_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = (
episode[0]
.expand(total_frames)
.memmap_like(self.root / self.dataset_id)
)
td_data = episode[0].expand(total_frames).memmap_like(self.root / self.dataset_id)
td_data[idx0:idx1] = episode