Compare commits

..

1 Commits

Author SHA1 Message Date
Pepijn
0108caacdc Add focal loss 2025-03-11 16:38:26 +01:00
9 changed files with 53 additions and 158 deletions

View File

@@ -126,7 +126,7 @@ jobs:
# portaudio19-dev is needed to install pyaudio
run: |
sudo apt-get update && \
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
sudo apt-get install -y libegl1-mesa-dev portaudio19-dev
- name: Install uv and python
uses: astral-sh/setup-uv@v5

View File

@@ -16,13 +16,6 @@ exclude: ^(tests/data)
default_language_version:
python: python3.10
repos:
##### Meta #####
- repo: meta
hooks:
- id: check-useless-excludes
- id: check-hooks-apply
##### Style / Misc. #####
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
@@ -35,18 +28,15 @@ repos:
- id: check-toml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/crate-ci/typos
rev: v1.30.2
rev: v1
hooks:
- id: typos
args: [--force-exclude]
- repo: https://github.com/asottile/pyupgrade
rev: v3.19.1
hooks:
- id: pyupgrade
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.10
hooks:
@@ -54,18 +44,15 @@ repos:
args: [--fix]
- id: ruff-format
##### Security #####
- repo: https://github.com/gitleaks/gitleaks
rev: v8.24.0
hooks:
- id: gitleaks
- repo: https://github.com/woodruffw/zizmor-pre-commit
rev: v1.4.1
hooks:
- id: zizmor
- repo: https://github.com/PyCQA/bandit
rev: 1.8.3
hooks:

View File

@@ -67,7 +67,7 @@ def parse_int_or_none(value) -> int | None:
def check_datasets_formats(repo_ids: list) -> None:
for repo_id in repo_ids:
dataset = LeRobotDataset(repo_id)
if len(dataset.meta.video_keys) > 0:
if dataset.video:
raise ValueError(
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
)

View File

@@ -67,7 +67,7 @@ from lerobot.common.datasets.utils import (
)
from lerobot.common.datasets.video_utils import (
VideoFrame,
decode_video_frames,
decode_video_frames_torchvision,
encode_video_frames,
get_video_info,
)
@@ -462,8 +462,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
video files are already present on local disk, they won't be downloaded again. Defaults to
True.
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec.
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
"""
super().__init__()
self.repo_id = repo_id
@@ -473,7 +473,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episodes = episodes
self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = video_backend if video_backend else "torchcodec"
self.video_backend = video_backend if video_backend else "pyav"
self.delta_indices = None
# Unused attributes
@@ -707,7 +707,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
item = {}
for vid_key, query_ts in query_timestamps.items():
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
frames = decode_video_frames_torchvision(
video_path, query_ts, self.tolerance_s, self.video_backend
)
item[vid_key] = frames.squeeze(0)
return item
@@ -1027,7 +1029,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.delta_timestamps = None
obj.delta_indices = None
obj.episode_data_index = None
obj.video_backend = video_backend if video_backend is not None else "torchcodec"
obj.video_backend = video_backend if video_backend is not None else "pyav"
return obj

View File

@@ -27,35 +27,6 @@ import torch
import torchvision
from datasets.features.features import register_feature
from PIL import Image
from torchcodec.decoders import VideoDecoder
def decode_video_frames(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
backend: str = "torchcodec",
) -> torch.Tensor:
"""
Decodes video frames using the specified backend.
Args:
video_path (Path): Path to the video file.
timestamps (list[float]): List of timestamps to extract frames.
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec".
Returns:
torch.Tensor: Decoded frames.
Currently supports torchcodec on cpu and pyav.
"""
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
elif backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
else:
raise ValueError(f"Unsupported video backend: {backend}")
def decode_video_frames_torchvision(
@@ -156,76 +127,6 @@ def decode_video_frames_torchvision(
return closest_frames
def decode_video_frames_torchcodec(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
device: str = "cpu",
log_loaded_timestamps: bool = False,
) -> torch.Tensor:
"""Loads frames associated with the requested timestamps of a video using torchcodec.
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
and all subsequent frames until reaching the requested frame. The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes.
"""
video_path = str(video_path)
# initialize video decoder
decoder = VideoDecoder(video_path, device=device)
loaded_frames = []
loaded_ts = []
# get metadata for frame information
metadata = decoder.metadata
average_fps = metadata.average_fps
# convert timestamps to frame indices
frame_indices = [round(ts * average_fps) for ts in timestamps]
# retrieve frames based on indices
frames_batch = decoder.get_frames_at(indices=frame_indices)
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False):
loaded_frames.append(frame)
loaded_ts.append(pts.item())
if log_loaded_timestamps:
logging.info(f"Frame loaded at timestamp={pts:.4f}")
query_ts = torch.tensor(timestamps)
loaded_ts = torch.tensor(loaded_ts)
# compute distances between each query timestamp and loaded timestamps
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
"It means that the closest frame that can be loaded from the video is too far away in time."
"This might be due to synchronization issues with timestamps during data collection."
"To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
)
# get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
closest_ts = loaded_ts[argmin_]
if log_loaded_timestamps:
logging.info(f"{closest_ts=}")
# convert to float32 in [0,1] range (channel first)
closest_frames = closest_frames.type(torch.float32) / 255
assert len(timestamps) == len(closest_frames)
return closest_frames
def encode_video_frames(
imgs_dir: Path | str,
video_path: Path | str,

View File

@@ -38,6 +38,38 @@ from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
def focal_regression_loss(
input: torch.Tensor,
target: torch.Tensor,
gamma: float = 2.0,
alpha: float = 0.25,
reduction: str = "mean",
) -> torch.Tensor:
"""
Computes a focal version of the L1 loss for regression tasks.
Args:
input (Tensor): Predicted values.
target (Tensor): Ground-truth values.
gamma (float): Focusing parameter. (How strongly the loss focuses on difficult examples l1 = 0, more is > 0)
alpha (float): Weighting factor. (Balancing parameter to weigh the focal term, preventing excessively large gradients. Lower alpha helps in controlling aggressive scaling, maintaining stable training)
reduction (str): 'mean', 'sum', or 'none'.
Returns:
Tensor: The computed loss.
"""
# Standard L1 error
l1_loss = torch.abs(input - target)
focal_weight = (1 - torch.exp(-l1_loss)) ** gamma
loss = alpha * focal_weight * l1_loss
if reduction == "mean":
return loss.mean()
elif reduction == "sum":
return loss.sum()
else:
return loss
class ACTPolicy(PreTrainedPolicy):
"""
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
@@ -155,11 +187,13 @@ class ACTPolicy(PreTrainedPolicy):
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean()
focal_loss = focal_regression_loss(
batch["action"], actions_hat, gamma=2.0, alpha=0.25, reduction="none"
)
focal_loss = focal_loss * ~batch["action_is_pad"].unsqueeze(-1)
focal_loss = focal_loss.mean()
loss_dict = {"l1_loss": l1_loss.item()}
loss_dict = {"focal_loss": focal_loss.item()}
if self.config.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
@@ -169,9 +203,9 @@ class ACTPolicy(PreTrainedPolicy):
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kld_loss"] = mean_kld.item()
loss = l1_loss + mean_kld * self.config.kl_weight
loss = focal_loss + mean_kld * self.config.kl_weight
else:
loss = l1_loss
loss = focal_loss
return loss, loss_dict

View File

@@ -265,25 +265,13 @@ def main():
),
)
parser.add_argument(
"--tolerance-s",
type=float,
default=1e-4,
help=(
"Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
"This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
"If not given, defaults to 1e-4."
),
)
args = parser.parse_args()
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
root = kwargs.pop("root")
tolerance_s = kwargs.pop("tolerance_s")
logging.info("Loading dataset")
dataset = LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s)
dataset = LeRobotDataset(repo_id, root=root)
visualize_dataset(dataset, **vars(args))

View File

@@ -446,31 +446,15 @@ def main():
help="Delete the output directory if it exists already.",
)
parser.add_argument(
"--tolerance-s",
type=float,
default=1e-4,
help=(
"Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
"This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
"If not given, defaults to 1e-4."
),
)
args = parser.parse_args()
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
load_from_hf_hub = kwargs.pop("load_from_hf_hub")
root = kwargs.pop("root")
tolerance_s = kwargs.pop("tolerance_s")
dataset = None
if repo_id:
dataset = (
LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s)
if not load_from_hf_hub
else get_dataset_info(repo_id)
)
dataset = LeRobotDataset(repo_id, root=root) if not load_from_hf_hub else get_dataset_info(repo_id)
visualize_dataset_html(dataset, **vars(args))

View File

@@ -69,7 +69,6 @@ dependencies = [
"rerun-sdk>=0.21.0",
"termcolor>=2.4.0",
"torch>=2.2.1",
"torchcodec>=0.2.1",
"torchvision>=0.21.0",
"wandb>=0.16.3",
"zarr>=2.17.0",