Compare commits

..

9 Commits

Author SHA1 Message Date
Pepijn
1892aa1b08 update comment 2025-03-17 08:41:36 +01:00
Pepijn
3b6fff70e1 remove beta 2025-03-17 08:30:02 +01:00
Pepijn
6e97876e81 remove important sampling 2025-03-17 08:27:17 +01:00
Pepijn
4bdbf2f6e0 update comment 2025-03-14 16:59:31 +01:00
Pepijn
4e9b4dd380 remove beta annealing 2025-03-14 13:22:22 +01:00
Pepijn
17d12db7c4 Add Important sampling, only use replacement, remove beta smoothing 2025-03-14 13:09:05 +01:00
Pepijn
6a8be97bb5 remove power of 2 optimization 2025-03-11 13:29:55 +01:00
Pepijn
841d54c050 Use sampler always (temp fix) 2025-03-11 12:23:51 +01:00
Pepijn
e3c3c165aa Add inital weighted sampling as prioritzed experience raplay with sum tree 2025-03-11 12:04:40 +01:00
11 changed files with 173 additions and 159 deletions

View File

@@ -126,7 +126,7 @@ jobs:
# portaudio19-dev is needed to install pyaudio
run: |
sudo apt-get update && \
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
sudo apt-get install -y libegl1-mesa-dev portaudio19-dev
- name: Install uv and python
uses: astral-sh/setup-uv@v5

View File

@@ -16,13 +16,6 @@ exclude: ^(tests/data)
default_language_version:
python: python3.10
repos:
##### Meta #####
- repo: meta
hooks:
- id: check-useless-excludes
- id: check-hooks-apply
##### Style / Misc. #####
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
@@ -35,18 +28,15 @@ repos:
- id: check-toml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/crate-ci/typos
rev: v1.30.2
rev: v1
hooks:
- id: typos
args: [--force-exclude]
- repo: https://github.com/asottile/pyupgrade
rev: v3.19.1
hooks:
- id: pyupgrade
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.10
hooks:
@@ -54,18 +44,15 @@ repos:
args: [--fix]
- id: ruff-format
##### Security #####
- repo: https://github.com/gitleaks/gitleaks
rev: v8.24.0
hooks:
- id: gitleaks
- repo: https://github.com/woodruffw/zizmor-pre-commit
rev: v1.4.1
hooks:
- id: zizmor
- repo: https://github.com/PyCQA/bandit
rev: 1.8.3
hooks:

View File

@@ -67,7 +67,7 @@ def parse_int_or_none(value) -> int | None:
def check_datasets_formats(repo_ids: list) -> None:
for repo_id in repo_ids:
dataset = LeRobotDataset(repo_id)
if len(dataset.meta.video_keys) > 0:
if dataset.video:
raise ValueError(
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
)

View File

@@ -67,7 +67,7 @@ from lerobot.common.datasets.utils import (
)
from lerobot.common.datasets.video_utils import (
VideoFrame,
decode_video_frames,
decode_video_frames_torchvision,
encode_video_frames,
get_video_info,
)
@@ -462,8 +462,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
video files are already present on local disk, they won't be downloaded again. Defaults to
True.
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec.
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
"""
super().__init__()
self.repo_id = repo_id
@@ -473,7 +473,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episodes = episodes
self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = video_backend if video_backend else "torchcodec"
self.video_backend = video_backend if video_backend else "pyav"
self.delta_indices = None
# Unused attributes
@@ -707,7 +707,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
item = {}
for vid_key, query_ts in query_timestamps.items():
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
frames = decode_video_frames_torchvision(
video_path, query_ts, self.tolerance_s, self.video_backend
)
item[vid_key] = frames.squeeze(0)
return item
@@ -747,6 +749,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
task_idx = item["task_index"].item()
item["task"] = self.meta.tasks[task_idx]
# Add global index of frame (indices)
item["indices"] = torch.tensor(idx)
return item
def __repr__(self):
@@ -1027,7 +1032,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.delta_timestamps = None
obj.delta_indices = None
obj.episode_data_index = None
obj.video_backend = video_backend if video_backend is not None else "torchcodec"
obj.video_backend = video_backend if video_backend is not None else "pyav"
return obj

View File

@@ -13,9 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterator, Union
import random
from typing import Iterator, List, Optional, Union
import torch
from torch.utils.data import Sampler
class EpisodeAwareSampler:
@@ -59,3 +61,123 @@ class EpisodeAwareSampler:
def __len__(self) -> int:
return len(self.indices)
class SumTree:
"""
A classic sum-tree data structure for storing priorities.
Each leaf stores a sample's priority, and internal nodes store sums of children.
"""
def __init__(self, capacity: int):
"""
Args:
capacity: Maximum number of elements.
"""
self.capacity = capacity
self.size = capacity
self.tree = [0.0] * (2 * self.size)
def initialize_tree(self, priorities: List[float]):
"""
Initializes the sum tree
"""
# Set leaf values
for i, priority in enumerate(priorities):
self.tree[i + self.size] = priority
# Compute internal node values
for i in range(self.size - 1, 0, -1):
self.tree[i] = self.tree[2 * i] + self.tree[2 * i + 1]
def update(self, idx: int, priority: float):
"""
Update the priority at leaf index `idx` and propagate changes upwards.
"""
tree_idx = idx + self.size
self.tree[tree_idx] = priority # Set new priority
# Propagate up, explicitly summing children
tree_idx //= 2
while tree_idx >= 1:
self.tree[tree_idx] = self.tree[2 * tree_idx] + self.tree[2 * tree_idx + 1]
tree_idx //= 2
def total_priority(self) -> float:
"""Returns the sum of all priorities (stored at root)."""
return self.tree[1]
def sample(self, value: float) -> int:
"""
Samples an index where the prefix sum up to that leaf is >= `value`.
"""
value = min(max(value, 0), self.total_priority()) # Clamp value
idx = 1
while idx < self.size:
left = 2 * idx
if self.tree[left] >= value:
idx = left
else:
value -= self.tree[left]
idx = left + 1
return idx - self.size # Convert tree index to data index
class PrioritizedSampler(Sampler[int]):
"""
PyTorch Sampler that draws samples in proportion to their priority using a SumTree.
"""
def __init__(
self,
data_len: int,
alpha: float = 0.6,
eps: float = 1e-6,
num_samples_per_epoch: Optional[int] = None,
):
"""
Args:
data_len: Total number of samples in the dataset.
alpha: Exponent for priority scaling. Default is 0.6.
eps: Small constant to avoid zero priorities.
num_samples_per_epoch: Number of samples per epoch (default is data_len).
"""
self.data_len = data_len
self.alpha = alpha
self.eps = eps
self.num_samples_per_epoch = num_samples_per_epoch or data_len
# Initialize difficulties and sum-tree
self.difficulties = [1.0] * data_len
self.priorities = [0.0] * data_len
initial_priorities = [(1.0 + eps) ** alpha] * data_len
self.sumtree = SumTree(data_len)
self.sumtree.initialize_tree(initial_priorities)
for i, p in enumerate(initial_priorities):
self.priorities[i] = p
def update_priorities(self, indices: List[int], difficulties: List[float]):
"""
Updates the priorities in the sum-tree.
"""
for idx, diff in zip(indices, difficulties, strict=False):
self.difficulties[idx] = diff
new_priority = (diff + self.eps) ** self.alpha
self.priorities[idx] = new_priority
self.sumtree.update(idx, new_priority)
def __iter__(self) -> Iterator[int]:
"""
Samples indices based on their priority weights.
"""
total_p = self.sumtree.total_priority()
for _ in range(self.num_samples_per_epoch):
r = random.random() * total_p
idx = self.sumtree.sample(r)
yield idx
def __len__(self) -> int:
return self.num_samples_per_epoch

View File

@@ -27,35 +27,6 @@ import torch
import torchvision
from datasets.features.features import register_feature
from PIL import Image
from torchcodec.decoders import VideoDecoder
def decode_video_frames(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
backend: str = "torchcodec",
) -> torch.Tensor:
"""
Decodes video frames using the specified backend.
Args:
video_path (Path): Path to the video file.
timestamps (list[float]): List of timestamps to extract frames.
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec".
Returns:
torch.Tensor: Decoded frames.
Currently supports torchcodec on cpu and pyav.
"""
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
elif backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
else:
raise ValueError(f"Unsupported video backend: {backend}")
def decode_video_frames_torchvision(
@@ -156,76 +127,6 @@ def decode_video_frames_torchvision(
return closest_frames
def decode_video_frames_torchcodec(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
device: str = "cpu",
log_loaded_timestamps: bool = False,
) -> torch.Tensor:
"""Loads frames associated with the requested timestamps of a video using torchcodec.
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
and all subsequent frames until reaching the requested frame. The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes.
"""
video_path = str(video_path)
# initialize video decoder
decoder = VideoDecoder(video_path, device=device)
loaded_frames = []
loaded_ts = []
# get metadata for frame information
metadata = decoder.metadata
average_fps = metadata.average_fps
# convert timestamps to frame indices
frame_indices = [round(ts * average_fps) for ts in timestamps]
# retrieve frames based on indices
frames_batch = decoder.get_frames_at(indices=frame_indices)
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False):
loaded_frames.append(frame)
loaded_ts.append(pts.item())
if log_loaded_timestamps:
logging.info(f"Frame loaded at timestamp={pts:.4f}")
query_ts = torch.tensor(timestamps)
loaded_ts = torch.tensor(loaded_ts)
# compute distances between each query timestamp and loaded timestamps
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
"It means that the closest frame that can be loaded from the video is too far away in time."
"This might be due to synchronization issues with timestamps during data collection."
"To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
)
# get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
closest_ts = loaded_ts[argmin_]
if log_loaded_timestamps:
logging.info(f"{closest_ts=}")
# convert to float32 in [0,1] range (channel first)
closest_frames = closest_frames.type(torch.float32) / 255
assert len(timestamps) == len(closest_frames)
return closest_frames
def encode_video_frames(
imgs_dir: Path | str,
video_path: Path | str,

View File

@@ -155,11 +155,14 @@ class ACTPolicy(PreTrainedPolicy):
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean()
elementwise_l1 = F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch[
"action_is_pad"
].unsqueeze(-1)
l1_loss = elementwise_l1.mean()
l1_per_sample = elementwise_l1.mean(dim=(1, 2))
loss_dict = {"l1_loss": l1_loss.item()}
if self.config.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
@@ -168,9 +171,17 @@ class ACTPolicy(PreTrainedPolicy):
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kld_loss"] = mean_kld.item()
loss_dict = {
"l1_loss": l1_loss.item(),
"kld_loss": mean_kld.item(),
"per_sample_l1": l1_per_sample,
}
loss = l1_loss + mean_kld * self.config.kl_weight
else:
loss_dict = {
"l1_loss": l1_loss.item(),
"per_sample_l1": l1_per_sample,
}
loss = l1_loss
return loss, loss_dict

View File

@@ -25,7 +25,7 @@ from torch.amp import GradScaler
from torch.optim import Optimizer
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.sampler import EpisodeAwareSampler, PrioritizedSampler
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.optim.factory import make_optimizer_and_scheduler
@@ -70,6 +70,7 @@ def update_policy(
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
loss, output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
grad_scaler.scale(loss).backward()
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
@@ -126,6 +127,7 @@ def train(cfg: TrainPipelineConfig):
logging.info("Creating dataset")
dataset = make_dataset(cfg)
data_len = len(dataset)
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
@@ -174,6 +176,15 @@ def train(cfg: TrainPipelineConfig):
shuffle = True
sampler = None
# TODO(pepijn): If experiment works integrate this
shuffle = False
sampler = PrioritizedSampler(
data_len=data_len,
alpha=0.6,
eps=1e-6,
num_samples_per_epoch=data_len,
)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=cfg.num_workers,
@@ -220,6 +231,12 @@ def train(cfg: TrainPipelineConfig):
use_amp=cfg.policy.use_amp,
)
# Update sampler
if "indices" in batch and "per_sample_l1" in output_dict:
idxs = batch["indices"].cpu().tolist()
diffs = output_dict["per_sample_l1"].detach().cpu().tolist()
sampler.update_priorities(idxs, diffs)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
# increment `step` here.
step += 1

View File

@@ -265,25 +265,13 @@ def main():
),
)
parser.add_argument(
"--tolerance-s",
type=float,
default=1e-4,
help=(
"Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
"This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
"If not given, defaults to 1e-4."
),
)
args = parser.parse_args()
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
root = kwargs.pop("root")
tolerance_s = kwargs.pop("tolerance_s")
logging.info("Loading dataset")
dataset = LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s)
dataset = LeRobotDataset(repo_id, root=root)
visualize_dataset(dataset, **vars(args))

View File

@@ -446,31 +446,15 @@ def main():
help="Delete the output directory if it exists already.",
)
parser.add_argument(
"--tolerance-s",
type=float,
default=1e-4,
help=(
"Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
"This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
"If not given, defaults to 1e-4."
),
)
args = parser.parse_args()
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
load_from_hf_hub = kwargs.pop("load_from_hf_hub")
root = kwargs.pop("root")
tolerance_s = kwargs.pop("tolerance_s")
dataset = None
if repo_id:
dataset = (
LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s)
if not load_from_hf_hub
else get_dataset_info(repo_id)
)
dataset = LeRobotDataset(repo_id, root=root) if not load_from_hf_hub else get_dataset_info(repo_id)
visualize_dataset_html(dataset, **vars(args))

View File

@@ -69,7 +69,6 @@ dependencies = [
"rerun-sdk>=0.21.0",
"termcolor>=2.4.0",
"torch>=2.2.1",
"torchcodec>=0.2.1",
"torchvision>=0.21.0",
"wandb>=0.16.3",
"zarr>=2.17.0",