Merge remote-tracking branch 'upstream/main' into unify_policy_api
This commit is contained in:
43
lerobot/commands/env.py
Normal file
43
lerobot/commands/env.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import platform
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
# import dataset
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot import __version__ as version
|
||||
|
||||
pt_version = torch.__version__
|
||||
pt_cuda_available = torch.cuda.is_available()
|
||||
pt_cuda_available = torch.cuda.is_available()
|
||||
cuda_version = torch._C._cuda_getCompiledVersion() if torch.version.cuda is not None else "N/A"
|
||||
|
||||
|
||||
# TODO(aliberts): refactor into an actual command `lerobot env`
|
||||
def get_env_info() -> dict:
|
||||
"""Run this to get basic system info to help for tracking issues & bugs."""
|
||||
info = {
|
||||
"`lerobot` version": version,
|
||||
"Platform": platform.platform(),
|
||||
"Python version": platform.python_version(),
|
||||
"Huggingface_hub version": huggingface_hub.__version__,
|
||||
# TODO(aliberts): Add dataset when https://github.com/huggingface/lerobot/pull/73 is merged
|
||||
# "Dataset version": dataset.__version__,
|
||||
"Numpy version": np.__version__,
|
||||
"PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
|
||||
"Cuda version": cuda_version,
|
||||
"Using GPU in script?": "<fill in>",
|
||||
"Using distributed or parallel set-up in script?": "<fill in>",
|
||||
}
|
||||
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
|
||||
print(format_dict(info))
|
||||
return info
|
||||
|
||||
|
||||
def format_dict(d: dict) -> str:
|
||||
return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
get_env_info()
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Helper code for loading PushT dataset from Diffusion Policy (https://diffusion-policy.cs.columbia.edu/)
|
||||
|
||||
Copied from the original Diffusion Policy repository.
|
||||
Copied from the original Diffusion Policy repository and used in our `download_and_upload_dataset.py` script.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -1,72 +1,19 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import gdown
|
||||
import h5py
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
from lerobot.common.datasets.utils import load_data_with_delta_timestamps
|
||||
|
||||
FOLDER_URLS = {
|
||||
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
|
||||
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
|
||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo",
|
||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj",
|
||||
}
|
||||
|
||||
EP48_URLS = {
|
||||
"aloha_sim_insertion_human": "https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link",
|
||||
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link",
|
||||
}
|
||||
|
||||
EP49_URLS = {
|
||||
"aloha_sim_insertion_human": "https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link",
|
||||
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link",
|
||||
}
|
||||
|
||||
NUM_EPISODES = {
|
||||
"aloha_sim_insertion_human": 50,
|
||||
"aloha_sim_insertion_scripted": 50,
|
||||
"aloha_sim_transfer_cube_human": 50,
|
||||
"aloha_sim_transfer_cube_scripted": 50,
|
||||
}
|
||||
|
||||
EPISODE_LEN = {
|
||||
"aloha_sim_insertion_human": 500,
|
||||
"aloha_sim_insertion_scripted": 400,
|
||||
"aloha_sim_transfer_cube_human": 400,
|
||||
"aloha_sim_transfer_cube_scripted": 400,
|
||||
}
|
||||
|
||||
CAMERAS = {
|
||||
"aloha_sim_insertion_human": ["top"],
|
||||
"aloha_sim_insertion_scripted": ["top"],
|
||||
"aloha_sim_transfer_cube_human": ["top"],
|
||||
"aloha_sim_transfer_cube_scripted": ["top"],
|
||||
}
|
||||
|
||||
|
||||
def download(data_dir, dataset_id):
|
||||
assert dataset_id in FOLDER_URLS
|
||||
assert dataset_id in EP48_URLS
|
||||
assert dataset_id in EP49_URLS
|
||||
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
gdown.download_folder(FOLDER_URLS[dataset_id], output=str(data_dir))
|
||||
|
||||
# because of the 50 files limit per directory, two files episode 48 and 49 were missing
|
||||
gdown.download(EP48_URLS[dataset_id], output=str(data_dir / "episode_48.hdf5"), fuzzy=True)
|
||||
gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True)
|
||||
from lerobot.common.datasets.utils import load_previous_and_future_frames
|
||||
|
||||
|
||||
class AlohaDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human
|
||||
https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted
|
||||
https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human
|
||||
https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted
|
||||
"""
|
||||
|
||||
available_datasets = [
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
@@ -79,8 +26,9 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.2",
|
||||
version: str | None = "v1.0",
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
@@ -88,120 +36,48 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
|
||||
self.data_dir = self.root / f"{self.dataset_id}"
|
||||
if (self.data_dir / "data_dict.pth").exists() and (
|
||||
self.data_dir / "data_ids_per_episode.pth"
|
||||
).exists():
|
||||
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
|
||||
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
|
||||
if self.root is not None:
|
||||
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
else:
|
||||
self._download_and_preproc_obsolete()
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
|
||||
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
|
||||
self.data_dict = load_dataset(
|
||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
||||
)
|
||||
self.data_dict = self.data_dict.with_format("torch")
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
|
||||
return len(self.data_dict)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.data_ids_per_episode)
|
||||
return len(self.data_dict.unique("episode_id"))
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = {}
|
||||
item = self.data_dict[idx]
|
||||
|
||||
# get episode id and timestamp of the sampled frame
|
||||
current_ts = self.data_dict["timestamp"][idx].item()
|
||||
episode = self.data_dict["episode"][idx].item()
|
||||
if self.delta_timestamps is not None:
|
||||
item = load_previous_and_future_frames(
|
||||
item,
|
||||
self.data_dict,
|
||||
self.delta_timestamps,
|
||||
)
|
||||
|
||||
for key in self.data_dict:
|
||||
if self.delta_timestamps is not None and key in self.delta_timestamps:
|
||||
data, is_pad = load_data_with_delta_timestamps(
|
||||
self.data_dict,
|
||||
self.data_ids_per_episode,
|
||||
self.delta_timestamps,
|
||||
key,
|
||||
current_ts,
|
||||
episode,
|
||||
)
|
||||
item[key] = data
|
||||
item[f"{key}_is_pad"] = is_pad
|
||||
# convert images from channel last (PIL) to channel first (pytorch)
|
||||
for key in self.image_keys:
|
||||
if item[key].ndim == 3:
|
||||
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
|
||||
elif item[key].ndim == 4:
|
||||
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
|
||||
else:
|
||||
item[key] = self.data_dict[key][idx]
|
||||
raise ValueError(item[key].ndim)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
return item
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
assert self.root is not None
|
||||
raw_dir = self.root / f"{self.dataset_id}_raw"
|
||||
if not raw_dir.is_dir():
|
||||
download(raw_dir, self.dataset_id)
|
||||
|
||||
total_frames = 0
|
||||
logging.info("Compute total number of frames to initialize offline buffer")
|
||||
for ep_id in range(NUM_EPISODES[self.dataset_id]):
|
||||
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
total_frames += ep["/action"].shape[0] - 1
|
||||
logging.info(f"{total_frames=}")
|
||||
|
||||
self.data_ids_per_episode = {}
|
||||
ep_dicts = []
|
||||
|
||||
frame_idx = 0
|
||||
for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])):
|
||||
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
num_frames = ep["/action"].shape[0]
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
done[-1] = True
|
||||
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
|
||||
ep_dict = {
|
||||
"observation.state": state,
|
||||
"action": action,
|
||||
"episode": torch.tensor([ep_id] * num_frames),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||
# "next.observation.state": state,
|
||||
# TODO(rcadene): compute reward and success
|
||||
# "next.reward": reward[1:],
|
||||
"next.done": done[1:],
|
||||
# "next.success": success[1:],
|
||||
}
|
||||
|
||||
for cam in CAMERAS[self.dataset_id]:
|
||||
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
|
||||
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
|
||||
ep_dict[f"observation.images.{cam}"] = image[:-1]
|
||||
# ep_dict[f"next.observation.images.{cam}"] = image[1:]
|
||||
|
||||
assert isinstance(ep_id, int)
|
||||
self.data_ids_per_episode[ep_id] = torch.arange(frame_idx, frame_idx + num_frames, 1)
|
||||
assert len(self.data_ids_per_episode[ep_id]) == num_frames
|
||||
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
frame_idx += num_frames
|
||||
|
||||
self.data_dict = {}
|
||||
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
|
||||
self.data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
@@ -8,9 +8,6 @@ from torchvision.transforms import v2
|
||||
from lerobot.common.datasets.utils import compute_stats
|
||||
from lerobot.common.transforms import NormalizeTransform, Prod
|
||||
|
||||
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
|
||||
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
|
||||
# to load a subset of our datasets for faster continuous integration.
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
|
||||
@@ -19,6 +16,7 @@ def make_dataset(
|
||||
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
|
||||
normalize=True,
|
||||
stats_path=None,
|
||||
split="train",
|
||||
):
|
||||
if cfg.env.name == "xarm":
|
||||
from lerobot.common.datasets.xarm import XarmDataset
|
||||
@@ -54,19 +52,23 @@ def make_dataset(
|
||||
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
elif stats_path is None:
|
||||
# instantiate a one frame dataset with light transform
|
||||
stats_dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
root=DATA_DIR,
|
||||
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||
)
|
||||
|
||||
# load stats if the file exists already or compute stats and save it
|
||||
precomputed_stats_path = stats_dataset.data_dir / "stats.pth"
|
||||
if DATA_DIR is None:
|
||||
# TODO(rcadene): clean stats
|
||||
precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth"
|
||||
else:
|
||||
precomputed_stats_path = DATA_DIR / cfg.dataset_id / "stats.pth"
|
||||
if precomputed_stats_path.exists():
|
||||
stats = torch.load(precomputed_stats_path)
|
||||
else:
|
||||
logging.info(f"compute_stats and save to {precomputed_stats_path}")
|
||||
# instantiate a one frame dataset with light transform
|
||||
stats_dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
split="train",
|
||||
root=DATA_DIR,
|
||||
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||
)
|
||||
stats = compute_stats(stats_dataset)
|
||||
torch.save(stats, stats_path)
|
||||
else:
|
||||
@@ -94,6 +96,7 @@ def make_dataset(
|
||||
|
||||
dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
split=split,
|
||||
root=DATA_DIR,
|
||||
delta_timestamps=delta_timestamps,
|
||||
transform=transforms,
|
||||
|
||||
@@ -1,24 +1,14 @@
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
from lerobot.common.datasets._diffusion_policy_replay_buffer import (
|
||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
||||
)
|
||||
from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps
|
||||
|
||||
# as define in env
|
||||
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
||||
|
||||
PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
||||
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
|
||||
from lerobot.common.datasets.utils import load_previous_and_future_frames
|
||||
|
||||
|
||||
class PushtDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
https://huggingface.co/datasets/lerobot/pusht
|
||||
|
||||
Arguments
|
||||
----------
|
||||
@@ -34,8 +24,9 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.2",
|
||||
version: str | None = "v1.0",
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
@@ -43,174 +34,48 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
|
||||
self.data_dir = self.root / f"{self.dataset_id}"
|
||||
if (self.data_dir / "data_dict.pth").exists() and (
|
||||
self.data_dir / "data_ids_per_episode.pth"
|
||||
).exists():
|
||||
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
|
||||
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
|
||||
if self.root is not None:
|
||||
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
else:
|
||||
self._download_and_preproc_obsolete()
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
|
||||
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
|
||||
self.data_dict = load_dataset(
|
||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
||||
)
|
||||
self.data_dict = self.data_dict.with_format("torch")
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
|
||||
return len(self.data_dict)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.data_ids_per_episode)
|
||||
return len(self.data_dict.unique("episode_id"))
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = {}
|
||||
item = self.data_dict[idx]
|
||||
|
||||
# get episode id and timestamp of the sampled frame
|
||||
current_ts = self.data_dict["timestamp"][idx].item()
|
||||
episode = self.data_dict["episode"][idx].item()
|
||||
if self.delta_timestamps is not None:
|
||||
item = load_previous_and_future_frames(
|
||||
item,
|
||||
self.data_dict,
|
||||
self.delta_timestamps,
|
||||
)
|
||||
|
||||
for key in self.data_dict:
|
||||
if self.delta_timestamps is not None and key in self.delta_timestamps:
|
||||
data, is_pad = load_data_with_delta_timestamps(
|
||||
self.data_dict,
|
||||
self.data_ids_per_episode,
|
||||
self.delta_timestamps,
|
||||
key,
|
||||
current_ts,
|
||||
episode,
|
||||
)
|
||||
item[key] = data
|
||||
item[f"{key}_is_pad"] = is_pad
|
||||
# convert images from channel last (PIL) to channel first (pytorch)
|
||||
for key in self.image_keys:
|
||||
if item[key].ndim == 3:
|
||||
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
|
||||
elif item[key].ndim == 4:
|
||||
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
|
||||
else:
|
||||
item[key] = self.data_dict[key][idx]
|
||||
raise ValueError(item[key].ndim)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
return item
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
try:
|
||||
import pymunk
|
||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||
except ModuleNotFoundError as e:
|
||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
||||
raise e
|
||||
|
||||
assert self.root is not None
|
||||
raw_dir = self.root / f"{self.dataset_id}_raw"
|
||||
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
||||
if not zarr_path.is_dir():
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
download_and_extract_zip(PUSHT_URL, raw_dir)
|
||||
|
||||
# load
|
||||
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(
|
||||
zarr_path
|
||||
) # , keys=['img', 'state', 'action'])
|
||||
|
||||
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
|
||||
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
||||
total_frames = dataset_dict["action"].shape[0]
|
||||
# to create test artifact
|
||||
# num_episodes = 1
|
||||
# total_frames = 50
|
||||
assert len(
|
||||
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
|
||||
), "Some data type dont have the same number of total frames."
|
||||
|
||||
# TODO: verify that goal pose is expected to be fixed
|
||||
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
||||
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
|
||||
|
||||
imgs = torch.from_numpy(dataset_dict["img"])
|
||||
imgs = einops.rearrange(imgs, "b h w c -> b c h w")
|
||||
states = torch.from_numpy(dataset_dict["state"])
|
||||
actions = torch.from_numpy(dataset_dict["action"])
|
||||
|
||||
self.data_ids_per_episode = {}
|
||||
ep_dicts = []
|
||||
|
||||
idx0 = 0
|
||||
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||
idx1 = dataset_dict.meta["episode_ends"][episode_id]
|
||||
|
||||
num_frames = idx1 - idx0
|
||||
|
||||
assert (episode_ids[idx0:idx1] == episode_id).all()
|
||||
|
||||
image = imgs[idx0:idx1]
|
||||
assert image.min() >= 0.0
|
||||
assert image.max() <= 255.0
|
||||
image = image.type(torch.uint8)
|
||||
|
||||
state = states[idx0:idx1]
|
||||
agent_pos = state[:, :2]
|
||||
block_pos = state[:, 2:4]
|
||||
block_angle = state[:, 4]
|
||||
|
||||
reward = torch.zeros(num_frames)
|
||||
success = torch.zeros(num_frames, dtype=torch.bool)
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
for i in range(num_frames):
|
||||
space = pymunk.Space()
|
||||
space.gravity = 0, 0
|
||||
space.damping = 0
|
||||
|
||||
# Add walls.
|
||||
walls = [
|
||||
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
|
||||
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
|
||||
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
|
||||
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
|
||||
]
|
||||
space.add(*walls)
|
||||
|
||||
block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
goal_area = goal_geom.area
|
||||
coverage = intersection_area / goal_area
|
||||
reward[i] = np.clip(coverage / SUCCESS_THRESHOLD, 0, 1)
|
||||
success[i] = coverage > SUCCESS_THRESHOLD
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
|
||||
ep_dict = {
|
||||
"observation.image": image,
|
||||
"observation.state": agent_pos,
|
||||
"action": actions[idx0:idx1],
|
||||
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||
# "next.observation.image": image[1:],
|
||||
# "next.observation.state": agent_pos[1:],
|
||||
# TODO(rcadene): verify that reward and done are aligned with image and agent_pos
|
||||
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
|
||||
"next.done": torch.cat([done[1:], done[[-1]]]),
|
||||
"next.success": torch.cat([success[1:], success[[-1]]]),
|
||||
}
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
assert isinstance(episode_id, int)
|
||||
self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
|
||||
assert len(self.data_ids_per_episode[episode_id]) == num_frames
|
||||
|
||||
idx0 = idx1
|
||||
|
||||
self.data_dict = {}
|
||||
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
|
||||
self.data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
@@ -1,51 +1,20 @@
|
||||
import io
|
||||
import zipfile
|
||||
from copy import deepcopy
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import requests
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
|
||||
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||
print(f"downloading from {url}")
|
||||
response = requests.get(url, stream=True)
|
||||
if response.status_code == 200:
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
|
||||
|
||||
zip_file = io.BytesIO()
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
zip_file.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
zip_file.seek(0)
|
||||
|
||||
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
||||
zip_ref.extractall(destination_folder)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def load_data_with_delta_timestamps(
|
||||
data_dict: dict[torch.Tensor],
|
||||
data_ids_per_episode: dict[torch.Tensor],
|
||||
delta_timestamps: list[float],
|
||||
key: str,
|
||||
current_ts: float,
|
||||
episode: int,
|
||||
def load_previous_and_future_frames(
|
||||
item: dict[str, torch.Tensor],
|
||||
data_dict: dict[str, torch.Tensor],
|
||||
delta_timestamps: dict[str, list[float]],
|
||||
tol: float = 0.04,
|
||||
):
|
||||
) -> dict[torch.Tensor]:
|
||||
"""
|
||||
Given a current timestamp (e.g. current_ts=0.6) and a list of timestamps differences (e.g. delta_timestamps=[-0.8, -0.2, 0, 0.2]),
|
||||
this function compute the query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames of the specified modality (e.g. key="observation.image").
|
||||
Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}),
|
||||
this function computes for each given modality a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames in the dataset.
|
||||
|
||||
Importantly, when no frame can be found around a query timestamp within a specified tolerance window (e.g. tol=0.04), this function raises an AssertionError.
|
||||
When a timestamp is queried before the first available timestamp of the episode or after the last available timestamp,
|
||||
@@ -54,56 +23,57 @@ def load_data_with_delta_timestamps(
|
||||
or to pad the observations in a specific way. Note that by default the observation frames before the start of the episode are the same as the first frame of the episode.
|
||||
|
||||
Parameters:
|
||||
- data_dict (dict): A dictionary containing the data, where each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
||||
- data_ids_per_episode (dict): A dictionary where keys are episode identifiers and values are lists of indices corresponding to frames associated with each episode.
|
||||
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible key to be retrieved. These deltas are added to the current_ts to form the query timestamps.
|
||||
- key (str): The key specifying which data modality is to be retrieved from the data_dict.
|
||||
- current_ts (float): The current timestamp to which the delta timestamps are added to form the query timestamps.
|
||||
- episode (int): The identifier of the episode from which frames are to be retrieved.
|
||||
- item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
||||
- data_dict (dict): A dictionary containing the full dataset. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
||||
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be retrieved. These deltas are added to the item timestamp to form the query timestamps.
|
||||
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query timestamp. Defaults to 0.04.
|
||||
|
||||
Returns:
|
||||
- tuple: A tuple containing two elements:
|
||||
- The first element is the data retrieved from the specified modality based on the closest match to the query timestamps.
|
||||
- The second element is a boolean array indicating which frames were considered as padding (True if the distance to the closest timestamp was greater than the tolerance level).
|
||||
- The same item with the queried frames for each modality specified in delta_timestamps, with an additional key for each modality (e.g. "observation.image_is_pad").
|
||||
|
||||
Raises:
|
||||
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization issues with timestamps during data collection.
|
||||
"""
|
||||
# get indices of the frames associated to the episode, and their timestamps
|
||||
ep_data_ids = data_ids_per_episode[episode]
|
||||
ep_timestamps = data_dict["timestamp"][ep_data_ids]
|
||||
ep_data_id_from = item["episode_data_index_from"].item()
|
||||
ep_data_id_to = item["episode_data_index_to"].item()
|
||||
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
|
||||
|
||||
# load timestamps
|
||||
ep_timestamps = data_dict.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
|
||||
|
||||
# we make the assumption that the timestamps are sorted
|
||||
ep_first_ts = ep_timestamps[0]
|
||||
ep_last_ts = ep_timestamps[-1]
|
||||
current_ts = item["timestamp"].item()
|
||||
|
||||
# get timestamps used as query to retrieve data of previous/future frames
|
||||
delta_ts = delta_timestamps[key]
|
||||
query_ts = current_ts + torch.tensor(delta_ts)
|
||||
for key in delta_timestamps:
|
||||
# get timestamps used as query to retrieve data of previous/future frames
|
||||
delta_ts = delta_timestamps[key]
|
||||
query_ts = current_ts + torch.tensor(delta_ts)
|
||||
|
||||
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
|
||||
dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1)
|
||||
min_, argmin_ = dist.min(1)
|
||||
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
|
||||
dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1)
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
# get the indices of the data that are closest to the query timestamps
|
||||
data_ids = ep_data_ids[argmin_]
|
||||
# closest_ts = ep_timestamps[argmin_]
|
||||
# TODO(rcadene): synchronize timestamps + interpolation if needed
|
||||
|
||||
# get the data
|
||||
data = data_dict[key][data_ids].clone()
|
||||
is_pad = min_ > tol
|
||||
|
||||
# TODO(rcadene): synchronize timestamps + interpolation if needed
|
||||
# check violated query timestamps are all outside the episode range
|
||||
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
|
||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=}) inside episode range."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
)
|
||||
|
||||
is_pad = min_ > tol
|
||||
# get dataset indices corresponding to frames to be loaded
|
||||
data_ids = ep_data_ids[argmin_]
|
||||
|
||||
# check violated query timestamps are all outside the episode range
|
||||
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
|
||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=}) inside episode range."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
)
|
||||
# load frames modality
|
||||
item[key] = data_dict.select_columns(key)[data_ids][key]
|
||||
item[f"{key}_is_pad"] = is_pad
|
||||
|
||||
return data, is_pad
|
||||
return item
|
||||
|
||||
|
||||
def get_stats_einops_patterns(dataset):
|
||||
|
||||
@@ -1,30 +1,16 @@
|
||||
import pickle
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
from lerobot.common.datasets.utils import load_data_with_delta_timestamps
|
||||
|
||||
|
||||
def download(raw_dir):
|
||||
import gdown
|
||||
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
||||
zip_path = raw_dir / "data.zip"
|
||||
gdown.download(url, str(zip_path), quiet=False)
|
||||
print("Extracting...")
|
||||
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
|
||||
for member in zip_f.namelist():
|
||||
if member.startswith("data/xarm") and member.endswith(".pkl"):
|
||||
print(member)
|
||||
zip_f.extract(member=member)
|
||||
zip_path.unlink()
|
||||
from lerobot.common.datasets.utils import load_previous_and_future_frames
|
||||
|
||||
|
||||
class XarmDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
https://huggingface.co/datasets/lerobot/xarm_lift_medium
|
||||
"""
|
||||
|
||||
available_datasets = [
|
||||
"xarm_lift_medium",
|
||||
]
|
||||
@@ -34,8 +20,9 @@ class XarmDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.1",
|
||||
version: str | None = "v1.0",
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
@@ -43,121 +30,48 @@ class XarmDataset(torch.utils.data.Dataset):
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
|
||||
self.data_dir = self.root / f"{self.dataset_id}"
|
||||
if (self.data_dir / "data_dict.pth").exists() and (
|
||||
self.data_dir / "data_ids_per_episode.pth"
|
||||
).exists():
|
||||
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
|
||||
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
|
||||
if self.root is not None:
|
||||
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
else:
|
||||
self._download_and_preproc_obsolete()
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
|
||||
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
|
||||
self.data_dict = load_dataset(
|
||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
||||
)
|
||||
self.data_dict = self.data_dict.with_format("torch")
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
|
||||
return len(self.data_dict)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.data_ids_per_episode)
|
||||
return len(self.data_dict.unique("episode_id"))
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = {}
|
||||
item = self.data_dict[idx]
|
||||
|
||||
# get episode id and timestamp of the sampled frame
|
||||
current_ts = self.data_dict["timestamp"][idx].item()
|
||||
episode = self.data_dict["episode"][idx].item()
|
||||
if self.delta_timestamps is not None:
|
||||
item = load_previous_and_future_frames(
|
||||
item,
|
||||
self.data_dict,
|
||||
self.delta_timestamps,
|
||||
)
|
||||
|
||||
for key in self.data_dict:
|
||||
if self.delta_timestamps is not None and key in self.delta_timestamps:
|
||||
data, is_pad = load_data_with_delta_timestamps(
|
||||
self.data_dict,
|
||||
self.data_ids_per_episode,
|
||||
self.delta_timestamps,
|
||||
key,
|
||||
current_ts,
|
||||
episode,
|
||||
)
|
||||
item[key] = data
|
||||
item[f"{key}_is_pad"] = is_pad
|
||||
# convert images from channel last (PIL) to channel first (pytorch)
|
||||
for key in self.image_keys:
|
||||
if item[key].ndim == 3:
|
||||
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
|
||||
elif item[key].ndim == 4:
|
||||
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
|
||||
else:
|
||||
item[key] = self.data_dict[key][idx]
|
||||
raise ValueError(item[key].ndim)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
return item
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
assert self.root is not None
|
||||
raw_dir = self.root / f"{self.dataset_id}_raw"
|
||||
if not raw_dir.exists():
|
||||
download(raw_dir)
|
||||
|
||||
dataset_path = self.root / f"{self.dataset_id}" / "buffer.pkl"
|
||||
print(f"Using offline dataset '{dataset_path}'")
|
||||
with open(dataset_path, "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
|
||||
total_frames = dataset_dict["actions"].shape[0]
|
||||
|
||||
self.data_ids_per_episode = {}
|
||||
ep_dicts = []
|
||||
|
||||
idx0 = 0
|
||||
idx1 = 0
|
||||
episode_id = 0
|
||||
for i in tqdm.tqdm(range(total_frames)):
|
||||
idx1 += 1
|
||||
|
||||
if not dataset_dict["dones"][i]:
|
||||
continue
|
||||
|
||||
num_frames = idx1 - idx0
|
||||
|
||||
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
|
||||
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
|
||||
action = torch.tensor(dataset_dict["actions"][idx0:idx1])
|
||||
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
|
||||
# it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
|
||||
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
|
||||
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
|
||||
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
|
||||
|
||||
ep_dict = {
|
||||
"observation.image": image,
|
||||
"observation.state": state,
|
||||
"action": action,
|
||||
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||
# "next.observation.image": next_image,
|
||||
# "next.observation.state": next_state,
|
||||
"next.reward": next_reward,
|
||||
"next.done": next_done,
|
||||
}
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
assert isinstance(episode_id, int)
|
||||
self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
|
||||
assert len(self.data_ids_per_episode[episode_id]) == num_frames
|
||||
|
||||
idx0 = idx1
|
||||
episode_id += 1
|
||||
|
||||
self.data_dict = {}
|
||||
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
|
||||
self.data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
@@ -14,7 +14,7 @@ def preprocess_observation(observation, transform=None):
|
||||
imgs = {"observation.image": observation["pixels"]}
|
||||
|
||||
for imgkey, img in imgs.items():
|
||||
img = torch.from_numpy(img).float()
|
||||
img = torch.from_numpy(img)
|
||||
# convert to (b c h w) torch format
|
||||
img = einops.rearrange(img, "b h w c -> b c h w")
|
||||
obs[imgkey] = img
|
||||
|
||||
@@ -41,7 +41,9 @@ import gymnasium as gym
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.envs.factory import make_env
|
||||
@@ -199,38 +201,48 @@ def eval_policy(
|
||||
ep_dicts = []
|
||||
num_episodes = dones.shape[0]
|
||||
total_frames = 0
|
||||
idx0 = idx1 = 0
|
||||
data_ids_per_episode = {}
|
||||
idx_from = 0
|
||||
for ep_id in range(num_episodes):
|
||||
num_frames = done_indices[ep_id].item() + 1
|
||||
total_frames += num_frames
|
||||
|
||||
# TODO(rcadene): We need to add a missing last frame which is the observation
|
||||
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||
ep_dict = {
|
||||
"action": actions[ep_id, :num_frames],
|
||||
"episode": torch.tensor([ep_id] * num_frames),
|
||||
"episode_id": torch.tensor([ep_id] * num_frames),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||
"next.done": dones[ep_id, :num_frames],
|
||||
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
|
||||
"episode_data_index_from": torch.tensor([idx_from] * num_frames),
|
||||
"episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames),
|
||||
}
|
||||
for key in observations:
|
||||
ep_dict[key] = observations[key][ep_id, :num_frames]
|
||||
ep_dict[key] = observations[key][ep_id][:num_frames]
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
total_frames += num_frames
|
||||
idx1 += num_frames
|
||||
|
||||
data_ids_per_episode[ep_id] = torch.arange(idx0, idx1, 1)
|
||||
|
||||
idx0 = idx1
|
||||
idx_from += num_frames
|
||||
|
||||
# similar logic is implemented in dataset preprocessing
|
||||
data_dict = {}
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
if "image" not in key:
|
||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
else:
|
||||
if key not in data_dict:
|
||||
data_dict[key] = []
|
||||
for ep_dict in ep_dicts:
|
||||
for x in ep_dict[key]:
|
||||
# c h w -> h w c
|
||||
img = PILImage.fromarray(x.permute(1, 2, 0).numpy())
|
||||
data_dict[key].append(img)
|
||||
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
data_dict = Dataset.from_dict(data_dict).with_format("torch")
|
||||
|
||||
if max_episodes_rendered > 0:
|
||||
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
||||
|
||||
@@ -280,10 +292,7 @@ def eval_policy(
|
||||
"eval_s": time.time() - start,
|
||||
"eval_ep_s": (time.time() - start) / num_episodes,
|
||||
},
|
||||
"episodes": {
|
||||
"data_dict": data_dict,
|
||||
"data_ids_per_episode": data_ids_per_episode,
|
||||
},
|
||||
"episodes": data_dict,
|
||||
}
|
||||
if max_episodes_rendered > 0:
|
||||
info["videos"] = videos
|
||||
|
||||
@@ -4,6 +4,8 @@ from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
from datasets import concatenate_datasets
|
||||
from datasets.utils.logging import disable_progress_bar
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
@@ -128,29 +130,33 @@ def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
|
||||
return -(n_off * pc_on) / (n_on * (pc_on - 1))
|
||||
|
||||
|
||||
def add_episodes_inplace(episodes, online_dataset, concat_dataset, sampler, pc_online_samples):
|
||||
data_dict = episodes["data_dict"]
|
||||
data_ids_per_episode = episodes["data_ids_per_episode"]
|
||||
def add_episodes_inplace(data_dict, online_dataset, concat_dataset, sampler, pc_online_samples):
|
||||
first_episode_id = data_dict.select_columns("episode_id")[0]["episode_id"].item()
|
||||
first_index = data_dict.select_columns("index")[0]["index"].item()
|
||||
assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}"
|
||||
assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}"
|
||||
|
||||
if len(online_dataset) == 0:
|
||||
# initialize online dataset
|
||||
online_dataset.data_dict = data_dict
|
||||
online_dataset.data_ids_per_episode = data_ids_per_episode
|
||||
else:
|
||||
# find episode index and data frame indices according to previous episode in online_dataset
|
||||
start_episode = max(online_dataset.data_ids_per_episode.keys()) + 1
|
||||
start_index = online_dataset.data_dict["index"][-1].item() + 1
|
||||
data_dict["episode"] += start_episode
|
||||
data_dict["index"] += start_index
|
||||
start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1
|
||||
start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1
|
||||
|
||||
def shift_indices(example):
|
||||
# note: we dont shift "frame_id" since it represents the index of the frame in the episode it belongs to
|
||||
example["episode_id"] += start_episode
|
||||
example["index"] += start_index
|
||||
example["episode_data_index_from"] += start_index
|
||||
example["episode_data_index_to"] += start_index
|
||||
return example
|
||||
|
||||
disable_progress_bar() # map has a tqdm progress bar
|
||||
data_dict = data_dict.map(shift_indices)
|
||||
|
||||
# extend online dataset
|
||||
for key in data_dict:
|
||||
# TODO(rcadene): avoid reallocating memory at every step by preallocating memory or changing our data structure
|
||||
online_dataset.data_dict[key] = torch.cat([online_dataset.data_dict[key], data_dict[key]])
|
||||
for ep_id in data_ids_per_episode:
|
||||
online_dataset.data_ids_per_episode[ep_id + start_episode] = (
|
||||
data_ids_per_episode[ep_id] + start_index
|
||||
)
|
||||
online_dataset.data_dict = concatenate_datasets([online_dataset.data_dict, data_dict])
|
||||
|
||||
# update the concatenated dataset length used during sampling
|
||||
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
|
||||
@@ -269,7 +275,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
# create an empty online dataset similar to offline dataset
|
||||
online_dataset = deepcopy(offline_dataset)
|
||||
online_dataset.data_dict = {}
|
||||
online_dataset.data_ids_per_episode = {}
|
||||
|
||||
# create dataloader for online training
|
||||
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
||||
|
||||
@@ -62,12 +62,12 @@ def render_dataset(dataset, out_dir, max_num_episodes):
|
||||
)
|
||||
dl_iter = iter(dataloader)
|
||||
|
||||
num_episodes = len(dataset.data_ids_per_episode)
|
||||
for ep_id in range(min(max_num_episodes, num_episodes)):
|
||||
for ep_id in range(min(max_num_episodes, dataset.num_episodes)):
|
||||
logging.info(f"Rendering episode {ep_id}")
|
||||
|
||||
frames = {}
|
||||
for _ in dataset.data_ids_per_episode[ep_id]:
|
||||
end_of_episode = False
|
||||
while not end_of_episode:
|
||||
item = next(dl_iter)
|
||||
|
||||
for im_key in dataset.image_keys:
|
||||
@@ -77,6 +77,8 @@ def render_dataset(dataset, out_dir, max_num_episodes):
|
||||
# add current frame to list of frames to render
|
||||
frames[im_key].append(item[im_key])
|
||||
|
||||
end_of_episode = item["index"].item() == item["episode_data_index_to"].item() - 1
|
||||
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
for im_key in dataset.image_keys:
|
||||
if len(dataset.image_keys) > 1:
|
||||
|
||||
Reference in New Issue
Block a user