Compare commits
21 Commits
user/alexa
...
user/rcade
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9cdc24bc0e | ||
|
|
a346469a5a | ||
|
|
2bef00c317 | ||
|
|
9954994a4b | ||
|
|
0fc94b81b3 | ||
|
|
d32a279435 | ||
|
|
75cc10198f | ||
|
|
4ecfd17f9e | ||
|
|
58d1787ee3 | ||
|
|
b752833f3f | ||
|
|
9c88071bc7 | ||
|
|
5805a7ffb1 | ||
|
|
41521f7e96 | ||
|
|
b10c9507d4 | ||
|
|
a311d38796 | ||
|
|
19730b3412 | ||
|
|
95e84079ef | ||
|
|
8e856f1bf7 | ||
|
|
8c2b47752a | ||
|
|
f515cb6efd | ||
|
|
c3f8d14fd8 |
3126
.github/poetry/cpu/poetry.lock
generated
vendored
Normal file
3126
.github/poetry/cpu/poetry.lock
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
107
.github/poetry/cpu/pyproject.toml
vendored
Normal file
107
.github/poetry/cpu/pyproject.toml
vendored
Normal file
@@ -0,0 +1,107 @@
|
||||
[tool.poetry]
|
||||
name = "lerobot"
|
||||
version = "0.1.0"
|
||||
description = "Le robot is learning"
|
||||
authors = [
|
||||
"Rémi Cadène <re.cadene@gmail.com>",
|
||||
"Simon Alibert <alibert.sim@gmail.com>",
|
||||
]
|
||||
repository = "https://github.com/Cadene/lerobot"
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
classifiers=[
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Topic :: Software Development :: Build Tools",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
]
|
||||
packages = [{include = "lerobot"}]
|
||||
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
cython = "^3.0.8"
|
||||
termcolor = "^2.4.0"
|
||||
omegaconf = "^2.3.0"
|
||||
dm-env = "^1.6"
|
||||
pandas = "^2.2.1"
|
||||
wandb = "^0.16.3"
|
||||
moviepy = "^1.0.3"
|
||||
imageio = {extras = ["pyav"], version = "^2.34.0"}
|
||||
gdown = "^5.1.0"
|
||||
hydra-core = "^1.3.2"
|
||||
einops = "^0.7.0"
|
||||
pygame = "^2.5.2"
|
||||
pymunk = "^6.6.0"
|
||||
zarr = "^2.17.0"
|
||||
shapely = "^2.0.3"
|
||||
scikit-image = "^0.22.0"
|
||||
numba = "^0.59.0"
|
||||
mpmath = "^1.3.0"
|
||||
torch = {version = "^2.2.1", source = "torch-cpu"}
|
||||
tensordict = {git = "https://github.com/pytorch/tensordict"}
|
||||
torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"}
|
||||
mujoco = "^3.1.2"
|
||||
mujoco-py = "^2.1.2.14"
|
||||
gym = "^0.26.2"
|
||||
opencv-python = "^4.9.0.80"
|
||||
diffusers = "^0.26.3"
|
||||
torchvision = {version = "^0.17.1", source = "torch-cpu"}
|
||||
h5py = "^3.10.0"
|
||||
dm = "^1.3"
|
||||
dm-control = "^1.0.16"
|
||||
huggingface-hub = "^0.21.4"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pre-commit = "^3.6.2"
|
||||
debugpy = "^1.8.1"
|
||||
pytest = "^8.1.0"
|
||||
|
||||
|
||||
[[tool.poetry.source]]
|
||||
name = "torch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
priority = "supplemental"
|
||||
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 110
|
||||
target-version = "py310"
|
||||
exclude = [
|
||||
".bzr",
|
||||
".direnv",
|
||||
".eggs",
|
||||
".git",
|
||||
".git-rewrite",
|
||||
".hg",
|
||||
".mypy_cache",
|
||||
".nox",
|
||||
".pants.d",
|
||||
".pytype",
|
||||
".ruff_cache",
|
||||
".svn",
|
||||
".tox",
|
||||
".venv",
|
||||
"__pypackages__",
|
||||
"_build",
|
||||
"buck-out",
|
||||
"build",
|
||||
"dist",
|
||||
"node_modules",
|
||||
"venv",
|
||||
]
|
||||
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
|
||||
|
||||
|
||||
[tool.poetry-dynamic-versioning]
|
||||
enable = true
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"]
|
||||
build-backend = "poetry_dynamic_versioning.backend"
|
||||
20
.github/workflows/test.yml
vendored
20
.github/workflows/test.yml
vendored
@@ -46,8 +46,8 @@ jobs:
|
||||
id: restore-poetry-cache
|
||||
uses: actions/cache/restore@v3
|
||||
with:
|
||||
path: ~/.local # the path depends on the OS
|
||||
key: poetry-${{ env.POETRY_VERSION }} # increment to reset cache
|
||||
path: ~/.local
|
||||
key: poetry-${{ env.POETRY_VERSION }}
|
||||
|
||||
- name: Install Poetry
|
||||
if: steps.restore-poetry-cache.outputs.cache-hit != 'true'
|
||||
@@ -64,8 +64,8 @@ jobs:
|
||||
id: save-poetry-cache
|
||||
uses: actions/cache/save@v3
|
||||
with:
|
||||
path: ~/.local # the path depends on the OS
|
||||
key: poetry-${{ env.POETRY_VERSION }} # increment to reset cache
|
||||
path: ~/.local
|
||||
key: poetry-${{ env.POETRY_VERSION }}
|
||||
|
||||
- name: Configure Poetry
|
||||
run: poetry config virtualenvs.in-project true
|
||||
@@ -73,6 +73,10 @@ jobs:
|
||||
#----------------------------------------------
|
||||
# install dependencies
|
||||
#----------------------------------------------
|
||||
# TODO(aliberts): move to gpu runners
|
||||
- name: Select cpu dependencies # HACK
|
||||
run: cp -t . .github/poetry/cpu/pyproject.toml .github/poetry/cpu/poetry.lock
|
||||
|
||||
- name: Load cached venv
|
||||
id: restore-dependencies-cache
|
||||
uses: actions/cache/restore@v3
|
||||
@@ -80,18 +84,10 @@ jobs:
|
||||
path: .venv
|
||||
key: venv-${{ steps.setup-python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ hashFiles('**/poetry.lock') }}
|
||||
|
||||
- name: Info
|
||||
run: |
|
||||
sudo du -sh /tmp
|
||||
sudo df -h
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.restore-dependencies-cache.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
mkdir ~/tmp
|
||||
echo $TMPDIR
|
||||
echo $TEMP
|
||||
echo $TMP
|
||||
poetry install --no-interaction --no-root
|
||||
|
||||
- name: Save cached venv
|
||||
|
||||
31
README.md
31
README.md
@@ -103,6 +103,18 @@ pre-commit install
|
||||
pre-commit run -a
|
||||
```
|
||||
|
||||
**Adding dependencies (temporary)**
|
||||
|
||||
Right now, for the CI to work, whenever a new dependency is added it needs to be also added to the cpu env, eg:
|
||||
|
||||
```
|
||||
# Run in this directory, adds the package to the main env with cuda
|
||||
poetry add some-package
|
||||
|
||||
# Adds the same package to the cpu env
|
||||
cd .github/poetry/cpu && poetry add some-package
|
||||
```
|
||||
|
||||
**Tests**
|
||||
|
||||
Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already).
|
||||
@@ -134,6 +146,25 @@ Run tests
|
||||
DATA_DIR="tests/data" pytest -sx tests
|
||||
```
|
||||
|
||||
**Datasets**
|
||||
|
||||
To add a pytorch rl dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access:
|
||||
```
|
||||
huggingface-cli login --token $HUGGINGFACE_TOKEN --add-to-git-credential
|
||||
```
|
||||
|
||||
Then you can upload it to the hub with:
|
||||
```
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload --repo-type dataset $HF_USER/$DATASET data/$DATASET
|
||||
```
|
||||
|
||||
For instance, for [cadene/pusht](https://huggingface.co/datasets/cadene/pusht), we used:
|
||||
```
|
||||
HF_USER=cadene
|
||||
DATASET=pusht
|
||||
```
|
||||
|
||||
|
||||
## Acknowledgment
|
||||
- Our Diffusion policy and Pusht environment are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/)
|
||||
- Our TDMPC policy and Simxarm environment are adapted from [FOWM](https://www.yunhaifeng.com/FOWM/)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import abc
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
@@ -7,14 +6,16 @@ import einops
|
||||
import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from huggingface_hub import snapshot_download
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.datasets.utils import _get_root_dir
|
||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
from torchrl.envs.transforms.transforms import Compose
|
||||
|
||||
from lerobot.common.datasets.transforms import DecodeVideoTransform
|
||||
|
||||
|
||||
class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||
def __init__(
|
||||
@@ -23,7 +24,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||
batch_size: int = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path = None,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: SliceSampler = None,
|
||||
@@ -33,11 +34,15 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||
):
|
||||
self.dataset_id = dataset_id
|
||||
self.shuffle = shuffle
|
||||
self.root = _get_root_dir(self.dataset_id) if root is None else root
|
||||
self.root = Path(self.root)
|
||||
self.data_dir = self.root / self.dataset_id
|
||||
self.root = root
|
||||
storage, meta_data = self._download_or_load_dataset()
|
||||
|
||||
storage = self._download_or_load_storage()
|
||||
if transform is not None and "video_id_to_path" in meta_data:
|
||||
# hack to access video paths
|
||||
assert isinstance(transform, Compose)
|
||||
for tf in transform:
|
||||
if isinstance(tf, DecodeVideoTransform):
|
||||
tf.set_video_id_to_path(meta_data["video_id_to_path"])
|
||||
|
||||
super().__init__(
|
||||
storage=storage,
|
||||
@@ -98,19 +103,18 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||
torch.save(stats, stats_path)
|
||||
return stats
|
||||
|
||||
@abc.abstractmethod
|
||||
def _download_and_preproc(self) -> torch.StorageBase:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _download_or_load_storage(self):
|
||||
if not self._is_downloaded():
|
||||
storage = self._download_and_preproc()
|
||||
def _download_or_load_dataset(self) -> torch.StorageBase:
|
||||
if self.root is None:
|
||||
self.data_dir = Path(snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset"))
|
||||
else:
|
||||
storage = TensorStorage(TensorDict.load_memmap(self.data_dir))
|
||||
return storage
|
||||
self.data_dir = self.root / self.dataset_id
|
||||
|
||||
def _is_downloaded(self) -> bool:
|
||||
return self.data_dir.is_dir()
|
||||
storage = TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer"))
|
||||
# required to not send cuda frames to cpu by default
|
||||
storage._storage.clear_device_()
|
||||
|
||||
meta_data = torch.load(self.data_dir / "meta_data.pth")
|
||||
return storage, meta_data
|
||||
|
||||
def _compute_stats(self, num_batch=100, batch_size=32):
|
||||
rb = TensorDictReplayBuffer(
|
||||
|
||||
@@ -87,7 +87,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
|
||||
batch_size: int = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path = None,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: SliceSampler = None,
|
||||
@@ -124,8 +124,9 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
|
||||
def image_keys(self) -> list:
|
||||
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]
|
||||
|
||||
def _download_and_preproc(self):
|
||||
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
|
||||
def _download_and_preproc_obsolete(self):
|
||||
assert self.root is not None
|
||||
raw_dir = self.root / f"{self.dataset_id}_raw"
|
||||
if not raw_dir.is_dir():
|
||||
download(raw_dir, self.dataset_id)
|
||||
|
||||
@@ -174,7 +175,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
|
||||
|
||||
if ep_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = ep_td[0].expand(total_num_frames).memmap_like(self.data_dir)
|
||||
td_data = ep_td[0].expand(total_num_frames).memmap_like(self.root / f"{self.dataset_id}")
|
||||
|
||||
td_data[idxtd : idxtd + len(ep_td)] = ep_td
|
||||
idxtd = idxtd + len(ep_td)
|
||||
|
||||
@@ -7,7 +7,10 @@ from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
|
||||
|
||||
from lerobot.common.envs.transforms import NormalizeTransform, Prod
|
||||
|
||||
DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
|
||||
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
|
||||
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
|
||||
# to load a subset of our datasets for faster continuous integration.
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
|
||||
def make_offline_buffer(
|
||||
@@ -77,24 +80,37 @@ def make_offline_buffer(
|
||||
|
||||
offline_buffer = clsfunc(
|
||||
dataset_id=dataset_id,
|
||||
root=DATA_DIR,
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
root=DATA_DIR,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
)
|
||||
|
||||
if cfg.policy.name == "tdmpc":
|
||||
img_keys = []
|
||||
for key in offline_buffer.image_keys:
|
||||
img_keys.append(("next", *key))
|
||||
img_keys += offline_buffer.image_keys
|
||||
else:
|
||||
img_keys = offline_buffer.image_keys
|
||||
transforms = []
|
||||
|
||||
transforms = [Prod(in_keys=img_keys, prod=1 / 255)]
|
||||
# transforms = [
|
||||
# ViewSliceHorizonTransform(num_slices, cfg.policy.horizon),
|
||||
# KeepFrames(positions=[0], in_keys=[("observation")]),
|
||||
# DecodeVideoTransform(
|
||||
# data_dir=offline_buffer.data_dir,
|
||||
# device=cfg.device,
|
||||
# frame_rate=None,
|
||||
# in_keys=[("observation", "frame")],
|
||||
# out_keys=[("observation", "frame", "data")],
|
||||
# ),
|
||||
# ]
|
||||
|
||||
if normalize:
|
||||
if cfg.policy.name == "tdmpc":
|
||||
img_keys = []
|
||||
for key in offline_buffer.image_keys:
|
||||
img_keys.append(("next", *key))
|
||||
img_keys += offline_buffer.image_keys
|
||||
else:
|
||||
img_keys = offline_buffer.image_keys
|
||||
transforms.append(Prod(in_keys=img_keys, prod=1 / 255))
|
||||
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
|
||||
stats = offline_buffer.compute_or_load_stats()
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ class PushtExperienceReplay(AbstractExperienceReplay):
|
||||
batch_size: int = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path = None,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: SliceSampler = None,
|
||||
@@ -111,8 +111,9 @@ class PushtExperienceReplay(AbstractExperienceReplay):
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
def _download_and_preproc(self):
|
||||
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
|
||||
def _download_and_preproc_obsolete(self):
|
||||
assert self.root is not None
|
||||
raw_dir = self.root / f"{self.dataset_id}_raw"
|
||||
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
||||
if not zarr_path.is_dir():
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -208,7 +209,7 @@ class PushtExperienceReplay(AbstractExperienceReplay):
|
||||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = ep_td[0].expand(total_frames).memmap_like(self.data_dir)
|
||||
td_data = ep_td[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}")
|
||||
|
||||
td_data[idxtd : idxtd + len(ep_td)] = ep_td
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
|
||||
batch_size: int = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path = None,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: SliceSampler = None,
|
||||
@@ -64,11 +64,12 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
def _download_and_preproc(self):
|
||||
def _download_and_preproc_obsolete(self):
|
||||
assert self.root is not None
|
||||
# TODO(rcadene): finish download
|
||||
download()
|
||||
|
||||
dataset_path = self.data_dir / "buffer.pkl"
|
||||
dataset_path = self.root / f"{self.dataset_id}_raw" / "buffer.pkl"
|
||||
print(f"Using offline dataset '{dataset_path}'")
|
||||
with open(dataset_path, "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
@@ -110,7 +111,7 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
|
||||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = episode[0].expand(total_frames).memmap_like(self.data_dir)
|
||||
td_data = episode[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}")
|
||||
|
||||
td_data[idx0:idx1] = episode
|
||||
|
||||
|
||||
310
lerobot/common/datasets/transforms.py
Normal file
310
lerobot/common/datasets/transforms.py
Normal file
@@ -0,0 +1,310 @@
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from tensordict import TensorDictBase
|
||||
from tensordict.nn import dispatch
|
||||
from tensordict.utils import NestedKey
|
||||
from torchaudio.io import StreamReader
|
||||
from torchrl.envs.transforms import Transform
|
||||
|
||||
|
||||
def yuv_to_rgb(frames):
|
||||
assert frames.dtype == torch.uint8
|
||||
assert frames.ndim == 4
|
||||
assert frames.shape[1] == 3
|
||||
|
||||
frames = frames.cpu().to(torch.float)
|
||||
y = frames[..., 0, :, :]
|
||||
u = frames[..., 1, :, :]
|
||||
v = frames[..., 2, :, :]
|
||||
|
||||
y /= 255
|
||||
u = u / 255 - 0.5
|
||||
v = v / 255 - 0.5
|
||||
|
||||
r = y + 1.13983 * v
|
||||
g = y + -0.39465 * u - 0.58060 * v
|
||||
b = y + 2.03211 * u
|
||||
|
||||
rgb = torch.stack([r, g, b], 1)
|
||||
rgb = (rgb * 255).clamp(0, 255).to(torch.uint8)
|
||||
return rgb
|
||||
|
||||
|
||||
def yuv_to_rgb_cv2(frames, return_hwc=True):
|
||||
assert frames.dtype == torch.uint8
|
||||
assert frames.ndim == 4
|
||||
assert frames.shape[1] == 3
|
||||
frames = frames.cpu()
|
||||
import cv2
|
||||
|
||||
frames = einops.rearrange(frames, "b c h w -> b h w c")
|
||||
frames = frames.numpy()
|
||||
frames = [cv2.cvtColor(frame, cv2.COLOR_YUV2RGB) for frame in frames]
|
||||
frames = [torch.from_numpy(frame) for frame in frames]
|
||||
frames = torch.stack(frames)
|
||||
if not return_hwc:
|
||||
frames = einops.rearrange(frames, "b h w c -> b c h w")
|
||||
return frames
|
||||
|
||||
|
||||
class ViewSliceHorizonTransform(Transform):
|
||||
invertible = False
|
||||
|
||||
def __init__(self, num_slices, horizon):
|
||||
super().__init__()
|
||||
self.num_slices = num_slices
|
||||
self.horizon = horizon
|
||||
|
||||
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
||||
# _reset is called once when the environment reset to normalize the first observation
|
||||
tensordict_reset = self._call(tensordict_reset)
|
||||
return tensordict_reset
|
||||
|
||||
@dispatch(source="in_keys", dest="out_keys")
|
||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||
return self._call(tensordict)
|
||||
|
||||
def _call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
td = td.view(self.num_slices, self.horizon)
|
||||
return td
|
||||
|
||||
|
||||
class KeepFrames(Transform):
|
||||
invertible = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
positions,
|
||||
in_keys: Sequence[NestedKey],
|
||||
out_keys: Sequence[NestedKey] = None,
|
||||
):
|
||||
if isinstance(positions, list):
|
||||
assert isinstance(positions[0], int)
|
||||
# TODO(rcadene)L add support for `isinstance(positions, int)`?
|
||||
|
||||
self.positions = positions
|
||||
if out_keys is None:
|
||||
out_keys = in_keys
|
||||
super().__init__(in_keys=in_keys, out_keys=out_keys)
|
||||
|
||||
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
||||
# _reset is called once when the environment reset to normalize the first observation
|
||||
tensordict_reset = self._call(tensordict_reset)
|
||||
return tensordict_reset
|
||||
|
||||
@dispatch(source="in_keys", dest="out_keys")
|
||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||
return self._call(tensordict)
|
||||
|
||||
def _call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
# we need set batch_size=[] before assigning a different shape to td[outkey]
|
||||
td.batch_size = []
|
||||
|
||||
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
|
||||
# TODO(rcadene): don't know how to do `inkey not in td`
|
||||
if td.get(inkey, None) is None:
|
||||
continue
|
||||
td[outkey] = td[inkey][:, self.positions]
|
||||
return td
|
||||
|
||||
|
||||
class DecodeVideoTransform(Transform):
|
||||
invertible = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: Path | str,
|
||||
device="cpu",
|
||||
decoding_lib: str = "torchaudio",
|
||||
# format options are None=yuv420p (usually), rgb24, bgr24, etc.
|
||||
format: str | None = None,
|
||||
frame_rate: int | None = None,
|
||||
width: int | None = None,
|
||||
height: int | None = None,
|
||||
in_keys: Sequence[NestedKey] = None,
|
||||
out_keys: Sequence[NestedKey] = None,
|
||||
in_keys_inv: Sequence[NestedKey] | None = None,
|
||||
out_keys_inv: Sequence[NestedKey] | None = None,
|
||||
):
|
||||
self.data_dir = Path(data_dir)
|
||||
self.device = device
|
||||
self.decoding_lib = decoding_lib
|
||||
self.format = format
|
||||
self.frame_rate = frame_rate
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.video_id_to_path = None
|
||||
if out_keys is None:
|
||||
out_keys = in_keys
|
||||
if in_keys_inv is None:
|
||||
in_keys_inv = out_keys
|
||||
if out_keys_inv is None:
|
||||
out_keys_inv = in_keys
|
||||
super().__init__(
|
||||
in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv
|
||||
)
|
||||
|
||||
def set_video_id_to_path(self, video_id_to_path):
|
||||
self.video_id_to_path = video_id_to_path
|
||||
|
||||
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
||||
# _reset is called once when the environment reset to normalize the first observation
|
||||
tensordict_reset = self._call(tensordict_reset)
|
||||
return tensordict_reset
|
||||
|
||||
@dispatch(source="in_keys", dest="out_keys")
|
||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||
return self._call(tensordict)
|
||||
|
||||
def _call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
assert (
|
||||
self.video_id_to_path is not None
|
||||
), "Setting a video_id_to_path dictionary with `self.set_video_id_to_path(video_id_to_path)` is required."
|
||||
|
||||
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
|
||||
# TODO(rcadene): don't know how to do `inkey not in td`
|
||||
if td.get(inkey, None) is None:
|
||||
continue
|
||||
|
||||
bsize = len(td[inkey]) # num episodes in the batch
|
||||
b_frames = []
|
||||
for i in range(bsize):
|
||||
assert (
|
||||
td["observation", "frame", "video_id"].ndim == 3
|
||||
and td["observation", "frame", "video_id"].shape[2] == 1
|
||||
), "We expect 2 dims. Respectively, number of episodes in the batch, number of observations, 1"
|
||||
|
||||
ep_video_ids = td[inkey]["video_id"][i]
|
||||
timestamps = td[inkey]["timestamp"][i]
|
||||
frame_ids = td["frame_id"][i]
|
||||
|
||||
unique_video_id = (ep_video_ids.min() == ep_video_ids.max()).item()
|
||||
assert unique_video_id
|
||||
|
||||
is_ascending = torch.all(timestamps[:-1] <= timestamps[1:]).item()
|
||||
assert is_ascending
|
||||
|
||||
is_contiguous = ((frame_ids[1:] - frame_ids[:-1]) == 1).all().item()
|
||||
assert is_contiguous
|
||||
|
||||
FIRST_FRAME = 0 # noqa: N806
|
||||
video_id = ep_video_ids[FIRST_FRAME].squeeze(0).item()
|
||||
video_path = self.data_dir / self.video_id_to_path[video_id]
|
||||
first_frame_ts = timestamps[FIRST_FRAME].squeeze(0).item()
|
||||
num_contiguous_frames = len(timestamps)
|
||||
|
||||
if self.decoding_lib == "torchaudio":
|
||||
frames = self._decode_frames_torchaudio(video_path, first_frame_ts, num_contiguous_frames)
|
||||
elif self.decoding_lib == "ffmpegio":
|
||||
frames = self._decode_frames_ffmpegio(video_path, first_frame_ts, num_contiguous_frames)
|
||||
elif self.decoding_lib == "decord":
|
||||
frames = self._decode_frames_decord(video_path, first_frame_ts, num_contiguous_frames)
|
||||
else:
|
||||
raise ValueError(self.decoding_lib)
|
||||
|
||||
assert frames.ndim == 4
|
||||
assert frames.shape[1] == 3
|
||||
|
||||
b_frames.append(frames)
|
||||
|
||||
td[outkey] = torch.stack(b_frames)
|
||||
|
||||
if self.device == "cuda":
|
||||
# make sure we return a cuda tensor, since the frames can be unwillingly sent to cpu
|
||||
assert "cuda" in str(td[outkey].device), f"{td[outkey].device} instead of cuda"
|
||||
return td
|
||||
|
||||
def _decode_frames_torchaudio(self, video_path, first_frame_ts, num_contiguous_frames):
|
||||
filter_desc = []
|
||||
video_stream_kwgs = {
|
||||
"frames_per_chunk": num_contiguous_frames,
|
||||
"buffer_chunk_size": num_contiguous_frames,
|
||||
}
|
||||
|
||||
# choice of decoder
|
||||
if self.device == "cuda":
|
||||
video_stream_kwgs["hw_accel"] = "cuda"
|
||||
video_stream_kwgs["decoder"] = "h264_cuvid"
|
||||
# video_stream_kwgs["decoder"] = "hevc_cuvid"
|
||||
# video_stream_kwgs["decoder"] = "av1_cuvid"
|
||||
# video_stream_kwgs["decoder"] = "ffv1_cuvid"
|
||||
else:
|
||||
video_stream_kwgs["decoder"] = "h264"
|
||||
# video_stream_kwgs["decoder"] = "hevc"
|
||||
# video_stream_kwgs["decoder"] = "av1"
|
||||
# video_stream_kwgs["decoder"] = "ffv1"
|
||||
|
||||
# resize
|
||||
resize_width = self.width is not None
|
||||
resize_height = self.height is not None
|
||||
if resize_width or resize_height:
|
||||
if self.device == "cuda":
|
||||
assert resize_width and resize_height
|
||||
video_stream_kwgs["decoder_option"] = {"resize": f"{self.width}x{self.height}"}
|
||||
else:
|
||||
scales = []
|
||||
if resize_width:
|
||||
scales.append(f"width={self.width}")
|
||||
if resize_height:
|
||||
scales.append(f"height={self.height}")
|
||||
filter_desc.append(f"scale={':'.join(scales)}")
|
||||
|
||||
# choice of format
|
||||
if self.format is not None:
|
||||
if self.device == "cuda":
|
||||
# TODO(rcadene): rebuild ffmpeg with --enable-cuda-nvcc, --enable-cuvid, and --enable-libnpp
|
||||
raise NotImplementedError()
|
||||
# filter_desc = f"scale=format={self.format}"
|
||||
# filter_desc = f"scale_cuda=format={self.format}"
|
||||
# filter_desc = f"scale_npp=format={self.format}"
|
||||
else:
|
||||
filter_desc.append(f"format=pix_fmts={self.format}")
|
||||
|
||||
# choice of frame rate
|
||||
if self.frame_rate is not None:
|
||||
filter_desc.append(f"fps={self.frame_rate}")
|
||||
|
||||
filter_desc.append("scale=in_range=limited:out_range=full")
|
||||
|
||||
if len(filter_desc) > 0:
|
||||
video_stream_kwgs["filter_desc"] = ",".join(filter_desc)
|
||||
|
||||
# create a stream and load a certain number of frame at a certain frame rate
|
||||
# TODO(rcadene): make sure it's the most optimal way to do it
|
||||
s = StreamReader(str(video_path))
|
||||
s.seek(first_frame_ts)
|
||||
s.add_video_stream(**video_stream_kwgs)
|
||||
s.fill_buffer()
|
||||
(frames,) = s.pop_chunks()
|
||||
|
||||
if "yuv" in self.format:
|
||||
frames = yuv_to_rgb(frames)
|
||||
return frames
|
||||
|
||||
def _decode_frames_ffmpegio(self, video_path, first_frame_ts, num_contiguous_frames):
|
||||
import ffmpegio
|
||||
|
||||
fs, frames = ffmpegio.video.read(
|
||||
str(video_path), ss=str(first_frame_ts), vframes=num_contiguous_frames, pix_fmt=self.format
|
||||
)
|
||||
frames = torch.from_numpy(frames)
|
||||
frames = einops.rearrange(frames, "b h w c -> b c h w")
|
||||
if self.device == "cuda":
|
||||
frames = frames.to(self.device)
|
||||
return frames
|
||||
|
||||
def _decode_frames_decord(self, video_path, first_frame_ts, num_contiguous_frames):
|
||||
from decord import VideoReader, cpu, gpu
|
||||
|
||||
with open(str(video_path), "rb") as f:
|
||||
ctx = gpu if self.device == "cuda" else cpu
|
||||
vr = VideoReader(f, ctx=ctx(0)) # noqa: F841
|
||||
raise NotImplementedError("Convert `first_frame_ts` into frame_id")
|
||||
# frame_id = frame_ids[0].item()
|
||||
# frames = vr.get_batch([frame_id])
|
||||
# frames = torch.from_numpy(frames.asnumpy())
|
||||
# frames = einops.rearrange(frames, "b h w c -> b c h w")
|
||||
# return frames
|
||||
@@ -3,6 +3,7 @@ import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
|
||||
@@ -28,3 +29,26 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def yuv_to_rgb(frames):
|
||||
assert frames.dtype == torch.uint8
|
||||
assert frames.ndim == 4
|
||||
assert frames.shape[1] == 3
|
||||
|
||||
frames = frames.cpu().to(torch.float)
|
||||
y = frames[..., 0, :, :]
|
||||
u = frames[..., 1, :, :]
|
||||
v = frames[..., 2, :, :]
|
||||
|
||||
y /= 255
|
||||
u = u / 255 - 0.5
|
||||
v = v / 255 - 0.5
|
||||
|
||||
r = y + 1.14 * v
|
||||
g = y + -0.396 * u - 0.581 * v
|
||||
b = y + 2.029 * u
|
||||
|
||||
rgb = torch.stack([r, g, b], 1)
|
||||
rgb = (rgb * 255).clamp(0, 255).to(torch.uint8)
|
||||
return rgb
|
||||
|
||||
@@ -30,6 +30,7 @@ class Logger:
|
||||
self._model_dir = self._log_dir / "models"
|
||||
self._buffer_dir = self._log_dir / "buffers"
|
||||
self._save_model = cfg.save_model
|
||||
self._disable_wandb_artifact = cfg.wandb.disable_artifact
|
||||
self._save_buffer = cfg.save_buffer
|
||||
self._group = cfg_to_group(cfg)
|
||||
self._seed = cfg.seed
|
||||
@@ -71,9 +72,10 @@ class Logger:
|
||||
self._model_dir.mkdir(parents=True, exist_ok=True)
|
||||
fp = self._model_dir / f"{str(identifier)}.pt"
|
||||
policy.save(fp)
|
||||
if self._wandb:
|
||||
if self._wandb and not self._disable_wandb_artifact:
|
||||
# note wandb artifact does not accept ":" in its name
|
||||
artifact = self._wandb.Artifact(
|
||||
self._group + "-" + str(self._seed) + "-" + str(identifier),
|
||||
self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier),
|
||||
type="model",
|
||||
)
|
||||
artifact.add_file(fp)
|
||||
|
||||
@@ -54,7 +54,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
def update(self, replay_buffer, step):
|
||||
del step
|
||||
|
||||
start_time = time.time()
|
||||
start_time = time.monotonic()
|
||||
|
||||
self.train()
|
||||
|
||||
@@ -104,7 +104,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
||||
|
||||
data_s = time.time() - start_time
|
||||
data_s = time.monotonic() - start_time
|
||||
|
||||
loss = self.compute_loss(batch)
|
||||
loss.backward()
|
||||
@@ -125,7 +125,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
# "lr": self.lr_scheduler.get_last_lr()[0],
|
||||
"lr": self.cfg.lr,
|
||||
"data_s": data_s,
|
||||
"update_s": time.time() - start_time,
|
||||
"update_s": time.monotonic() - start_time,
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
@@ -188,8 +188,8 @@ class MetricLogger:
|
||||
def log_every(self, iterable, print_freq, header=None):
|
||||
if not header:
|
||||
header = ""
|
||||
start_time = time.time()
|
||||
end = time.time()
|
||||
start_time = time.monotonic()
|
||||
end = time.monotonic()
|
||||
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
||||
data_time = SmoothedValue(fmt="{avg:.4f}")
|
||||
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
||||
@@ -218,9 +218,9 @@ class MetricLogger:
|
||||
)
|
||||
mega_b = 1024.0 * 1024.0
|
||||
for i, obj in enumerate(iterable):
|
||||
data_time.update(time.time() - end)
|
||||
data_time.update(time.monotonic() - end)
|
||||
yield obj
|
||||
iter_time.update(time.time() - end)
|
||||
iter_time.update(time.monotonic() - end)
|
||||
if i % print_freq == 0 or i == len(iterable) - 1:
|
||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
@@ -247,8 +247,8 @@ class MetricLogger:
|
||||
data=str(data_time),
|
||||
)
|
||||
)
|
||||
end = time.time()
|
||||
total_time = time.time() - start_time
|
||||
end = time.monotonic()
|
||||
total_time = time.monotonic() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable)))
|
||||
|
||||
|
||||
@@ -111,7 +111,7 @@ class DiffusionPolicy(nn.Module):
|
||||
return action
|
||||
|
||||
def update(self, replay_buffer, step):
|
||||
start_time = time.time()
|
||||
start_time = time.monotonic()
|
||||
|
||||
self.diffusion.train()
|
||||
|
||||
@@ -158,7 +158,7 @@ class DiffusionPolicy(nn.Module):
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
||||
|
||||
data_s = time.time() - start_time
|
||||
data_s = time.monotonic() - start_time
|
||||
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
loss.backward()
|
||||
@@ -181,7 +181,7 @@ class DiffusionPolicy(nn.Module):
|
||||
"grad_norm": float(grad_norm),
|
||||
"lr": self.lr_scheduler.get_last_lr()[0],
|
||||
"data_s": data_s,
|
||||
"update_s": time.time() - start_time,
|
||||
"update_s": time.monotonic() - start_time,
|
||||
}
|
||||
|
||||
# TODO(rcadene): remove hardcoding
|
||||
|
||||
@@ -291,7 +291,7 @@ class TDMPC(nn.Module):
|
||||
|
||||
def update(self, replay_buffer, step, demo_buffer=None):
|
||||
"""Main update function. Corresponds to one iteration of the model learning."""
|
||||
start_time = time.time()
|
||||
start_time = time.monotonic()
|
||||
|
||||
num_slices = self.cfg.batch_size
|
||||
batch_size = self.cfg.horizon * num_slices
|
||||
@@ -405,7 +405,7 @@ class TDMPC(nn.Module):
|
||||
self.std = h.linear_schedule(self.cfg.std_schedule, step)
|
||||
self.model.train()
|
||||
|
||||
data_s = time.time() - start_time
|
||||
data_s = time.monotonic() - start_time
|
||||
|
||||
# Compute targets
|
||||
with torch.no_grad():
|
||||
@@ -501,7 +501,7 @@ class TDMPC(nn.Module):
|
||||
"grad_norm": float(grad_norm),
|
||||
"lr": self.cfg.lr,
|
||||
"data_s": data_s,
|
||||
"update_s": time.time() - start_time,
|
||||
"update_s": time.monotonic() - start_time,
|
||||
}
|
||||
info["demo_batch_size"] = demo_batch_size
|
||||
info["expectile"] = expectile
|
||||
|
||||
@@ -30,5 +30,7 @@ policy: ???
|
||||
|
||||
wandb:
|
||||
enable: true
|
||||
# Set to true to disable saving an artifact despite save_model == True
|
||||
disable_artifact: false
|
||||
project: lerobot
|
||||
notes: ""
|
||||
|
||||
129
lerobot/scripts/convert_dataset_uint8_to_mp4.py
Normal file
129
lerobot/scripts/convert_dataset_uint8_to_mp4.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""
|
||||
usage: `python lerobot/scripts/convert_dataset_uint8_to_mp4.py --in-data-dir data/pusht --out-data-dir tests/data/pusht`
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
|
||||
|
||||
def convert_dataset_uint8_to_mp4(in_data_dir, out_data_dir, fps, overwrite_num_frames=None):
|
||||
assert fps is not None and isinstance(fps, float)
|
||||
# load full dataset as a tensor dict
|
||||
in_td_data = TensorDict.load_memmap(in_data_dir)
|
||||
|
||||
out_data_dir = Path(out_data_dir)
|
||||
# use 1 frame to know the specification of the dataset
|
||||
# and copy it over `n` frames in the test artifact directory
|
||||
out_rb_dir = out_data_dir / "replay_buffer"
|
||||
if out_rb_dir.exists():
|
||||
shutil.rmtree(out_rb_dir)
|
||||
|
||||
num_frames = len(in_td_data) if overwrite_num_frames is None else overwrite_num_frames
|
||||
|
||||
# del in_td_data["observation", "image"]
|
||||
# del in_td_data["next", "observation", "image"]
|
||||
|
||||
out_td_data = in_td_data[0].memmap_().clone()
|
||||
|
||||
out_td_data["observation", "frame", "video_id"] = torch.zeros(1, dtype=torch.int)
|
||||
out_td_data["observation", "frame", "timestamp"] = torch.zeros(1)
|
||||
out_td_data["next", "observation", "frame", "video_id"] = torch.zeros(1, dtype=torch.int)
|
||||
out_td_data["next", "observation", "frame", "timestamp"] = torch.zeros(1)
|
||||
|
||||
out_td_data = out_td_data.expand(num_frames)
|
||||
out_td_data = out_td_data.memmap_like(out_rb_dir)
|
||||
|
||||
out_vid_dir = out_data_dir / "videos"
|
||||
out_vid_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
video_id_to_path = {}
|
||||
|
||||
for key in out_td_data.keys(include_nested=True, leaves_only=True):
|
||||
if in_td_data.get(key, None) is None:
|
||||
continue
|
||||
if overwrite_num_frames is None:
|
||||
out_td_data[key].copy_(in_td_data[key].clone())
|
||||
else:
|
||||
out_td_data[key][:num_frames].copy_(in_td_data[key][:num_frames].clone())
|
||||
|
||||
for i in range(num_frames):
|
||||
video_id = in_td_data["episode"][i]
|
||||
frame_id = in_td_data["frame_id"][i]
|
||||
|
||||
out_td_data["observation", "frame", "video_id"][i] = video_id
|
||||
out_td_data["observation", "frame", "timestamp"][i] = frame_id / fps
|
||||
out_td_data["next", "observation", "frame", "video_id"][i] = video_id
|
||||
out_td_data["next", "observation", "frame", "timestamp"][i] = (frame_id + 1) / fps
|
||||
|
||||
video_id = video_id.item()
|
||||
if video_id not in video_id_to_path:
|
||||
video_id_to_path[video_id] = f"videos/episode_{video_id}.mp4"
|
||||
|
||||
# copy the first `n` frames so that we have real data
|
||||
|
||||
# make sure everything has been properly written
|
||||
out_td_data.lock_()
|
||||
|
||||
# copy the full statistics of dataset since it's pretty small
|
||||
in_stats_path = Path(in_data_dir) / "stats.pth"
|
||||
|
||||
out_stats_path = Path(out_data_dir) / "stats.pth"
|
||||
shutil.copy(in_stats_path, out_stats_path)
|
||||
|
||||
meta_data = {
|
||||
"video_id_to_path": video_id_to_path,
|
||||
}
|
||||
torch.save(meta_data, out_data_dir / "meta_data.pth")
|
||||
|
||||
|
||||
# def write_to_mp4():
|
||||
# buffer = io.BytesIO()
|
||||
# swriter = StreamWriter(buffer, format="mp4")
|
||||
|
||||
# device = "cuda"
|
||||
|
||||
# c,h,w = in_td_data[0]["observation", "image"].shape
|
||||
|
||||
# swriter.add_video_stream(
|
||||
# frame_rate=fps,
|
||||
# width=w,
|
||||
# height=h,
|
||||
# # frame_rate=30000 / 1001,
|
||||
# format="yuv444p",
|
||||
# encoder="h264_nvenc",
|
||||
# encoder_format="yuv444p",
|
||||
# hw_accel=device,
|
||||
# )
|
||||
|
||||
# for i in range(num_frames):
|
||||
# ep_id = in_td_data[i]["episode"]
|
||||
# data = in_td_data[i]["observation", "image"]
|
||||
# with swriter.open():
|
||||
# t0 = time.monotonic()
|
||||
# data = data.to(device)
|
||||
# swriter.write_video_chunk(0, data)
|
||||
# elapsed = time.monotonic() - t0
|
||||
# size = buffer.tell()
|
||||
# print(f"{elapsed=}")
|
||||
# print(f"{size=}")
|
||||
# buffer.seek(0)
|
||||
# video = buffer.read()
|
||||
|
||||
# vid_path = out_vid_dir / f"episode_{ep_id}.mp4"
|
||||
# with open(vid_path, 'wb+') as f:
|
||||
# f.write(video)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Create dataset")
|
||||
|
||||
parser.add_argument("--in-data-dir", type=str, help="Path to input data")
|
||||
parser.add_argument("--out-data-dir", type=str, help="Path to save the output data")
|
||||
parser.add_argument("--fps", type=float)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_dataset_uint8_to_mp4(args.in_data_dir, args.out_data_dir, args.fps)
|
||||
@@ -32,7 +32,7 @@ def eval_policy(
|
||||
fps: int = 15,
|
||||
return_first_video: bool = False,
|
||||
):
|
||||
start = time.time()
|
||||
start = time.monotonic()
|
||||
sum_rewards = []
|
||||
max_rewards = []
|
||||
successes = []
|
||||
@@ -85,8 +85,8 @@ def eval_policy(
|
||||
"avg_sum_reward": np.nanmean(sum_rewards),
|
||||
"avg_max_reward": np.nanmean(max_rewards),
|
||||
"pc_success": np.nanmean(successes) * 100,
|
||||
"eval_s": time.time() - start,
|
||||
"eval_ep_s": (time.time() - start) / num_episodes,
|
||||
"eval_s": time.monotonic() - start,
|
||||
"eval_ep_s": (time.monotonic() - start) / num_episodes,
|
||||
}
|
||||
if return_first_video:
|
||||
return info, first_video
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import logging
|
||||
import shutil
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
@@ -26,9 +29,49 @@ def visualize_dataset_cli(cfg: dict):
|
||||
|
||||
def cat_and_write_video(video_path, frames, fps):
|
||||
frames = torch.cat(frames)
|
||||
assert frames.dtype == torch.uint8
|
||||
if frames.dtype != torch.uint8:
|
||||
logging.warning(f"frames are expected to be uint8 to {frames.dtype}")
|
||||
frames = frames.type(torch.uint8)
|
||||
|
||||
_, _, h, w = frames.shape
|
||||
frames = einops.rearrange(frames, "b c h w -> b h w c").numpy()
|
||||
imageio.mimsave(video_path, frames, fps=fps)
|
||||
|
||||
img_dir = Path(video_path.split(".")[0])
|
||||
if img_dir.exists():
|
||||
shutil.rmtree(img_dir)
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i in range(len(frames)):
|
||||
imageio.imwrite(str(img_dir / f"frame_{i:04d}.png"), frames[i])
|
||||
|
||||
ffmpeg_command = [
|
||||
"ffmpeg",
|
||||
"-r",
|
||||
str(fps),
|
||||
"-f",
|
||||
"image2",
|
||||
"-s",
|
||||
f"{w}x{h}",
|
||||
"-i",
|
||||
str(img_dir / "frame_%04d.png"),
|
||||
"-vcodec",
|
||||
"libx264",
|
||||
#'-vcodec', 'libx265',
|
||||
#'-vcodec', 'libaom-av1',
|
||||
"-crf",
|
||||
"0", # Lossless option
|
||||
"-pix_fmt",
|
||||
# "yuv420p", # Specify pixel format
|
||||
"yuv444p", # Specify pixel format
|
||||
video_path,
|
||||
# video_path.replace(".mp4", ".mkv")
|
||||
]
|
||||
subprocess.run(ffmpeg_command, check=True)
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
# clean temporary image directory
|
||||
# shutil.rmtree(img_dir)
|
||||
|
||||
|
||||
def visualize_dataset(cfg: dict, out_dir=None):
|
||||
@@ -61,7 +104,10 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
||||
|
||||
# TODO(rcaene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
|
||||
no_more_frames = offline_buffer._sampler._sample_list.numel() == 0
|
||||
new_episode = ep_idx != current_ep_idx
|
||||
new_episode = ep_idx > current_ep_idx
|
||||
|
||||
if ep_idx < current_ep_idx:
|
||||
break
|
||||
|
||||
if new_episode:
|
||||
logging.info(f"Visualizing episode {current_ep_idx}")
|
||||
@@ -71,7 +117,7 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
||||
# append last observed frames (the ones after last action taken)
|
||||
frames[im_key].append(ep_td[("next", *im_key)])
|
||||
|
||||
video_dir = Path(out_dir) / "visualize_dataset"
|
||||
video_dir = Path(out_dir) / "videos"
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if len(offline_buffer.image_keys) > 1:
|
||||
|
||||
141
poetry.lock
generated
141
poetry.lock
generated
@@ -838,6 +838,78 @@ files = [
|
||||
[package.dependencies]
|
||||
numpy = ">=1.17.3"
|
||||
|
||||
[[package]]
|
||||
name = "hf-transfer"
|
||||
version = "0.1.6"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "hf_transfer-0.1.6-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:6fd3d61f9229d27def007e53540412507b74ac2fdb1a29985ae0b6a5137749a2"},
|
||||
{file = "hf_transfer-0.1.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b043bb78df1225de043eb041de9d97783fcca14a0bdc1b1d560fc172fc21b648"},
|
||||
{file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7db60dd18eae4fa6ea157235fb82196cde5313995b396d1b591aad3b790a7f8f"},
|
||||
{file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:30d31dbab9b5a558cce407b8728e39d87d7af1ef8745ddb90187e9ae0b9e1e90"},
|
||||
{file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f6b368bddd757efc7af3126ba81f9ac8f9435e2cc00902cb3d64f2be28d8f719"},
|
||||
{file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa2086d8aefaaa3e144e167324574882004c0cec49bf2d0638ec4b74732d8da0"},
|
||||
{file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:45d8985a0940bfe1535cb4ca781f5c11e47c83798ef3373ee1f5d57bbe527a9c"},
|
||||
{file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2f42b89735f1cde22f2a795d1f0915741023235666be7de45879e533c7d6010c"},
|
||||
{file = "hf_transfer-0.1.6-cp310-none-win32.whl", hash = "sha256:2d2c4c4613f3ad45b6ce6291e347b2d3ba1b86816635681436567e461cb3c961"},
|
||||
{file = "hf_transfer-0.1.6-cp310-none-win_amd64.whl", hash = "sha256:78b0eed8d8dce60168a46e584b9742b816af127d7e410a713e12c31249195342"},
|
||||
{file = "hf_transfer-0.1.6-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:f1d8c172153f9a6cdaecf137612c42796076f61f6bea1072c90ac2e17c1ab6fa"},
|
||||
{file = "hf_transfer-0.1.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2c601996351f90c514a75a0eeb02bf700b1ad1db2d946cbfe4b60b79e29f0b2f"},
|
||||
{file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e585c808405557d3f5488f385706abb696997bbae262ea04520757e30836d9d"},
|
||||
{file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ec51af1e8cf4268c268bd88932ade3d7ca895a3c661b42493503f02610ae906b"},
|
||||
{file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d106fdf996332f6df3ed3fab6d6332df82e8c1fb4b20fd81a491ca4d2ab5616a"},
|
||||
{file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e9c2ee9e9fde5a0319cc0e8ddfea10897482bc06d5709b10a238f1bc2ebcbc0b"},
|
||||
{file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f394ea32bc7802b061e549d3133efc523b4ae4fd19bf4b74b183ca6066eef94e"},
|
||||
{file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4282f09902114cd67fca98a1a1bad569a44521a8395fedf327e966714f68b977"},
|
||||
{file = "hf_transfer-0.1.6-cp311-none-win32.whl", hash = "sha256:276dbf307d5ab6f1bcbf57b5918bfcf9c59d6848ccb28242349e1bb5985f983b"},
|
||||
{file = "hf_transfer-0.1.6-cp311-none-win_amd64.whl", hash = "sha256:fa475175c51451186bea804471995fa8e7b2a48a61dcca55534911dc25955527"},
|
||||
{file = "hf_transfer-0.1.6-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:23d157a67acfa00007799323a1c441b2bbacc7dee625b016b7946fe0e25e6c89"},
|
||||
{file = "hf_transfer-0.1.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6067342a2864b988f861cd2d31bd78eb1e84d153a3f6df38485b6696d9ad3013"},
|
||||
{file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91cfcb3070e205b58fa8dc8bcb6a62ccc40913fcdb9cd1ff7c364c8e3aa85345"},
|
||||
{file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb76064ac5165d5eeaaf8d0903e8bf55477221ecc2a4a4d69f0baca065ab905b"},
|
||||
{file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9dabd3a177d83028f164984cf4dd859f77ec1e20c97a6f307ff8fcada0785ef1"},
|
||||
{file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d0bf4254e44f64a26e0a5b73b5d7e8d91bb36870718fb4f8e126ec943ff4c805"},
|
||||
{file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d32c1b106f38f336ceb21531f4db9b57d777b9a33017dafdb6a5316388ebe50"},
|
||||
{file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff05aba3c83921e5c7635ba9f07c693cc893350c447644824043aeac27b285f5"},
|
||||
{file = "hf_transfer-0.1.6-cp312-none-win32.whl", hash = "sha256:051ef0c55607652cb5974f59638da035773254b9a07d7ee5b574fe062de4c9d1"},
|
||||
{file = "hf_transfer-0.1.6-cp312-none-win_amd64.whl", hash = "sha256:716fb5c574fcbdd8092ce73f9b6c66f42e3544337490f77c60ec07df02bd081b"},
|
||||
{file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c0c981134a55965e279cb7be778c1ccaf93f902fc9ebe31da4f30caf824cc4d"},
|
||||
{file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1ef1f145f04c5b573915bcb1eb5db4039c74f6b46fce73fc473c4287e613b623"},
|
||||
{file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d0a7609b004db3347dbb7796df45403eceb171238210d054d93897d6d84c63a4"},
|
||||
{file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:60f0864bf5996773dbd5f8ae4d1649041f773fe9d5769f4c0eeb5553100acef3"},
|
||||
{file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d01e55d630ffe70a4f5d0ed576a04c6a48d7c65ca9a7d18f2fca385f20685a9"},
|
||||
{file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d855946c5062b665190de15b2bdbd4c8eddfee35350bfb7564592e23d36fbbd3"},
|
||||
{file = "hf_transfer-0.1.6-cp37-none-win32.whl", hash = "sha256:fd40b2409cfaf3e8aba20169ee09552f69140e029adeec261b988903ff0c8f6f"},
|
||||
{file = "hf_transfer-0.1.6-cp37-none-win_amd64.whl", hash = "sha256:0e0eba49d46d3b5481919aea0794aec625fbc6ecdf13fe7e0e9f3fc5d5ad5971"},
|
||||
{file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7e669fecb29fc454449739f9f53ed9253197e7c19e6a6eaa0f08334207af4287"},
|
||||
{file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:89f701802892e5eb84f89f402686861f87dc227d6082b05f4e9d9b4e8015a3c3"},
|
||||
{file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b6f2b0c8b95b01409275d789a9b74d5f2e146346f985d384bf50ec727caf1ccc"},
|
||||
{file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa855a2fa262792a230f9efcdb5da6d431b747d1861d2a69fe7834b19aea077e"},
|
||||
{file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4aa8ca349afb2f0713475426946261eb2035e4efb50ebd2c1d5ad04f395f4217"},
|
||||
{file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01255f043996bc7d1bae62d8afc5033a90c7e36ce308b988eeb84afe0a69562f"},
|
||||
{file = "hf_transfer-0.1.6-cp38-none-win32.whl", hash = "sha256:60b1db183e8a7540cd4f8b2160ff4de55f77cb0c3fc6a10be1e7c30eb1b2bdeb"},
|
||||
{file = "hf_transfer-0.1.6-cp38-none-win_amd64.whl", hash = "sha256:fb8be3cba6aaa50ab2e9dffbd25c8eb2046785eeff642cf0cdd0dd9ae6be3539"},
|
||||
{file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d09af35e3e3f09b664e6429e9a0dc200f29c5bdfd88bdd9666de51183b1fe202"},
|
||||
{file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a4505bd707cc14d85c800f961fad8ca76f804a8ad22fbb7b1a217d8d0c15e6a5"},
|
||||
{file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c453fd8b0be9740faa23cecd1f28ee9ead7d900cefa64ff836960c503a744c9"},
|
||||
{file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:13cb8884e718a78c3b81a8cdec9c7ac196dd42961fce55c3ccff3dd783e5ad7a"},
|
||||
{file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:39cd39df171a2b5404de69c4e6cd14eee47f6fe91c1692f939bfb9e59a0110d8"},
|
||||
{file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ff0629ee9f98df57a783599602eb498f9ec3619dc69348b12e4d9d754abf0e9"},
|
||||
{file = "hf_transfer-0.1.6-cp39-none-win32.whl", hash = "sha256:164a6ce445eb0cc7c645f5b6e1042c003d33292520c90052b6325f30c98e4c5f"},
|
||||
{file = "hf_transfer-0.1.6-cp39-none-win_amd64.whl", hash = "sha256:11b8b4b73bf455f13218c5f827698a30ae10998ca31b8264b51052868c7a9f11"},
|
||||
{file = "hf_transfer-0.1.6-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:16957ba057376a99ea361074ce1094f61b58e769defa6be2422ae59c0b6a6530"},
|
||||
{file = "hf_transfer-0.1.6-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7db952112e3b8ee1a5cbf500d2443e9ce4fb893281c5310a3e31469898628005"},
|
||||
{file = "hf_transfer-0.1.6-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d39d826a7344f5e39f438d62632acd00467aa54a083b66496f61ef67a9885a56"},
|
||||
{file = "hf_transfer-0.1.6-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4e2653fbfa92e7651db73d99b697c8684e7345c479bd6857da80bed6138abb2"},
|
||||
{file = "hf_transfer-0.1.6-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:144277e6a86add10b90ec3b583253aec777130312256bfc8d5ade5377e253807"},
|
||||
{file = "hf_transfer-0.1.6-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3bb53bcd16365313b2aa0dbdc28206f577d70770f31249cdabc387ac5841edcc"},
|
||||
{file = "hf_transfer-0.1.6-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:990d73a5a68d8261980f146c51f4c5f9995314011cb225222021ad7c39f3af2d"},
|
||||
{file = "hf_transfer-0.1.6-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:652406037029ab9b4097b4c5f29321bad5f64c2b46fbff142509d918aec87c29"},
|
||||
{file = "hf_transfer-0.1.6.tar.gz", hash = "sha256:deb505a7d417d7055fd7b3549eadb91dfe782941261f3344025c486c16d1d2f9"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "huggingface-hub"
|
||||
version = "0.21.4"
|
||||
@@ -852,6 +924,7 @@ files = [
|
||||
[package.dependencies]
|
||||
filelock = "*"
|
||||
fsspec = ">=2023.5.0"
|
||||
hf-transfer = {version = ">=0.1.4", optional = true, markers = "extra == \"hf_transfer\""}
|
||||
packaging = ">=20.9"
|
||||
pyyaml = ">=5.1"
|
||||
requests = "*"
|
||||
@@ -2611,13 +2684,13 @@ test = ["asv", "gmpy2", "hypothesis", "mpmath", "pooch", "pytest", "pytest-cov",
|
||||
|
||||
[[package]]
|
||||
name = "sentry-sdk"
|
||||
version = "1.41.0"
|
||||
version = "1.42.0"
|
||||
description = "Python client for Sentry (https://sentry.io)"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "sentry-sdk-1.41.0.tar.gz", hash = "sha256:4f2d6c43c07925d8cd10dfbd0970ea7cb784f70e79523cca9dbcd72df38e5a46"},
|
||||
{file = "sentry_sdk-1.41.0-py2.py3-none-any.whl", hash = "sha256:be4f8f4b29a80b6a3b71f0f31487beb9e296391da20af8504498a328befed53f"},
|
||||
{file = "sentry-sdk-1.42.0.tar.gz", hash = "sha256:4a8364b8f7edbf47f95f7163e48334c96100d9c098f0ae6606e2e18183c223e6"},
|
||||
{file = "sentry_sdk-1.42.0-py2.py3-none-any.whl", hash = "sha256:a654ee7e497a3f5f6368b36d4f04baeab1fe92b3105f7f6965d6ef0de35a9ba4"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2641,6 +2714,7 @@ grpcio = ["grpcio (>=1.21.1)"]
|
||||
httpx = ["httpx (>=0.16.0)"]
|
||||
huey = ["huey (>=2)"]
|
||||
loguru = ["loguru (>=0.5)"]
|
||||
openai = ["openai (>=1.0.0)", "tiktoken (>=0.3.0)"]
|
||||
opentelemetry = ["opentelemetry-distro (>=0.35b0)"]
|
||||
opentelemetry-experimental = ["opentelemetry-distro (>=0.40b0,<1.0)", "opentelemetry-instrumentation-aiohttp-client (>=0.40b0,<1.0)", "opentelemetry-instrumentation-django (>=0.40b0,<1.0)", "opentelemetry-instrumentation-fastapi (>=0.40b0,<1.0)", "opentelemetry-instrumentation-flask (>=0.40b0,<1.0)", "opentelemetry-instrumentation-requests (>=0.40b0,<1.0)", "opentelemetry-instrumentation-sqlite3 (>=0.40b0,<1.0)", "opentelemetry-instrumentation-urllib (>=0.40b0,<1.0)"]
|
||||
pure-eval = ["asttokens", "executing", "pure-eval"]
|
||||
@@ -2756,18 +2830,18 @@ test = ["pytest"]
|
||||
|
||||
[[package]]
|
||||
name = "setuptools"
|
||||
version = "69.1.1"
|
||||
version = "69.2.0"
|
||||
description = "Easily download, build, install, upgrade, and uninstall Python packages"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "setuptools-69.1.1-py3-none-any.whl", hash = "sha256:02fa291a0471b3a18b2b2481ed902af520c69e8ae0919c13da936542754b4c56"},
|
||||
{file = "setuptools-69.1.1.tar.gz", hash = "sha256:5c0806c7d9af348e6dd3777b4f4dbb42c7ad85b190104837488eab9a7c945cf8"},
|
||||
{file = "setuptools-69.2.0-py3-none-any.whl", hash = "sha256:c21c49fb1042386df081cb5d86759792ab89efca84cf114889191cd09aacc80c"},
|
||||
{file = "setuptools-69.2.0.tar.gz", hash = "sha256:0ff4183f8f42cd8fa3acea16c45205521a4ef28f73c6391d8a25e92893134f2e"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"]
|
||||
testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
|
||||
testing = ["build[virtualenv]", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
|
||||
testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
|
||||
|
||||
[[package]]
|
||||
@@ -2876,7 +2950,7 @@ mpmath = ">=0.19"
|
||||
|
||||
[[package]]
|
||||
name = "tensordict"
|
||||
version = "0.4.0+551331d"
|
||||
version = "0.4.0+6a56ecd"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
@@ -2897,7 +2971,7 @@ tests = ["pytest", "pytest-benchmark", "pytest-instafail", "pytest-rerunfailures
|
||||
type = "git"
|
||||
url = "https://github.com/pytorch/tensordict"
|
||||
reference = "HEAD"
|
||||
resolved_reference = "ed22554d6860731610df784b2f5d09f31d3dbc7a"
|
||||
resolved_reference = "6a56ecd728757feee387f946b7da66dd452b739b"
|
||||
|
||||
[[package]]
|
||||
name = "termcolor"
|
||||
@@ -2999,6 +3073,43 @@ typing-extensions = ">=4.8.0"
|
||||
opt-einsum = ["opt-einsum (>=3.3)"]
|
||||
optree = ["optree (>=0.9.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "torchaudio"
|
||||
version = "2.2.1"
|
||||
description = "An audio package for PyTorch"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "torchaudio-2.2.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:580eefd764a01a64d5b6aa260c0c47974be6a6964892d54029a73b17f4611fcd"},
|
||||
{file = "torchaudio-2.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8ad55c2069b27bbe18e14783a202e3f3f8082fe9e59281436ba797edb0fc94d5"},
|
||||
{file = "torchaudio-2.2.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:55d23986254f7af689695f3fc214c4aa3e73dc931289ecdba7262d73fea7af7a"},
|
||||
{file = "torchaudio-2.2.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:b916b7764698ba9319aa3b25519139892de8665d84438969bac5e1d8578c6a11"},
|
||||
{file = "torchaudio-2.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:281cd4bdb9e65c0618a028b809df9e06f9bd9592aeef8f2b37b4d8a788ce5f2b"},
|
||||
{file = "torchaudio-2.2.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:274cb8474bc1e56b768ef347d3188661c5a9d5e68e2df56fc0aff11cc73c916a"},
|
||||
{file = "torchaudio-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e62c27b17672cc2bdd9663681e533000f9c0984e6a0f3d455f7051bc005bb02"},
|
||||
{file = "torchaudio-2.2.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:7df7d5d9100116be38ff7b27b628820dca4a9e3fe79394605141d339e3b3e46d"},
|
||||
{file = "torchaudio-2.2.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:20b2965db4f843021636f53d3fab1075c3f8959c450c647629124d24c7e6cbb0"},
|
||||
{file = "torchaudio-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:63dd0e840bcf2e4aceb7a98daccfaf7a2a5b3a927647b98bbef449b0b190f2cc"},
|
||||
{file = "torchaudio-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:2c232dc8bee97d303b90833ba934d8905eb7326456236efcd9fa71ccb92fd363"},
|
||||
{file = "torchaudio-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2419387cf04d33047369337bf09c00c2a7673a8f52f80258454c7eca7d205d23"},
|
||||
{file = "torchaudio-2.2.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:2483c0620a68a136359ae90c893608ad5cd73091fb0351b94d33af126a0e3d67"},
|
||||
{file = "torchaudio-2.2.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:bd389f33b7dbfc44e5f4070fc6db00cc560992bea8378a952889acfd772b7022"},
|
||||
{file = "torchaudio-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:d5af725a327b79f3bd8389c53ec51554ee003c18434fc47e68da49b09900132e"},
|
||||
{file = "torchaudio-2.2.1-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:81ef88d7693e3b99007d1ee742fd81b9a92399ecbf88eb7ed69949443005ffba"},
|
||||
{file = "torchaudio-2.2.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f487a7d3177ae6af016750850ee93788e880218a1a310bc6c76901e212f91cd3"},
|
||||
{file = "torchaudio-2.2.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:bee5478ec2cb7d0eaa97023d817aa4914010e1ab0c266f64ef1b0db893aceb49"},
|
||||
{file = "torchaudio-2.2.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:a4462b3f214f60b6b8f78e12a4cf1291c9bc353deed709ac3dfdedbed513a7a3"},
|
||||
{file = "torchaudio-2.2.1-cp38-cp38-win_amd64.whl", hash = "sha256:4bc43d11d9e086f0dfb29f6ea99517d8ec06fa80d97283f2c8b83c4cd467dd1a"},
|
||||
{file = "torchaudio-2.2.1-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:0339fe78ed9c29f704296761b28bb055b5350625ff503ad781704397934e6b58"},
|
||||
{file = "torchaudio-2.2.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:68b1d9f8ffe9b26ef04e80d82ae2dc2f74b1a1eb64c3e8ad21b525802b3bc7ac"},
|
||||
{file = "torchaudio-2.2.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:3962fea5d2511c9ab2b1dd515b45ec44d0c28e51f3b05c0b9fa7bbcc3c213bc1"},
|
||||
{file = "torchaudio-2.2.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:cb2da08abb7b68dc7b0105748b1a736dd33329f841374013ec02c54e04bedf29"},
|
||||
{file = "torchaudio-2.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:54996977ab1c875729e8dedc4695609ca58f876c23756c79979c6b50136b3385"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
torch = "2.2.1"
|
||||
|
||||
[[package]]
|
||||
name = "torchrl"
|
||||
version = "0.4.0+13bef42"
|
||||
@@ -3238,20 +3349,20 @@ jupyter = ["ipytree (>=0.2.2)", "ipywidgets (>=8.0.0)", "notebook"]
|
||||
|
||||
[[package]]
|
||||
name = "zipp"
|
||||
version = "3.17.0"
|
||||
version = "3.18.0"
|
||||
description = "Backport of pathlib-compatible object wrapper for zip files"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "zipp-3.17.0-py3-none-any.whl", hash = "sha256:0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31"},
|
||||
{file = "zipp-3.17.0.tar.gz", hash = "sha256:84e64a1c28cf7e91ed2078bb8cc8c259cb19b76942096c8d7b84947690cabaf0"},
|
||||
{file = "zipp-3.18.0-py3-none-any.whl", hash = "sha256:c1bb803ed69d2cce2373152797064f7e79bc43f0a3748eb494096a867e0ebf79"},
|
||||
{file = "zipp-3.18.0.tar.gz", hash = "sha256:df8d042b02765029a09b157efd8e820451045890acc30f8e37dd2f94a060221f"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
|
||||
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"]
|
||||
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
|
||||
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "3d82309a7b2388d774b56ceb6f6906ef0732d8cedda0d76cc84a30e239949be8"
|
||||
content-hash = "e0c9fa6894aaa917493f81028c1bcc3fff8c56d9025681af44534fc3dbe7646e"
|
||||
|
||||
@@ -50,6 +50,8 @@ diffusers = "^0.26.3"
|
||||
torchvision = "^0.17.1"
|
||||
h5py = "^3.10.0"
|
||||
dm-control = "1.0.14"
|
||||
huggingface-hub = {extras = ["hf-transfer"], version = "^0.21.4"}
|
||||
torchaudio = "^2.2.1"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
330
test.py
Normal file
330
test.py
Normal file
@@ -0,0 +1,330 @@
|
||||
# TODO(rcadene): add tests
|
||||
# TODO(rcadene): what is the best format to store/load videos?
|
||||
|
||||
import subprocess
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchrl
|
||||
from matplotlib import pyplot as plt
|
||||
from tensordict import TensorDict
|
||||
from torchaudio.utils import ffmpeg_utils
|
||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers.samplers import SliceSampler, SliceSamplerWithoutReplacement
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
from torchrl.envs.transforms.transforms import Compose
|
||||
|
||||
from lerobot.common.datasets.transforms import DecodeVideoTransform, KeepFrames, ViewSliceHorizonTransform
|
||||
from lerobot.common.utils import set_seed
|
||||
|
||||
NUM_STATE_CHANNELS = 12
|
||||
NUM_ACTION_CHANNELS = 12
|
||||
|
||||
|
||||
def count_frames(video_path):
|
||||
try:
|
||||
# Construct the ffprobe command to get the number of frames
|
||||
cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"v:0",
|
||||
"-show_entries",
|
||||
"stream=nb_frames",
|
||||
"-of",
|
||||
"default=nokey=1:noprint_wrappers=1",
|
||||
video_path,
|
||||
]
|
||||
|
||||
# Execute the ffprobe command and capture the output
|
||||
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
|
||||
# Convert the output to an integer
|
||||
num_frames = int(result.stdout.strip())
|
||||
|
||||
return num_frames
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
return -1
|
||||
|
||||
|
||||
def get_frame_rate(video_path):
|
||||
try:
|
||||
cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"v:0",
|
||||
"-show_entries",
|
||||
"stream=r_frame_rate",
|
||||
"-of",
|
||||
"default=nokey=1:noprint_wrappers=1",
|
||||
video_path,
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
|
||||
# The frame rate is typically represented as a fraction (e.g., "30000/1001").
|
||||
# To convert it to a float, we can evaluate the fraction.
|
||||
frame_rate = eval(result.stdout.strip())
|
||||
|
||||
return frame_rate
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
return -1
|
||||
|
||||
|
||||
def get_frame_timestamps(frame_rate, num_frames):
|
||||
timestamps = [(1 / frame_rate) * i for i in range(num_frames)]
|
||||
return timestamps
|
||||
|
||||
|
||||
# class ClearDeviceTransform(Transform):
|
||||
# invertible = False
|
||||
|
||||
# def __init__(self):
|
||||
# super().__init__()
|
||||
|
||||
# def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
||||
# # _reset is called once when the environment reset to normalize the first observation
|
||||
# tensordict_reset = self._call(tensordict_reset)
|
||||
# return tensordict_reset
|
||||
|
||||
# @dispatch(source="in_keys", dest="out_keys")
|
||||
# def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||
# return self._call(tensordict)
|
||||
|
||||
# def _call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
# td.clear_device_()
|
||||
# return td
|
||||
|
||||
|
||||
class VideoExperienceReplay(TensorDictReplayBuffer):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = None,
|
||||
*,
|
||||
root: Path = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: SliceSampler = None,
|
||||
collate_fn: Callable = None,
|
||||
writer: Writer = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
self.data_dir = root
|
||||
self.rb_dir = self.data_dir / "replay_buffer"
|
||||
|
||||
storage, meta_data = self._load_or_download()
|
||||
|
||||
# hack to access video paths
|
||||
assert isinstance(transform, Compose)
|
||||
for tf in transform:
|
||||
if isinstance(tf, DecodeVideoTransform):
|
||||
tf.set_video_id_to_path(meta_data["video_id_to_path"])
|
||||
|
||||
super().__init__(
|
||||
storage=storage,
|
||||
sampler=sampler,
|
||||
writer=ImmutableDatasetWriter() if writer is None else writer,
|
||||
collate_fn=_collate_id if collate_fn is None else collate_fn,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
batch_size=batch_size,
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
def _load_or_download(self, force_download=False):
|
||||
if not force_download and self.data_dir.exists():
|
||||
storage = TensorStorage(TensorDict.load_memmap(self.rb_dir))
|
||||
meta_data = torch.load(self.data_dir / "meta_data.pth")
|
||||
else:
|
||||
storage, meta_data = self._download()
|
||||
torch.save(meta_data, self.data_dir / "meta_data.pth")
|
||||
|
||||
# required to not send cuda frames to cpu by default
|
||||
storage._storage.clear_device_()
|
||||
return storage, meta_data
|
||||
|
||||
def _download(self):
|
||||
num_episodes = 1
|
||||
video_id_to_path = {}
|
||||
for episode_id in range(num_episodes):
|
||||
video_path = torchaudio.utils.download_asset(
|
||||
"tutorial-assets/stream-api/NASAs_Most_Scientifically_Complex_Space_Observatory_Requires_Precision-MP4_small.mp4"
|
||||
)
|
||||
# several episodes can belong to the same video
|
||||
video_id = episode_id
|
||||
video_id_to_path[video_id] = video_path
|
||||
|
||||
print(f"{video_path=}")
|
||||
num_frames = count_frames(video_path)
|
||||
print(f"{num_frames=}")
|
||||
frame_rate = get_frame_rate(video_path)
|
||||
print(f"{frame_rate=}")
|
||||
|
||||
frame_timestamps = get_frame_timestamps(frame_rate, num_frames)
|
||||
|
||||
reward = torch.zeros(num_frames, 1, dtype=torch.float32)
|
||||
success = torch.zeros(num_frames, 1, dtype=torch.bool)
|
||||
done = torch.zeros(num_frames, 1, dtype=torch.bool)
|
||||
state = torch.randn(num_frames, NUM_STATE_CHANNELS, dtype=torch.float32)
|
||||
action = torch.randn(num_frames, NUM_ACTION_CHANNELS, dtype=torch.float32)
|
||||
timestamp = torch.tensor(frame_timestamps)
|
||||
frame_id = torch.arange(0, num_frames, 1)
|
||||
episode_id_tensor = torch.tensor([episode_id] * num_frames, dtype=torch.int)
|
||||
video_id_tensor = torch.tensor([video_id] * num_frames, dtype=torch.int)
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
|
||||
ep_td = TensorDict(
|
||||
{
|
||||
("observation", "frame", "video_id"): video_id_tensor[:-1],
|
||||
("observation", "frame", "timestamp"): timestamp[:-1],
|
||||
("observation", "state"): state[:-1],
|
||||
"action": action[:-1],
|
||||
"episode": episode_id_tensor[:-1],
|
||||
"frame_id": frame_id[:-1],
|
||||
("next", "observation", "frame", "video_id"): video_id_tensor[1:],
|
||||
("next", "observation", "frame", "timestamp"): timestamp[1:],
|
||||
("next", "observation", "state"): state[1:],
|
||||
("next", "reward"): reward[1:],
|
||||
("next", "done"): done[1:],
|
||||
("next", "success"): success[1:],
|
||||
},
|
||||
batch_size=num_frames - 1,
|
||||
)
|
||||
|
||||
# TODO:
|
||||
total_frames = num_frames - 1
|
||||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = ep_td[0].expand(total_frames).memmap_like(self.rb_dir)
|
||||
|
||||
td_data[:] = ep_td
|
||||
|
||||
meta_data = {
|
||||
"video_id_to_path": video_id_to_path,
|
||||
}
|
||||
|
||||
return TensorStorage(td_data.lock_()), meta_data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
import tqdm
|
||||
|
||||
print("FFmpeg Library versions:")
|
||||
for k, ver in ffmpeg_utils.get_versions().items():
|
||||
print(f" {k}:\t{'.'.join(str(v) for v in ver)}")
|
||||
|
||||
print("Available NVDEC Decoders:")
|
||||
for k in ffmpeg_utils.get_video_decoders().keys(): # noqa: SIM118
|
||||
if "cuvid" in k:
|
||||
print(f" - {k}")
|
||||
|
||||
def create_replay_buffer(device, format=None):
|
||||
data_dir = Path("tmp/2024_03_17_data_video/pusht")
|
||||
|
||||
num_slices = 1
|
||||
horizon = 2
|
||||
batch_size = num_slices * horizon
|
||||
|
||||
sampler = SliceSamplerWithoutReplacement(
|
||||
num_slices=num_slices,
|
||||
strict_length=True,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
transforms = [
|
||||
# ClearDeviceTransform(),
|
||||
ViewSliceHorizonTransform(num_slices, horizon),
|
||||
KeepFrames(positions=[0], in_keys=[("observation")]),
|
||||
DecodeVideoTransform(
|
||||
data_dir=data_dir,
|
||||
device=device,
|
||||
frame_rate=None,
|
||||
format=format,
|
||||
in_keys=[("observation", "frame")],
|
||||
out_keys=[("observation", "frame", "data")],
|
||||
),
|
||||
]
|
||||
|
||||
replay_buffer = VideoExperienceReplay(
|
||||
root=data_dir,
|
||||
batch_size=batch_size,
|
||||
# prefetch=4,
|
||||
transform=Compose(*transforms),
|
||||
sampler=sampler,
|
||||
)
|
||||
return replay_buffer
|
||||
|
||||
def test_time():
|
||||
replay_buffer = create_replay_buffer(device="cuda")
|
||||
|
||||
start = time.monotonic()
|
||||
for _ in tqdm.tqdm(range(2)):
|
||||
# include_info=False is required to not have a batch_size mismatch error with the truncated key (2,8) != (16, 1)
|
||||
replay_buffer.sample(include_info=False)
|
||||
torch.cuda.synchronize()
|
||||
print(time.monotonic() - start)
|
||||
|
||||
start = time.monotonic()
|
||||
for _ in tqdm.tqdm(range(10)):
|
||||
replay_buffer.sample(include_info=False)
|
||||
torch.cuda.synchronize()
|
||||
print(time.monotonic() - start)
|
||||
|
||||
def test_plot(seed=1337):
|
||||
rb_cuda = create_replay_buffer(device="cuda", format="yuv444p")
|
||||
rb_cpu = create_replay_buffer(device="cpu", format="yuv444p")
|
||||
|
||||
n_rows = 2 # len(replay_buffer)
|
||||
fig, axes = plt.subplots(n_rows, 3, figsize=[12.8, 16.0])
|
||||
for i in range(n_rows):
|
||||
set_seed(seed + i)
|
||||
batch_cpu = rb_cpu.sample(include_info=False)
|
||||
print("frame_ids cpu", batch_cpu["frame_id"].tolist())
|
||||
print("episode cpu", batch_cpu["episode"].tolist())
|
||||
print("timestamps cpu", batch_cpu["observation", "frame", "timestamp"].tolist())
|
||||
frames = batch_cpu["observation", "frame", "data"]
|
||||
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
|
||||
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
|
||||
assert frames.shape[0] == 1
|
||||
axes[i][0].imshow(frames[0])
|
||||
|
||||
set_seed(seed + i)
|
||||
batch_cuda = rb_cuda.sample(include_info=False)
|
||||
print("frame_ids cuda", batch_cuda["frame_id"].tolist())
|
||||
print("episode cuda", batch_cuda["episode"].tolist())
|
||||
print("timestamps cuda", batch_cuda["observation", "frame", "timestamp"].tolist())
|
||||
frames = batch_cuda["observation", "frame", "data"]
|
||||
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
|
||||
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
|
||||
assert frames.shape[0] == 1
|
||||
axes[i][1].imshow(frames[0])
|
||||
|
||||
frames = batch_cuda["observation", "image"].type(torch.uint8)
|
||||
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
|
||||
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
|
||||
assert frames.shape[0] == 1
|
||||
axes[i][2].imshow(frames[0])
|
||||
|
||||
axes[0][0].set_title("Software decoder")
|
||||
axes[0][1].set_title("HW decoder")
|
||||
axes[0][2].set_title("uint8")
|
||||
plt.setp(axes, xticks=[], yticks=[])
|
||||
plt.tight_layout()
|
||||
fig.savefig(rb_cuda.data_dir / "test.png", dpi=300)
|
||||
|
||||
# test_time()
|
||||
test_plot()
|
||||
Reference in New Issue
Block a user