finish examples 2 and 3
This commit is contained in:
@@ -59,6 +59,8 @@ class AbstractDataset(TensorDictReplayBuffer):
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
# Don't actually load any data. This is a stand-in solution to get the transforms.
|
||||
dummy: bool = False,
|
||||
):
|
||||
assert (
|
||||
self.available_datasets is not None
|
||||
@@ -77,7 +79,7 @@ class AbstractDataset(TensorDictReplayBuffer):
|
||||
f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})."
|
||||
)
|
||||
|
||||
storage = self._download_or_load_dataset()
|
||||
storage = self._download_or_load_dataset() if not dummy else None
|
||||
|
||||
super().__init__(
|
||||
storage=storage,
|
||||
|
||||
@@ -97,6 +97,7 @@ class AlohaDataset(AbstractDataset):
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
dummy: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
@@ -110,6 +111,7 @@ class AlohaDataset(AbstractDataset):
|
||||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
dummy=dummy,
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -21,7 +21,12 @@ def make_offline_buffer(
|
||||
overwrite_batch_size=None,
|
||||
overwrite_prefetch=None,
|
||||
stats_path=None,
|
||||
# Don't actually load any data. This is a stand-in solution to get the transforms.
|
||||
dummy=False,
|
||||
):
|
||||
if dummy and normalize and stats_path is None:
|
||||
raise ValueError("`stats_path` is required if `dummy` and `normalize` are True.")
|
||||
|
||||
if cfg.policy.balanced_sampling:
|
||||
assert cfg.online_steps > 0
|
||||
batch_size = None
|
||||
@@ -93,6 +98,7 @@ def make_offline_buffer(
|
||||
root=DATA_DIR,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
dummy=dummy,
|
||||
)
|
||||
|
||||
if cfg.policy.name == "tdmpc":
|
||||
|
||||
@@ -100,6 +100,7 @@ class PushtDataset(AbstractDataset):
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
dummy: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
@@ -113,6 +114,7 @@ class PushtDataset(AbstractDataset):
|
||||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
dummy=dummy,
|
||||
)
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
|
||||
@@ -51,6 +51,7 @@ class SimxarmDataset(AbstractDataset):
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
dummy: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
@@ -64,6 +65,7 @@ class SimxarmDataset(AbstractDataset):
|
||||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
dummy=dummy,
|
||||
)
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import logging
|
||||
import os.path as osp
|
||||
import random
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
|
||||
@@ -63,3 +67,29 @@ def format_big_number(num):
|
||||
num /= divisor
|
||||
|
||||
return num
|
||||
|
||||
|
||||
def _relative_path_between(path1: Path, path2: Path) -> Path:
|
||||
"""Returns path1 relative to path2."""
|
||||
path1 = path1.absolute()
|
||||
path2 = path2.absolute()
|
||||
try:
|
||||
return path1.relative_to(path2)
|
||||
except ValueError: # most likely because path1 is not a subpath of path2
|
||||
common_parts = Path(osp.commonpath([path1, path2])).parts
|
||||
return Path(
|
||||
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
|
||||
)
|
||||
|
||||
|
||||
def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> DictConfig:
|
||||
"""Initialize a Hydra config given only the path to the relevant config file.
|
||||
|
||||
For config resolution, it is assumed that the config file's parent is the Hydra config dir.
|
||||
"""
|
||||
# Hydra needs a path relative to this file.
|
||||
hydra.initialize(
|
||||
str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent))
|
||||
)
|
||||
cfg = hydra.compose(Path(config_path).stem, overrides)
|
||||
return cfg
|
||||
|
||||
Reference in New Issue
Block a user