Merge remote-tracking branch 'upstream/main' into refactor_dp
This commit is contained in:
@@ -105,7 +105,7 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.data_dict["index"])
|
||||
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import v2
|
||||
|
||||
from lerobot.common.datasets.utils import compute_or_load_stats
|
||||
from lerobot.common.datasets.utils import compute_stats
|
||||
from lerobot.common.transforms import NormalizeTransform, Prod
|
||||
|
||||
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
|
||||
@@ -40,7 +41,8 @@ def make_dataset(
|
||||
if normalize:
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
|
||||
# min_max_from_spec
|
||||
# stats = dataset.compute_or_load_stats() if stats_path is None else torch.load(stats_path)
|
||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||
|
||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||
stats = {}
|
||||
@@ -51,21 +53,27 @@ def make_dataset(
|
||||
stats["action"] = {}
|
||||
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
else:
|
||||
elif stats_path is None:
|
||||
# instantiate a one frame dataset with light transform
|
||||
stats_dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
root=DATA_DIR,
|
||||
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||
)
|
||||
stats = compute_or_load_stats(stats_dataset)
|
||||
|
||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||
# load stats if the file exists already or compute stats and save it
|
||||
precomputed_stats_path = stats_dataset.data_dir / "stats.pth"
|
||||
if precomputed_stats_path.exists():
|
||||
stats = torch.load(precomputed_stats_path)
|
||||
else:
|
||||
logging.info(f"compute_stats and save to {precomputed_stats_path}")
|
||||
stats = compute_stats(stats_dataset)
|
||||
torch.save(stats, stats_path)
|
||||
else:
|
||||
stats = torch.load(stats_path)
|
||||
|
||||
transforms = v2.Compose(
|
||||
[
|
||||
# TODO(rcadene): we need to do something about image_keys
|
||||
Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||
NormalizeTransform(
|
||||
stats,
|
||||
|
||||
@@ -2,11 +2,8 @@ from pathlib import Path
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import pygame
|
||||
import pymunk
|
||||
import torch
|
||||
import tqdm
|
||||
from gym_pusht.envs.pusht import pymunk_to_shapely
|
||||
|
||||
from lerobot.common.datasets._diffusion_policy_replay_buffer import (
|
||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
||||
@@ -20,64 +17,6 @@ PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
||||
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
|
||||
|
||||
|
||||
def get_goal_pose_body(pose):
|
||||
mass = 1
|
||||
inertia = pymunk.moment_for_box(mass, (50, 100))
|
||||
body = pymunk.Body(mass, inertia)
|
||||
# preserving the legacy assignment order for compatibility
|
||||
# the order here doesn't matter somehow, maybe because CoM is aligned with body origin
|
||||
body.position = pose[:2].tolist()
|
||||
body.angle = pose[2]
|
||||
return body
|
||||
|
||||
|
||||
def add_segment(space, a, b, radius):
|
||||
shape = pymunk.Segment(space.static_body, a, b, radius)
|
||||
shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names
|
||||
return shape
|
||||
|
||||
|
||||
def add_tee(
|
||||
space,
|
||||
position,
|
||||
angle,
|
||||
scale=30,
|
||||
color="LightSlateGray",
|
||||
mask=None,
|
||||
):
|
||||
if mask is None:
|
||||
mask = pymunk.ShapeFilter.ALL_MASKS()
|
||||
mass = 1
|
||||
length = 4
|
||||
vertices1 = [
|
||||
(-length * scale / 2, scale),
|
||||
(length * scale / 2, scale),
|
||||
(length * scale / 2, 0),
|
||||
(-length * scale / 2, 0),
|
||||
]
|
||||
inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1)
|
||||
vertices2 = [
|
||||
(-scale / 2, scale),
|
||||
(-scale / 2, length * scale),
|
||||
(scale / 2, length * scale),
|
||||
(scale / 2, scale),
|
||||
]
|
||||
inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1)
|
||||
body = pymunk.Body(mass, inertia1 + inertia2)
|
||||
shape1 = pymunk.Poly(body, vertices1)
|
||||
shape2 = pymunk.Poly(body, vertices2)
|
||||
shape1.color = pygame.Color(color)
|
||||
shape2.color = pygame.Color(color)
|
||||
shape1.filter = pymunk.ShapeFilter(mask=mask)
|
||||
shape2.filter = pymunk.ShapeFilter(mask=mask)
|
||||
body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2
|
||||
body.position = position
|
||||
body.angle = angle
|
||||
body.friction = 1
|
||||
space.add(body, shape1, shape2)
|
||||
return body
|
||||
|
||||
|
||||
class PushtDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
|
||||
@@ -121,7 +60,7 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.data_dict["index"])
|
||||
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
@@ -158,6 +97,13 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
return item
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
try:
|
||||
import pymunk
|
||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||
except ModuleNotFoundError as e:
|
||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
||||
raise e
|
||||
|
||||
assert self.root is not None
|
||||
raw_dir = self.root / f"{self.dataset_id}_raw"
|
||||
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
||||
@@ -182,7 +128,7 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
|
||||
# TODO: verify that goal pose is expected to be fixed
|
||||
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
||||
goal_body = get_goal_pose_body(goal_pos_angle)
|
||||
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
|
||||
|
||||
imgs = torch.from_numpy(dataset_dict["img"])
|
||||
imgs = einops.rearrange(imgs, "b h w c -> b c h w")
|
||||
@@ -201,6 +147,9 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
assert (episode_ids[idx0:idx1] == episode_id).all()
|
||||
|
||||
image = imgs[idx0:idx1]
|
||||
assert image.min() >= 0.0
|
||||
assert image.max() <= 255.0
|
||||
image = image.type(torch.uint8)
|
||||
|
||||
state = states[idx0:idx1]
|
||||
agent_pos = state[:, :2]
|
||||
@@ -217,14 +166,14 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
|
||||
# Add walls.
|
||||
walls = [
|
||||
add_segment(space, (5, 506), (5, 5), 2),
|
||||
add_segment(space, (5, 5), (506, 5), 2),
|
||||
add_segment(space, (506, 5), (506, 506), 2),
|
||||
add_segment(space, (5, 506), (506, 506), 2),
|
||||
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
|
||||
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
|
||||
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
|
||||
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
|
||||
]
|
||||
space.add(*walls)
|
||||
|
||||
block_body = add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
@@ -265,16 +214,3 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
|
||||
self.data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset = PushtDataset(
|
||||
"pusht",
|
||||
root=Path("data"),
|
||||
delta_timestamps={
|
||||
"observation.image": [0, -1, -0.2, -0.1],
|
||||
"observation.state": [0, -1, -0.2, -0.1],
|
||||
"action": [-0.1, 0, 1, 2, 3],
|
||||
},
|
||||
)
|
||||
dataset[10]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import io
|
||||
import logging
|
||||
import zipfile
|
||||
from copy import deepcopy
|
||||
from math import ceil
|
||||
@@ -35,52 +34,56 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def euclidean_distance_matrix(mat0, mat1):
|
||||
# Compute the square of the distance matrix
|
||||
sq0 = torch.sum(mat0**2, dim=1, keepdim=True)
|
||||
sq1 = torch.sum(mat1**2, dim=1, keepdim=True)
|
||||
distance_sq = sq0 + sq1.transpose(0, 1) - 2 * mat0 @ mat1.transpose(0, 1)
|
||||
|
||||
# Taking the square root to get the euclidean distance
|
||||
distance = torch.sqrt(torch.clamp(distance_sq, min=0))
|
||||
return distance
|
||||
|
||||
|
||||
def is_contiguously_true_or_false(bool_vector):
|
||||
assert bool_vector.ndim == 1
|
||||
assert bool_vector.dtype == torch.bool
|
||||
|
||||
# Compare each element with its neighbor to find changes
|
||||
changes = bool_vector[1:] != bool_vector[:-1]
|
||||
|
||||
# Count the number of changes
|
||||
num_changes = changes.sum().item()
|
||||
|
||||
# If there's more than one change, the list is not contiguous
|
||||
return num_changes <= 1
|
||||
|
||||
# examples = [
|
||||
# ([True, False, True, False, False, False], False),
|
||||
# ([True, True, True, False, False, False], True),
|
||||
# ([False, False, False, False, False, False], True)
|
||||
# ]
|
||||
# for bool_list, expected in examples:
|
||||
# result = is_contiguously_true_or_false(bool_list)
|
||||
|
||||
|
||||
def load_data_with_delta_timestamps(
|
||||
data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode
|
||||
data_dict: dict[torch.Tensor],
|
||||
data_ids_per_episode: dict[torch.Tensor],
|
||||
delta_timestamps: list[float],
|
||||
key: str,
|
||||
current_ts: float,
|
||||
episode: int,
|
||||
tol: float = 0.04,
|
||||
):
|
||||
"""
|
||||
Given a current timestamp (e.g. current_ts=0.6) and a list of timestamps differences (e.g. delta_timestamps=[-0.8, -0.2, 0, 0.2]),
|
||||
this function compute the query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames of the specified modality (e.g. key="observation.image").
|
||||
|
||||
Importantly, when no frame can be found around a query timestamp within a specified tolerance window (e.g. tol=0.04), this function raises an AssertionError.
|
||||
When a timestamp is queried before the first available timestamp of the episode or after the last available timestamp,
|
||||
the violation of the tolerance doesnt raise an AssertionError, and the function populates a boolean array indicating which frames are outside of the episode range.
|
||||
For instance, this boolean array is useful during batched training to not supervise actions associated to timestamps coming after the end of the episode,
|
||||
or to pad the observations in a specific way. Note that by default the observation frames before the start of the episode are the same as the first frame of the episode.
|
||||
|
||||
Parameters:
|
||||
- data_dict (dict): A dictionary containing the data, where each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
||||
- data_ids_per_episode (dict): A dictionary where keys are episode identifiers and values are lists of indices corresponding to frames associated with each episode.
|
||||
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible key to be retrieved. These deltas are added to the current_ts to form the query timestamps.
|
||||
- key (str): The key specifying which data modality is to be retrieved from the data_dict.
|
||||
- current_ts (float): The current timestamp to which the delta timestamps are added to form the query timestamps.
|
||||
- episode (int): The identifier of the episode from which frames are to be retrieved.
|
||||
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query timestamp. Defaults to 0.04.
|
||||
|
||||
Returns:
|
||||
- tuple: A tuple containing two elements:
|
||||
- The first element is the data retrieved from the specified modality based on the closest match to the query timestamps.
|
||||
- The second element is a boolean array indicating which frames were considered as padding (True if the distance to the closest timestamp was greater than the tolerance level).
|
||||
|
||||
Raises:
|
||||
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization issues with timestamps during data collection.
|
||||
"""
|
||||
# get indices of the frames associated to the episode, and their timestamps
|
||||
ep_data_ids = data_ids_per_episode[episode]
|
||||
ep_timestamps = data_dict["timestamp"][ep_data_ids]
|
||||
|
||||
# we make the assumption that the timestamps are sorted
|
||||
ep_first_ts = ep_timestamps[0]
|
||||
ep_last_ts = ep_timestamps[-1]
|
||||
|
||||
# get timestamps used as query to retrieve data of previous/future frames
|
||||
delta_ts = delta_timestamps[key]
|
||||
query_ts = current_ts + torch.tensor(delta_ts)
|
||||
|
||||
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
|
||||
dist = euclidean_distance_matrix(query_ts[:, None], ep_timestamps[:, None])
|
||||
dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1)
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
# get the indices of the data that are closest to the query timestamps
|
||||
@@ -92,24 +95,29 @@ def load_data_with_delta_timestamps(
|
||||
|
||||
# TODO(rcadene): synchronize timestamps + interpolation if needed
|
||||
|
||||
tol = 0.04
|
||||
is_pad = min_ > tol
|
||||
|
||||
assert is_contiguously_true_or_false(is_pad), (
|
||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=})."
|
||||
# check violated query timestamps are all outside the episode range
|
||||
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
|
||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=}) inside episode range."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
)
|
||||
|
||||
return data, is_pad
|
||||
|
||||
|
||||
def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
|
||||
stats_path = dataset.data_dir / "stats.pth"
|
||||
if stats_path.exists():
|
||||
return torch.load(stats_path)
|
||||
def get_stats_einops_patterns(dataset):
|
||||
"""These einops patterns will be used to aggregate batches and compute statistics."""
|
||||
stats_patterns = {
|
||||
"action": "b c -> c",
|
||||
"observation.state": "b c -> c",
|
||||
}
|
||||
for key in dataset.image_keys:
|
||||
stats_patterns[key] = "b c h w -> c 1 1"
|
||||
return stats_patterns
|
||||
|
||||
logging.info(f"compute_stats and save to {stats_path}")
|
||||
|
||||
def compute_stats(dataset, batch_size=32, max_num_samples=None):
|
||||
if max_num_samples is None:
|
||||
max_num_samples = len(dataset)
|
||||
else:
|
||||
@@ -124,13 +132,8 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
# these einops patterns will be used to aggregate batches and compute statistics
|
||||
stats_patterns = {
|
||||
"action": "b c -> c",
|
||||
"observation.state": "b c -> c",
|
||||
}
|
||||
for key in dataset.image_keys:
|
||||
stats_patterns[key] = "b c h w -> c 1 1"
|
||||
# get einops patterns to aggregate batches and compute statistics
|
||||
stats_patterns = get_stats_einops_patterns(dataset)
|
||||
|
||||
# mean and std will be computed incrementally while max and min will track the running value.
|
||||
mean, std, max, min = {}, {}, {}, {}
|
||||
@@ -201,7 +204,6 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
|
||||
"min": min[key],
|
||||
}
|
||||
|
||||
torch.save(stats, stats_path)
|
||||
return stats
|
||||
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ class XarmDataset(torch.utils.data.Dataset):
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.data_dict["index"])
|
||||
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
@@ -126,7 +126,8 @@ class XarmDataset(torch.utils.data.Dataset):
|
||||
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
|
||||
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
|
||||
action = torch.tensor(dataset_dict["actions"][idx0:idx1])
|
||||
# TODO(rcadene): concat the last "next_observations" to "observations"
|
||||
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
|
||||
# it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
|
||||
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
|
||||
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
|
||||
|
||||
Reference in New Issue
Block a user