Compare commits

...

25 Commits

Author SHA1 Message Date
Adil Zouitine
278b56bce9 Add rlpd tricks 2025-01-15 15:49:24 +01:00
Adil Zouitine
0ffc0a7170 SAC works 2025-01-14 11:34:52 +01:00
Adil Zouitine
43d9890489 remove breakpoint 2025-01-13 17:58:00 +01:00
Adil Zouitine
963be41003 [WIP] correct sac implementation 2025-01-13 17:54:11 +01:00
Adil Zouitine
9edae4a8de Correct losses and factorisation 2025-01-07 17:07:55 +01:00
Ke-Wang1017
89d8189d8b remove unused debug lines 2025-01-06 10:18:40 +00:00
Ke-Wang1017
8b70b129dc improvements from JClinton to speed up loading offline data 2025-01-06 10:15:45 +00:00
joeclinton1
db3925df28 Potential fixes for SAC instability and NAN bug 2025-01-06 10:15:01 +00:00
Ke-Wang1017
f99e670976 Refactor SACPolicy and configuration for improved training dynamics
- Introduced target critic networks in SACPolicy to enhance stability during training.
- Updated TD target calculation to incorporate entropy adjustments, improving robustness.
- Increased online buffer capacity in configuration from 10,000 to 40,000 for better data handling.
- Adjusted learning rates for critic, actor, and temperature to 3e-4 for optimized training performance.

These changes aim to refine the SAC implementation, enhancing its robustness and performance during training and inference.
2025-01-06 10:14:34 +00:00
KeWang1017
eec28baa63 fix the bug of target critic updates, roll back to origial temperature implementation, added debug logging info 2025-01-06 10:14:09 +00:00
KeWang1017
f1f04eb4f9 use mean instead of sampled action for the inference 2025-01-06 10:12:24 +00:00
KeWang1017
77a7f92139 1, add input normalization in configuration_sac.py 2, add masking on loss computation 2025-01-06 10:11:51 +00:00
Michel Aractingi
35de91ef2b added temporary fix for missing task_index key in online environment 2024-12-30 13:47:28 +00:00
Michel Aractingi
ee306e2f9b split encoder for critic and actor 2024-12-29 23:59:39 +00:00
Michel Aractingi
bae3b02928 style fixes 2024-12-29 14:35:21 +00:00
KeWang1017
5b4adc00bb Refactor SAC configuration and policy for improved action sampling and stability
- Updated SACConfig to replace standard deviation parameterization with log_std_min and log_std_max for better control over action distributions.
- Modified SACPolicy to streamline action selection and log probability calculations, enhancing stochastic behavior.
- Removed deprecated TanhMultivariateNormalDiag class to simplify the codebase and improve maintainability.

These changes aim to enhance the robustness and performance of the SAC implementation during training and inference.
2024-12-29 14:27:19 +00:00
KeWang1017
22fbc9ea4a Refine SAC configuration and policy for enhanced performance
- Updated standard deviation parameterization in SACConfig to 'softplus' with defined min and max values for improved stability.
- Modified action sampling in SACPolicy to use reparameterized sampling, ensuring better gradient flow and log probability calculations.
- Cleaned up log probability calculations in TanhMultivariateNormalDiag for clarity and efficiency.
- Increased evaluation frequency in YAML configuration to 50000 for more efficient training cycles.

These changes aim to enhance the robustness and performance of the SAC implementation during training and inference.
2024-12-29 14:21:49 +00:00
KeWang1017
ca74a13d61 Refactor SACPolicy for improved action sampling and standard deviation handling
- Updated action selection to use distribution sampling and log probabilities for better stochastic behavior.
- Enhanced standard deviation clamping to prevent extreme values, ensuring stability in policy outputs.
- Cleaned up code by removing unnecessary comments and improving readability.

These changes aim to refine the SAC implementation, enhancing its robustness and performance during training and inference.
2024-12-29 14:17:25 +00:00
KeWang1017
18a4598986 trying to get sac running 2024-12-29 14:14:13 +00:00
Michel Aractingi
dc54d357ca Added normalization schemes and style checks 2024-12-29 12:51:21 +00:00
Michel Aractingi
08ec971086 added optimizer and sac to factory.py 2024-12-23 14:12:03 +01:00
Eugene Mironov
b53d6e0ff2 [HIL-SERL PORT] Fix linter issues (#588) 2024-12-23 10:44:29 +01:00
Eugene Mironov
70b652f791 [Port Hil-SERL] Add unit tests for the reward classifier & fix imports & check script (#578) 2024-12-23 10:43:55 +01:00
Michel Aractingi
7b68bfb73b added comments from kewang 2024-12-17 18:03:46 +01:00
KeWang1017
7e0f20fbf2 Enhance SAC configuration and policy with new parameters and subsampling logic
- Added `num_subsample_critics`, `critic_target_update_weight`, and `utd_ratio` to SACConfig.
- Implemented target entropy calculation in SACPolicy if not provided.
- Introduced subsampling of critics to prevent overfitting during updates.
- Updated temperature loss calculation to use the new target entropy.
- Added comments for future UTD update implementation.

These changes improve the flexibility and performance of the SAC implementation.
2024-12-17 17:58:11 +01:00
19 changed files with 2095 additions and 416 deletions

View File

@@ -611,11 +611,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
return query_timestamps
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
return {
key: torch.stack(self.hf_dataset.select(q_idx)[key])
for key, q_idx in query_indices.items()
if key not in self.meta.video_keys
}
# Step 1: Combine all unique indices
all_indices = sorted({idx for indices in query_indices.values() for idx in indices})
# Step 2: Select all required data at once
selected_dataset = self.hf_dataset.select(all_indices).to_dict()
selected_dataset = {key: torch.tensor(values) for key, values in selected_dataset.items()}
# Step 3: Map original indices to their positions in the selected dataset
index_map = {original_idx: i for i, original_idx in enumerate(all_indices)}
# Step 4: Build the result for each key
results = {}
for key, q_indices in query_indices.items():
if key not in self.meta.video_keys:
mapped_indices = [index_map[idx] for idx in q_indices]
results[key] = torch.stack([selected_dataset[key][i] for i in mapped_indices])
return results
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function

View File

@@ -25,13 +25,13 @@ from glob import glob
from pathlib import Path
import torch
import wandb
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
import wandb
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state

View File

@@ -66,6 +66,12 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
return VQBeTPolicy, VQBeTConfig
elif name == "sac":
from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.sac.modeling_sac import SACPolicy
return SACPolicy, SACConfig
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")

View File

@@ -2,8 +2,6 @@ import json
import os
from dataclasses import asdict, dataclass
import torch
@dataclass
class ClassifierConfig:
@@ -13,7 +11,7 @@ class ClassifierConfig:
hidden_dim: int = 256
dropout_rate: float = 0.1
model_name: str = "microsoft/resnet-50"
device: str = "cuda" if torch.cuda.is_available() else "mps"
device: str = "cpu"
model_type: str = "cnn" # "transformer" or "cnn"
def save_pretrained(self, save_dir):

View File

@@ -22,6 +22,13 @@ class ClassifierOutput:
self.probabilities = probabilities
self.hidden_states = hidden_states
def __repr__(self):
return (
f"ClassifierOutput(logits={self.logits}, "
f"probabilities={self.probabilities}, "
f"hidden_states={self.hidden_states})"
)
class Classifier(
nn.Module,
@@ -70,6 +77,8 @@ class Classifier(
else:
raise ValueError("Unsupported CNN architecture")
self.encoder = self.encoder.to(self.config.device)
def _freeze_encoder(self) -> None:
"""Freeze the encoder parameters."""
for param in self.encoder.parameters():
@@ -93,6 +102,7 @@ class Classifier(
nn.ReLU(),
nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes),
)
self.classifier_head = self.classifier_head.to(self.config.device)
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
"""Extract the appropriate output from the encoder."""

View File

@@ -1,6 +1,6 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,6 +1,6 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -26,4 +26,4 @@ class HILSerlPolicy(
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "hilserl"],
):
pass
pass

View File

@@ -1,6 +1,6 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,25 +15,59 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from dataclasses import dataclass, field
@dataclass
class SACConfig:
input_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"observation.image": [3, 84, 84],
"observation.state": [4],
}
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [2],
}
)
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "min_max",
"observation.environment_state": "min_max",
}
)
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"},
)
discount = 0.99
temperature_init = 1.0
num_critics = 2
num_subsample_critics = None
critic_lr = 3e-4
actor_lr = 3e-4
temperature_lr = 3e-4
critic_target_update_weight = 0.005
utd_ratio = 2
state_encoder_hidden_dim = 256
latent_dim = 256
target_entropy = None
# backup_entropy = False
use_backup_entropy = True
critic_network_kwargs = {
"hidden_dims": [256, 256],
"activate_final": True,
}
"hidden_dims": [256, 256],
"activate_final": True,
}
actor_network_kwargs = {
"hidden_dims": [256, 256],
"activate_final": True,
}
"hidden_dims": [256, 256],
"activate_final": True,
}
policy_kwargs = {
"tanh_squash_distribution": True,
"std_parameterization": "uniform",
}
"use_tanh_squash": True,
"log_std_min": -5,
"log_std_max": 2,
}

View File

@@ -1,6 +1,6 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,22 +18,18 @@
# TODO: (1) better device management
from collections import deque
from copy import deepcopy
from functools import partial
from typing import Callable, Optional, Sequence, Tuple, Union
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor
from huggingface_hub import PyTorchModelHubMixin
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.sac.configuration_sac import SACConfig
import numpy as np
from typing import Callable, Optional, Tuple, Sequence
class SACPolicy(
@@ -43,11 +39,11 @@ class SACPolicy(
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "RL", "SAC"],
):
name = "sac"
def __init__(
self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
):
super().__init__()
if config is None:
@@ -60,33 +56,61 @@ class SACPolicy(
)
else:
self.normalize_inputs = nn.Identity()
# HACK: we need to pass the dataset_stats to the normalization functions
dataset_stats = dataset_stats or {
"action": {
"min": torch.tensor([-1.0, -1.0, -1.0, -1.0]),
"max": torch.tensor([1.0, 1.0, 1.0, 1.0]),
}
}
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
encoder = SACObservationEncoder(config)
encoder_critic = SACObservationEncoder(config)
encoder_actor = SACObservationEncoder(config)
# Define networks
critic_nets = []
for _ in range(config.num_critics):
critic_net = Critic(
encoder=encoder,
network=MLP(**config.critic_network_kwargs)
encoder=encoder_critic,
network=MLP(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
),
)
critic_nets.append(critic_net)
target_critic_nets = []
for _ in range(config.num_critics):
target_critic_net = Critic(
encoder=encoder_critic,
network=MLP(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
),
)
target_critic_nets.append(target_critic_net)
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
self.critic_target = deepcopy(self.critic_ensemble)
self.critic_target = create_critic_ensemble(target_critic_nets, config.num_critics)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
self.actor_network = Policy(
encoder=encoder,
network=MLP(**config.actor_network_kwargs),
self.actor = Policy(
encoder=encoder_actor,
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0],
**config.policy_kwargs
**config.policy_kwargs,
)
self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
# TODO: fix later device
# TODO: Handle the case where the temparameter is a fixed
self.log_alpha = torch.zeros(1, requires_grad=True, device="cpu")
self.temperature = self.log_alpha.exp().item()
def reset(self):
"""
@@ -98,24 +122,49 @@ class SACPolicy(
"observation.state": deque(maxlen=1),
"action": deque(maxlen=1),
}
if self._use_image:
if "observation.image" in self.config.input_shapes:
self._queues["observation.image"] = deque(maxlen=1)
if self._use_env_state:
if "observation.environment_state" in self.config.input_shapes:
self._queues["observation.environment_state"] = deque(maxlen=1)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
actions, _ = self.actor_network(batch['observations'])###
"""Select action for inference/evaluation"""
actions, _, _ = self.actor(batch)
actions = self.unnormalize_outputs({"action": actions})["action"]
return actions
def critic_forward(
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
) -> Tensor:
"""Forward pass through a critic network ensemble
Args:
observations: Dictionary of observations
actions: Action tensor
use_target: If True, use target critics, otherwise use ensemble critics
Returns:
Tensor of Q-values from all critics
"""
critics = self.critic_target if use_target else self.critic_ensemble
q_values = torch.stack([critic(observations, actions) for critic in critics])
return q_values
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss.
Returns a dictionary with loss as a tensor, and other information as native floats.
"""
# We have to actualize the value of the temperature because in the previous
self.temperature = self.log_alpha.exp().item()
temperature = self.temperature
batch = self.normalize_inputs(batch)
# batch shape is (b, 2, ...) where index 1 returns the current observation and
# the next observation for caluculating the right td index.
actions = batch["action"][:, 0]
# batch shape is (b, 2, ...) where index 1 returns the current observation and
# the next observation for calculating the right td index.
# actions = batch["action"][:, 0]
actions = batch["action"]
rewards = batch["next.reward"][:, 0]
observations = {}
next_observations = {}
@@ -123,186 +172,227 @@ class SACPolicy(
if k.startswith("observation."):
observations[k] = batch[k][:, 0]
next_observations[k] = batch[k][:, 1]
# perform image augmentation
done = batch["next.done"]
# reward bias
# from HIL-SERL code base
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations)
# calculate critics loss
# 1- compute actions from policy
action_preds, log_probs = self.actor_network(observations)
# 2- compute q targets
q_targets = self.target_qs(next_observations, action_preds)
# 2- compute q targets
q_targets = self.critic_forward(next_observations, next_action_preds, use_target=True)
# critics subsample size
min_q = q_targets.min(dim=0)
# subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices]
# backup entropy
td_target = rewards + self.discount * min_q
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
min_q -= self.temperature * next_log_probs
td_target = rewards + self.config.discount * min_q * ~done
# 3- compute predicted qs
q_preds = self.critic_ensemble(observations, actions)
q_preds = self.critic_forward(observations, actions, use_target=False)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
critics_loss = (
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
critics_loss = (
F.mse_loss(
q_preds,
einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]),
reduction="none",
).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
# q_targets depends on the reward and the next observations.
* ~batch["next.reward_is_pad"]
* ~batch["observation.state_is_pad"][1:]
).sum(0).mean()
# calculate actors loss
# 1- temperature
temperature = self.temperature()
input=q_preds,
target=td_target_duplicate,
reduction="none",
).mean(1)
).sum()
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
actions, log_probs = self.actor_network(observations) \
# 3- get q-value predictions
with torch.no_grad():
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
actor_loss = (
-(q_preds - temperature * log_probs).mean()
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
).mean()
actions_pi, log_probs, _ = self.actor(observations)
with torch.inference_mode():
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
# calculate temperature loss
# 1- calculate entropy
entropy = -log_probs.mean()
temperature_loss = temperature * (entropy - self.target_entropy).mean()
with torch.no_grad():
_, log_probs, _ = self.actor(observations)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
loss = critics_loss + actor_loss + temperature_loss
return {
"critics_loss": critics_loss.item(),
"actor_loss": actor_loss.item(),
"temperature_loss": temperature_loss.item(),
"temperature": temperature.item(),
"entropy": entropy.item(),
"loss": loss,
"critics_loss": critics_loss.item(),
"actor_loss": actor_loss.item(),
"mean_q_predicts": min_q_preds.mean().item(),
"min_q_predicts": min_q_preds.min().item(),
"max_q_predicts": min_q_preds.max().item(),
"temperature_loss": temperature_loss.item(),
"temperature": temperature,
"mean_log_probs": log_probs.mean().item(),
"min_log_probs": log_probs.min().item(),
"max_log_probs": log_probs.max().item(),
"td_target_mean": td_target.mean().item(),
"td_target_max": td_target.max().item(),
"action_mean": actions.mean().item(),
"entropy": log_probs.mean().item(),
"loss": loss,
}
}
def update(self):
self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight)
#for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)
def update_target_networks(self):
"""Update target networks with exponential moving average"""
for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False):
for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
def compute_loss_critic(self, observations, actions, rewards, next_observations, done) -> Tensor:
temperature = self.log_alpha.exp().item()
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations)
# 2- compute q targets
q_targets = self.critic_forward(
observations=next_observations, actions=next_action_preds, use_target=True
)
# subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices]
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
min_q = min_q - (temperature * next_log_probs)
td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs
q_preds = self.critic_forward(observations, actions, use_target=False)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
critics_loss = (
F.mse_loss(
input=q_preds,
target=td_target_duplicate,
reduction="none",
).mean(1)
).sum()
return critics_loss
def compute_loss_temperature(self, observations) -> Tensor:
"""Compute the temperature loss"""
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.actor(observations)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
return temperature_loss
def compute_loss_actor(self, observations) -> Tensor:
temperature = self.log_alpha.exp().item()
actions_pi, log_probs, _ = self.actor(observations)
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
return actor_loss
class MLP(nn.Module):
def __init__(
self,
config: SACConfig,
input_dim: int,
hidden_dims: list[int],
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False,
dropout_rate: Optional[float] = None,
):
super().__init__()
self.activate_final = config.activate_final
self.activate_final = activate_final
layers = []
for i, size in enumerate(config.network_hidden_dims):
layers.append(nn.Linear(config.network_hidden_dims[i-1] if i > 0 else config.network_hidden_dims[0], size))
if i + 1 < len(config.network_hidden_dims) or activate_final:
# First layer uses input_dim
layers.append(nn.Linear(input_dim, hidden_dims[0]))
# Add activation after first layer
if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[0]))
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
# Rest of the layers
for i in range(1, len(hidden_dims)):
layers.append(nn.Linear(hidden_dims[i - 1], hidden_dims[i]))
if i + 1 < len(hidden_dims) or activate_final:
if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(size))
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
layers.append(nn.LayerNorm(hidden_dims[i]))
layers.append(
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
)
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor, train: bool = False) -> torch.Tensor:
# in training mode or not. TODO: find better way to do this
self.train(train)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class Critic(nn.Module):
def __init__(
self,
encoder: Optional[nn.Module],
network: nn.Module,
init_final: Optional[float] = None,
activate_final: bool = False,
device: str = "cuda"
device: str = "cpu",
):
super().__init__()
self.device = torch.device(device)
self.encoder = encoder
self.network = network
self.init_final = init_final
self.activate_final = activate_final
# Find the last Linear layer's output dimension
for layer in reversed(network.net):
if isinstance(layer, nn.Linear):
out_features = layer.out_features
break
# Output layer
if init_final is not None:
if self.activate_final:
self.output_layer = nn.Linear(network.net[-3].out_features, 1)
else:
self.output_layer = nn.Linear(network.net[-2].out_features, 1)
self.output_layer = nn.Linear(out_features, 1)
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
else:
if self.activate_final:
self.output_layer = nn.Linear(network.net[-3].out_features, 1)
else:
self.output_layer = nn.Linear(network.net[-2].out_features, 1)
self.output_layer = nn.Linear(out_features, 1)
orthogonal_init()(self.output_layer.weight)
self.to(self.device)
def forward(
self,
observations: torch.Tensor,
self,
observations: dict[str, torch.Tensor],
actions: torch.Tensor,
train: bool = False
) -> torch.Tensor:
self.train(train)
observations = observations.to(self.device)
# Move each tensor in observations to device
observations = {k: v.to(self.device) for k, v in observations.items()}
actions = actions.to(self.device)
if self.encoder is not None:
obs_enc = self.encoder(observations)
else:
obs_enc = observations
obs_enc = observations if self.encoder is None else self.encoder(observations)
inputs = torch.cat([obs_enc, actions], dim=-1)
x = self.network(inputs)
value = self.output_layer(x)
return value.squeeze(-1)
def q_value_ensemble(
self,
observations: torch.Tensor,
actions: torch.Tensor,
train: bool = False
) -> torch.Tensor:
observations = observations.to(self.device)
actions = actions.to(self.device)
if len(actions.shape) == 3: # [batch_size, num_actions, action_dim]
batch_size, num_actions = actions.shape[:2]
obs_expanded = observations.unsqueeze(1).expand(-1, num_actions, -1)
obs_flat = obs_expanded.reshape(-1, observations.shape[-1])
actions_flat = actions.reshape(-1, actions.shape[-1])
q_values = self(obs_flat, actions_flat, train)
return q_values.reshape(batch_size, num_actions)
else:
return self(observations, actions, train)
class Policy(nn.Module):
@@ -311,115 +401,93 @@ class Policy(nn.Module):
encoder: Optional[nn.Module],
network: nn.Module,
action_dim: int,
std_parameterization: str = "exp",
std_min: float = 1e-5,
std_max: float = 10.0,
tanh_squash_distribution: bool = False,
log_std_min: float = -5,
log_std_max: float = 2,
fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None,
activate_final: bool = False,
device: str = "cuda"
use_tanh_squash: bool = False,
device: str = "cpu",
):
super().__init__()
self.device = torch.device(device)
self.encoder = encoder
self.network = network
self.action_dim = action_dim
self.std_parameterization = std_parameterization
self.std_min = std_min
self.std_max = std_max
self.tanh_squash_distribution = tanh_squash_distribution
self.log_std_min = log_std_min
self.log_std_max = log_std_max
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
self.activate_final = activate_final
self.use_tanh_squash = use_tanh_squash
# Find the last Linear layer's output dimension
for layer in reversed(network.net):
if isinstance(layer, nn.Linear):
out_features = layer.out_features
break
# Mean layer
if self.activate_final:
self.mean_layer = nn.Linear(network.net[-3].out_features, action_dim)
else:
self.mean_layer = nn.Linear(network.net[-2].out_features, action_dim)
self.mean_layer = nn.Linear(out_features, action_dim)
if init_final is not None:
nn.init.uniform_(self.mean_layer.weight, -init_final, init_final)
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.mean_layer.weight)
# Standard deviation layer or parameter
if fixed_std is None:
if std_parameterization == "uniform":
self.log_stds = nn.Parameter(torch.zeros(action_dim, device=self.device))
self.std_layer = nn.Linear(out_features, action_dim)
if init_final is not None:
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
else:
if self.activate_final:
self.std_layer = nn.Linear(network.net[-3].out_features, action_dim)
else:
self.std_layer = nn.Linear(network.net[-2].out_features, action_dim)
if init_final is not None:
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.std_layer.weight)
orthogonal_init()(self.std_layer.weight)
self.to(self.device)
def forward(
self,
self,
observations: torch.Tensor,
temperature: float = 1.0,
train: bool = False,
non_squash_distribution: bool = False
) -> torch.distributions.Distribution:
self.train(train)
) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists
if self.encoder is not None:
with torch.set_grad_enabled(train):
obs_enc = self.encoder(observations, train=train)
else:
obs_enc = observations
obs_enc = observations if self.encoder is None else self.encoder(observations)
# Get network outputs
outputs = self.network(obs_enc)
means = self.mean_layer(outputs)
# Compute standard deviations
if self.fixed_std is None:
if self.std_parameterization == "exp":
log_stds = self.std_layer(outputs)
stds = torch.exp(log_stds)
elif self.std_parameterization == "softplus":
stds = torch.nn.functional.softplus(self.std_layer(outputs))
elif self.std_parameterization == "uniform":
stds = torch.exp(self.log_stds).expand_as(means)
log_std = self.std_layer(outputs)
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
if self.use_tanh_squash:
log_std = torch.tanh(log_std)
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
else:
raise ValueError(
f"Invalid std_parameterization: {self.std_parameterization}"
)
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
else:
assert self.std_parameterization == "fixed"
stds = self.fixed_std.expand_as(means)
log_std = self.fixed_std.expand_as(means)
# Clip standard deviations and scale with temperature
temperature = torch.tensor(temperature, device=self.device)
stds = torch.clamp(stds, self.std_min, self.std_max) * torch.sqrt(temperature)
# uses tanh activation function to squash the action to be in the range of [-1, 1]
normal = torch.distributions.Normal(means, torch.exp(log_std))
x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
log_probs = normal.log_prob(x_t) # Base log probability before Tanh
# Create distribution
if self.tanh_squash_distribution and not non_squash_distribution:
distribution = TanhMultivariateNormalDiag(
loc=means,
scale_diag=stds,
)
if self.use_tanh_squash:
actions = torch.tanh(x_t)
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
else:
distribution = torch.distributions.Normal(
loc=means,
scale=stds,
)
actions = x_t # No Tanh; raw Gaussian sample
log_probs = log_probs.sum(-1) # Sum over action dimensions
means = torch.tanh(means) if self.use_tanh_squash else means
return actions, log_probs, means
return distribution
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
"""Get encoded features from observations"""
observations = observations.to(self.device)
if self.encoder is not None:
with torch.no_grad():
return self.encoder(observations, train=False)
with torch.inference_mode():
return self.encoder(observations)
return observations
@@ -461,19 +529,13 @@ class SACObservationEncoder(nn.Module):
)
if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential(
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.Linear(config.input_shapes["observation.state"][0], config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
if "observation.environment_state" in config.input_shapes:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.Linear(config.input_shapes["observation.environment_state"][0], config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
@@ -493,175 +555,25 @@ class SACObservationEncoder(nn.Module):
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
# TODO(ke-wang): currently average over all features, concatenate all features maybe a better way
return torch.stack(feat, dim=0).mean(0)
class LagrangeMultiplier(nn.Module):
def __init__(
self,
init_value: float = 1.0,
constraint_shape: Sequence[int] = (),
device: str = "cuda"
):
super().__init__()
self.device = torch.device(device)
init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1)
# Initialize the Lagrange multiplier as a parameter
self.lagrange = nn.Parameter(
torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device)
)
self.to(self.device)
def forward(
self,
lhs: Optional[torch.Tensor] = None,
rhs: Optional[torch.Tensor] = None
) -> torch.Tensor:
# Get the multiplier value based on parameterization
multiplier = torch.nn.functional.softplus(self.lagrange)
# Return the raw multiplier if no constraint values provided
if lhs is None:
return multiplier
# Move inputs to device
lhs = lhs.to(self.device)
if rhs is not None:
rhs = rhs.to(self.device)
# Use the multiplier to compute the Lagrange penalty
if rhs is None:
rhs = torch.zeros_like(lhs, device=self.device)
diff = lhs - rhs
assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}"
return multiplier * diff
# The TanhMultivariateNormalDiag is a probability distribution that represents a transformed normal (Gaussian) distribution where:
# 1. The base distribution is a diagonal multivariate normal distribution
# 2. The samples from this normal distribution are transformed through a tanh function, which squashes the values to be between -1 and 1
# 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation
# This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces
class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
def __init__(
self,
loc: torch.Tensor,
scale_diag: torch.Tensor,
low: Optional[torch.Tensor] = None,
high: Optional[torch.Tensor] = None,
):
# Create base normal distribution
base_distribution = torch.distributions.Normal(loc=loc, scale=scale_diag)
# Create list of transforms
transforms = []
# Add tanh transform
transforms.append(torch.distributions.transforms.TanhTransform())
# Add rescaling transform if bounds are provided
if low is not None and high is not None:
transforms.append(
torch.distributions.transforms.AffineTransform(
loc=(high + low) / 2,
scale=(high - low) / 2
)
)
# Initialize parent class
super().__init__(
base_distribution=base_distribution,
transforms=transforms
)
# Store parameters
self.loc = loc
self.scale_diag = scale_diag
self.low = low
self.high = high
def mode(self) -> torch.Tensor:
"""Get the mode of the transformed distribution"""
# The mode of a normal distribution is its mean
mode = self.loc
# Apply transforms
for transform in self.transforms:
mode = transform(mode)
return mode
def rsample(self, sample_shape=torch.Size()) -> torch.Tensor:
"""
Reparameterized sample from the distribution
"""
# Sample from base distribution
x = self.base_dist.rsample(sample_shape)
# Apply transforms
for transform in self.transforms:
x = transform(x)
return x
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
"""
Compute log probability of a value
Includes the log det jacobian for the transforms
"""
# Initialize log prob
log_prob = torch.zeros_like(value[..., 0])
# Inverse transforms to get back to normal distribution
q = value
for transform in reversed(self.transforms):
q = transform.inv(q)
log_prob = log_prob - transform.log_abs_det_jacobian(q, transform(q))
# Add base distribution log prob
log_prob = log_prob + self.base_dist.log_prob(q).sum(-1)
return log_prob
def sample_and_log_prob(self, sample_shape=torch.Size()) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Sample from the distribution and compute log probability
"""
x = self.rsample(sample_shape)
log_prob = self.log_prob(x)
return x, log_prob
def entropy(self) -> torch.Tensor:
"""
Compute entropy of the distribution
"""
# Start with base distribution entropy
entropy = self.base_dist.entropy().sum(-1)
# Add log det jacobian for each transform
x = self.rsample()
for transform in self.transforms:
entropy = entropy + transform.log_abs_det_jacobian(x, transform(x))
x = transform(x)
return entropy
def create_critic_ensemble(critic_class, num_critics: int, device: str = "cuda") -> nn.ModuleList:
"""Creates an ensemble of critic networks"""
critics = nn.ModuleList([critic_class() for _ in range(num_critics)])
return critics.to(device)
@property
def output_dim(self) -> int:
"""Returns the dimension of the encoder output"""
return self.config.latent_dim
def orthogonal_init():
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cpu") -> nn.ModuleList:
"""Creates an ensemble of critic networks"""
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
return nn.ModuleList(critics).to(device)
# borrowed from tdmpc
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor.
@@ -669,7 +581,7 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
Args:
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
(B, *), where * is any number of dimensions.
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
can be more than 1 dimensions, generally different from *.
Returns:
A return value from the callable reshaped to (**, *).
@@ -680,4 +592,3 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
inp = torch.flatten(image_tensor, end_dim=-4)
flat_out = fn(inp)
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))

View File

@@ -0,0 +1,89 @@
# @package _global_
# Train with:
#
# python lerobot/scripts/train.py \
# env=pusht \
# +dataset=lerobot/pusht_keypoints
seed: 1
dataset_repo_id: lerobot/pusht_keypoints
training:
offline_steps: 0
# Offline training dataloader
num_workers: 4
batch_size: 256
grad_clip_norm: 10.0
lr: 3e-4
eval_freq: 2500
log_freq: 500
save_freq: 50000
online_steps: 1000000
online_rollout_n_episodes: 10
online_rollout_batch_size: 10
online_steps_between_rollouts: 1000
online_sampling_ratio: 1.0
online_env_seed: 10000
online_buffer_capacity: 40000
online_buffer_seed_size: 0
do_online_rollout_async: false
delta_timestamps:
observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
action: "[i / ${fps} for i in range(${policy.horizon})]"
next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
policy:
name: sac
pretrained_model_path:
# Input / output structure.
n_action_repeats: 1
horizon: 2
n_action_steps: 2
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.environment_state: [16]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.environment_state: min_max
observation.state: min_max
output_normalization_modes:
action: min_max
# Architecture / modeling.
# Neural networks.
# image_encoder_hidden_dim: 32
discount: 0.99
temperature_init: 1.0
num_critics: 2
num_subsample_critics: None
critic_lr: 3e-4
actor_lr: 3e-4
temperature_lr: 3e-4
critic_target_update_weight: 0.005
utd_ratio: 2
# # Loss coefficients.
# reward_coeff: 0.5
# expectile_weight: 0.9
# value_coeff: 0.1
# consistency_coeff: 20.0
# advantage_scaling: 3.0
# pi_coeff: 0.5
# temporal_decay_coeff: 0.5
# # Target model.
# target_model_momentum: 0.995

View File

@@ -24,7 +24,7 @@ python lerobot/scripts/eval_on_robot.py \
```
**NOTE** (michel-aractingi): This script is incomplete and it is being prepared
for running training on the real robot.
for running training on the real robot.
"""
import argparse
@@ -47,7 +47,7 @@ from lerobot.common.utils.utils import (
def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, use_amp: bool = True) -> dict:
"""Run a batched policy rollout on the real robot.
"""Run a batched policy rollout on the real robot.
The return dictionary contains:
"robot": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation
@@ -64,7 +64,7 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20,
extraneous elements from the sequences above.
Args:
robot: The robot class that defines the interface with the real robot.
robot: The robot class that defines the interface with the real robot.
policy: The policy. Must be a PyTorch nn module.
Returns:
@@ -77,7 +77,7 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20,
listener, events = init_keyboard_listener()
# Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready.
# policy.reset()
# policy.reset()
# Get observation from real robot
observation = robot.capture_observation()

View File

@@ -93,6 +93,18 @@ def make_optimizer_and_scheduler(cfg, policy):
elif policy.name == "tdmpc":
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
lr_scheduler = None
elif policy.name == "sac":
optimizer = torch.optim.Adam(
[
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
{"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr},
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
{"params": [policy.log_alpha], "lr": policy.config.temperature_lr},
]
)
lr_scheduler = None
elif cfg.policy.name == "vqbet":
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
@@ -311,6 +323,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("make_dataset")
offline_dataset = make_dataset(cfg)
# TODO (michel-aractingi): temporary fix to avoid datasets with task_index key that doesn't exist in online environment
# i.e., pusht
if "task_index" in offline_dataset.hf_dataset[0]:
offline_dataset.hf_dataset = offline_dataset.hf_dataset.remove_columns(["task_index"])
if isinstance(offline_dataset, MultiLeRobotDataset):
logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "

View File

@@ -22,7 +22,6 @@ from pprint import pformat
import hydra
import torch
import torch.nn as nn
import wandb
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
@@ -31,6 +30,7 @@ from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
from tqdm import tqdm
import wandb
from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger

1112
lerobot/scripts/train_sac.py Normal file

File diff suppressed because it is too large Load Diff

153
poetry.lock generated
View File

@@ -3139,6 +3139,27 @@ dev = ["changelist (==0.5)"]
lint = ["pre-commit (==3.7.0)"]
test = ["pytest (>=7.4)", "pytest-cov (>=4.1)"]
[[package]]
name = "lightning-utilities"
version = "0.11.9"
description = "Lightning toolbox for across the our ecosystem."
optional = true
python-versions = ">=3.8"
files = [
{file = "lightning_utilities-0.11.9-py3-none-any.whl", hash = "sha256:ac6d4e9e28faf3ff4be997876750fee10dc604753dbc429bf3848a95c5d7e0d2"},
{file = "lightning_utilities-0.11.9.tar.gz", hash = "sha256:f5052b81344cc2684aa9afd74b7ce8819a8f49a858184ec04548a5a109dfd053"},
]
[package.dependencies]
packaging = ">=17.1"
setuptools = "*"
typing-extensions = "*"
[package.extras]
cli = ["fire"]
docs = ["requests (>=2.0.0)"]
typing = ["mypy (>=1.0.0)", "types-setuptools"]
[[package]]
name = "llvmlite"
version = "0.43.0"
@@ -6798,6 +6819,38 @@ webencodings = ">=0.4"
doc = ["sphinx", "sphinx_rtd_theme"]
test = ["pytest", "ruff"]
[[package]]
name = "tokenizers"
version = "0.21.0"
description = ""
optional = true
python-versions = ">=3.7"
files = [
{file = "tokenizers-0.21.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:3c4c93eae637e7d2aaae3d376f06085164e1660f89304c0ab2b1d08a406636b2"},
{file = "tokenizers-0.21.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:f53ea537c925422a2e0e92a24cce96f6bc5046bbef24a1652a5edc8ba975f62e"},
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b177fb54c4702ef611de0c069d9169f0004233890e0c4c5bd5508ae05abf193"},
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6b43779a269f4629bebb114e19c3fca0223296ae9fea8bb9a7a6c6fb0657ff8e"},
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9aeb255802be90acfd363626753fda0064a8df06031012fe7d52fd9a905eb00e"},
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d8b09dbeb7a8d73ee204a70f94fc06ea0f17dcf0844f16102b9f414f0b7463ba"},
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:400832c0904f77ce87c40f1a8a27493071282f785724ae62144324f171377273"},
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e84ca973b3a96894d1707e189c14a774b701596d579ffc7e69debfc036a61a04"},
{file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:eb7202d231b273c34ec67767378cd04c767e967fda12d4a9e36208a34e2f137e"},
{file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:089d56db6782a73a27fd8abf3ba21779f5b85d4a9f35e3b493c7bbcbbf0d539b"},
{file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:c87ca3dc48b9b1222d984b6b7490355a6fdb411a2d810f6f05977258400ddb74"},
{file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:4145505a973116f91bc3ac45988a92e618a6f83eb458f49ea0790df94ee243ff"},
{file = "tokenizers-0.21.0-cp39-abi3-win32.whl", hash = "sha256:eb1702c2f27d25d9dd5b389cc1f2f51813e99f8ca30d9e25348db6585a97e24a"},
{file = "tokenizers-0.21.0-cp39-abi3-win_amd64.whl", hash = "sha256:87841da5a25a3a5f70c102de371db120f41873b854ba65e52bccd57df5a3780c"},
{file = "tokenizers-0.21.0.tar.gz", hash = "sha256:ee0894bf311b75b0c03079f33859ae4b2334d675d4e93f5a4132e1eae2834fe4"},
]
[package.dependencies]
huggingface-hub = ">=0.16.4,<1.0"
[package.extras]
dev = ["tokenizers[testing]"]
docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"]
testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests", "ruff"]
[[package]]
name = "tomli"
version = "2.0.2"
@@ -6863,6 +6916,34 @@ typing-extensions = ">=4.8.0"
opt-einsum = ["opt-einsum (>=3.3)"]
optree = ["optree (>=0.11.0)"]
[[package]]
name = "torchmetrics"
version = "1.6.0"
description = "PyTorch native Metrics"
optional = true
python-versions = ">=3.9"
files = [
{file = "torchmetrics-1.6.0-py3-none-any.whl", hash = "sha256:a508cdd87766cedaaf55a419812bf9f493aff8fffc02cc19df5a8e2e7ccb942a"},
{file = "torchmetrics-1.6.0.tar.gz", hash = "sha256:aebba248708fb90def20cccba6f55bddd134a58de43fb22b0c5ca0f3a89fa984"},
]
[package.dependencies]
lightning-utilities = ">=0.8.0"
numpy = ">1.20.0"
packaging = ">17.1"
torch = ">=2.0.0"
[package.extras]
all = ["SciencePlots (>=2.0.0)", "gammatone (>=1.0.0)", "ipadic (>=1.0.0)", "librosa (>=0.10.0)", "matplotlib (>=3.6.0)", "mecab-python3 (>=1.0.6)", "mypy (==1.13.0)", "nltk (>3.8.1)", "numpy (<2.0)", "onnxruntime (>=1.12.0)", "pesq (>=0.0.4)", "piq (<=0.8.0)", "pycocotools (>2.0.0)", "pystoi (>=0.4.0)", "regex (>=2021.9.24)", "requests (>=2.19.0)", "scipy (>1.0.0)", "sentencepiece (>=0.2.0)", "torch (==2.5.1)", "torch-fidelity (<=0.4.0)", "torchaudio (>=2.0.1)", "torchvision (>=0.15.1)", "tqdm (<4.68.0)", "transformers (>4.4.0)", "transformers (>=4.42.3)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"]
audio = ["gammatone (>=1.0.0)", "librosa (>=0.10.0)", "numpy (<2.0)", "onnxruntime (>=1.12.0)", "pesq (>=0.0.4)", "pystoi (>=0.4.0)", "requests (>=2.19.0)", "torchaudio (>=2.0.1)"]
detection = ["pycocotools (>2.0.0)", "torchvision (>=0.15.1)"]
dev = ["PyTDC (==0.4.1)", "SciencePlots (>=2.0.0)", "bert-score (==0.3.13)", "dython (==0.7.6)", "dython (>=0.7.8,<0.8.0)", "fairlearn", "fast-bss-eval (>=0.1.0)", "faster-coco-eval (>=1.6.3)", "gammatone (>=1.0.0)", "huggingface-hub (<0.27)", "ipadic (>=1.0.0)", "jiwer (>=2.3.0)", "kornia (>=0.6.7)", "librosa (>=0.10.0)", "lpips (<=0.1.4)", "matplotlib (>=3.6.0)", "mecab-ko (>=1.0.0,<1.1.0)", "mecab-ko-dic (>=1.0.0)", "mecab-python3 (>=1.0.6)", "mir-eval (>=0.6)", "monai (==1.3.2)", "monai (==1.4.0)", "mypy (==1.13.0)", "netcal (>1.0.0)", "nltk (>3.8.1)", "numpy (<2.0)", "numpy (<2.2.0)", "onnxruntime (>=1.12.0)", "pandas (>1.4.0)", "permetrics (==2.0.0)", "pesq (>=0.0.4)", "piq (<=0.8.0)", "pycocotools (>2.0.0)", "pystoi (>=0.4.0)", "pytorch-msssim (==1.0.0)", "regex (>=2021.9.24)", "requests (>=2.19.0)", "rouge-score (>0.1.0)", "sacrebleu (>=2.3.0)", "scikit-image (>=0.19.0)", "scipy (>1.0.0)", "sentencepiece (>=0.2.0)", "sewar (>=0.4.4)", "statsmodels (>0.13.5)", "torch (==2.5.1)", "torch-complex (<0.5.0)", "torch-fidelity (<=0.4.0)", "torchaudio (>=2.0.1)", "torchvision (>=0.15.1)", "tqdm (<4.68.0)", "transformers (>4.4.0)", "transformers (>=4.42.3)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"]
image = ["scipy (>1.0.0)", "torch-fidelity (<=0.4.0)", "torchvision (>=0.15.1)"]
multimodal = ["piq (<=0.8.0)", "transformers (>=4.42.3)"]
text = ["ipadic (>=1.0.0)", "mecab-python3 (>=1.0.6)", "nltk (>3.8.1)", "regex (>=2021.9.24)", "sentencepiece (>=0.2.0)", "tqdm (<4.68.0)", "transformers (>4.4.0)"]
typing = ["mypy (==1.13.0)", "torch (==2.5.1)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"]
visual = ["SciencePlots (>=2.0.0)", "matplotlib (>=3.6.0)"]
[[package]]
name = "torchvision"
version = "0.19.1"
@@ -6956,6 +7037,75 @@ files = [
docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"]
test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<8.2)", "pytest-mock", "pytest-mypy-testing"]
[[package]]
name = "transformers"
version = "4.47.0"
description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
optional = true
python-versions = ">=3.9.0"
files = [
{file = "transformers-4.47.0-py3-none-any.whl", hash = "sha256:a8e1bafdaae69abdda3cad638fe392e37c86d2ce0ecfcae11d60abb8f949ff4d"},
{file = "transformers-4.47.0.tar.gz", hash = "sha256:f8ead7a5a4f6937bb507e66508e5e002dc5930f7b6122a9259c37b099d0f3b19"},
]
[package.dependencies]
filelock = "*"
huggingface-hub = ">=0.24.0,<1.0"
numpy = ">=1.17"
packaging = ">=20.0"
pyyaml = ">=5.1"
regex = "!=2019.12.17"
requests = "*"
safetensors = ">=0.4.1"
tokenizers = ">=0.21,<0.22"
tqdm = ">=4.27"
[package.extras]
accelerate = ["accelerate (>=0.26.0)"]
agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch"]
all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch", "torchaudio", "torchvision"]
audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
benchmark = ["optimum-benchmark (>=0.3.0)"]
codecarbon = ["codecarbon (==1.2.0)"]
deepspeed = ["accelerate (>=0.26.0)", "deepspeed (>=0.9.3)"]
deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.21,<0.22)", "urllib3 (<2.0.0)"]
dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "libcst", "librosa", "nltk (<=3.8.1)", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"]
flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
ftfy = ["ftfy"]
integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"]
ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"]
modelcreation = ["cookiecutter (==1.7.3)"]
natten = ["natten (>=0.14.6,<0.15.0)"]
onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"]
onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
optuna = ["optuna"]
quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "libcst", "rich", "ruff (==0.5.1)", "urllib3 (<2.0.0)"]
ray = ["ray[tune] (>=2.7.0)"]
retrieval = ["datasets (!=2.5.0)", "faiss-cpu"]
ruff = ["ruff (==0.5.1)"]
sagemaker = ["sagemaker (>=2.31.0)"]
sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"]
serving = ["fastapi", "pydantic", "starlette", "uvicorn"]
sigopt = ["sigopt"]
sklearn = ["scikit-learn"]
speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
tf = ["keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"]
tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"]
tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
tiktoken = ["blobfile", "tiktoken"]
timm = ["timm (<=1.0.11)"]
tokenizers = ["tokenizers (>=0.21,<0.22)"]
torch = ["accelerate (>=0.26.0)", "torch"]
torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"]
torchhub = ["filelock", "huggingface-hub (>=0.24.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.21,<0.22)", "torch", "tqdm (>=4.27)"]
video = ["av (==9.2.0)"]
vision = ["Pillow (>=10.0.1,<=15.0)"]
[[package]]
name = "transforms3d"
version = "0.4.2"
@@ -7558,6 +7708,7 @@ dev = ["debugpy", "pre-commit"]
dora = ["gym-dora"]
dynamixel = ["dynamixel-sdk", "pynput"]
feetech = ["feetech-servo-sdk", "pynput"]
hilserl = ["torchmetrics", "transformers"]
intelrealsense = ["pyrealsense2"]
pusht = ["gym-pusht"]
stretch = ["hello-robot-stretch-body", "pynput", "pyrealsense2", "pyrender"]
@@ -7569,4 +7720,4 @@ xarm = ["gym-xarm"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
content-hash = "41344f0eb2d06d9a378abcd10df8205aa3926ff0a08ac5ab1a0b1bcae7440fd8"
content-hash = "b9d299916ced6af1d243f961a32b0a4aacbef18e0b95337a5224e8511f5d6dda"

View File

@@ -71,6 +71,8 @@ pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platfo
hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true}
pyserial = {version = ">=3.5", optional = true}
jsonlines = ">=4.0.0"
transformers = {version = "^4.47.0", optional = true}
torchmetrics = {version = "^1.6.0", optional = true}
[tool.poetry.extras]
@@ -86,6 +88,7 @@ dynamixel = ["dynamixel-sdk", "pynput"]
feetech = ["feetech-servo-sdk", "pynput"]
intelrealsense = ["pyrealsense2"]
stretch = ["hello-robot-stretch-body", "pyrender", "pyrealsense2", "pynput"]
hilserl = ["transformers", "torchmetrics"]
[tool.ruff]
line-length = 110

View File

@@ -14,9 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import traceback
import pytest
import torch
from serial import SerialException
from lerobot import available_cameras, available_motors, available_robots
@@ -124,3 +126,14 @@ def patch_builtins_input(monkeypatch):
print(text)
monkeypatch.setattr("builtins.input", print_text)
def pytest_addoption(parser):
parser.addoption("--seed", action="store", default="42", help="Set random seed for reproducibility")
@pytest.fixture(autouse=True)
def set_random_seed(request):
seed = int(request.config.getoption("--seed"))
random.seed(seed) # Python random
torch.manual_seed(seed) # PyTorch

View File

@@ -0,0 +1,244 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchmetrics import AUROC, Accuracy, F1Score, Precision, Recall
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier, ClassifierConfig
BATCH_SIZE = 1000
LR = 0.1
EPOCH_NUM = 2
if torch.cuda.is_available():
DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
DEVICE = torch.device("mps")
else:
DEVICE = torch.device("cpu")
def train_evaluate_multiclass_classifier():
logging.info(
f"Start multiclass classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}"
)
multiclass_config = ClassifierConfig(model_name="microsoft/resnet-18", device=DEVICE, num_classes=10)
multiclass_classifier = Classifier(multiclass_config)
trainset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
testset = CIFAR10(root="data", train=False, download=True, transform=ToTensor())
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)
multiclass_num_classes = 10
epoch = 1
criterion = CrossEntropyLoss()
optimizer = Adam(multiclass_classifier.parameters(), lr=LR)
multiclass_classifier.train()
logging.info("Start multiclass classifier training")
# Training loop
while epoch < EPOCH_NUM: # loop over the dataset multiple times
for i, data in enumerate(trainloader):
inputs, labels = data
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
# Zero the parameter gradients
optimizer.zero_grad()
# Forward pass
outputs = multiclass_classifier(inputs)
loss = criterion(outputs.logits, labels)
loss.backward()
optimizer.step()
if i % 10 == 0: # print every 10 mini-batches
logging.info(f"[Epoch {epoch}, Batch {i}] loss: {loss.item():.3f}")
epoch += 1
print("Multiclass classifier training finished")
multiclass_classifier.eval()
test_loss = 0.0
test_labels = []
test_pridections = []
test_probs = []
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(DEVICE), labels.to(DEVICE)
outputs = multiclass_classifier(images)
loss = criterion(outputs.logits, labels)
test_loss += loss.item() * BATCH_SIZE
_, predicted = torch.max(outputs.logits, 1)
test_labels.extend(labels.cpu())
test_pridections.extend(predicted.cpu())
test_probs.extend(outputs.probabilities.cpu())
test_loss = test_loss / len(testset)
logging.info(f"Multiclass classifier test loss {test_loss:.3f}")
test_labels = torch.stack(test_labels)
test_predictions = torch.stack(test_pridections)
test_probs = torch.stack(test_probs)
accuracy = Accuracy(task="multiclass", num_classes=multiclass_num_classes)
precision = Precision(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
recall = Recall(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
f1 = F1Score(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
auroc = AUROC(task="multiclass", num_classes=multiclass_num_classes, average="weighted")
# Calculate metrics
acc = accuracy(test_predictions, test_labels)
prec = precision(test_predictions, test_labels)
rec = recall(test_predictions, test_labels)
f1_score = f1(test_predictions, test_labels)
auroc_score = auroc(test_probs, test_labels)
logging.info(f"Accuracy: {acc:.2f}")
logging.info(f"Precision: {prec:.2f}")
logging.info(f"Recall: {rec:.2f}")
logging.info(f"F1 Score: {f1_score:.2f}")
logging.info(f"AUROC Score: {auroc_score:.2f}")
def train_evaluate_binary_classifier():
logging.info(
f"Start binary classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}"
)
target_binary_class = 3
def one_vs_rest(dataset, target_class):
new_targets = []
for _, label in dataset:
new_label = float(1.0) if label == target_class else float(0.0)
new_targets.append(new_label)
dataset.targets = new_targets # Replace the original labels with the binary ones
return dataset
binary_train_dataset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
binary_test_dataset = CIFAR10(root="data", train=False, download=True, transform=ToTensor())
# Apply one-vs-rest labeling
binary_train_dataset = one_vs_rest(binary_train_dataset, target_binary_class)
binary_test_dataset = one_vs_rest(binary_test_dataset, target_binary_class)
binary_trainloader = DataLoader(binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
binary_testloader = DataLoader(binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False)
binary_epoch = 1
binary_config = ClassifierConfig(model_name="microsoft/resnet-50", device=DEVICE)
binary_classifier = Classifier(binary_config)
class_counts = np.bincount(binary_train_dataset.targets)
n = len(binary_train_dataset)
w0 = n / (2.0 * class_counts[0])
w1 = n / (2.0 * class_counts[1])
binary_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(w1 / w0))
binary_optimizer = Adam(binary_classifier.parameters(), lr=LR)
binary_classifier.train()
logging.info("Start binary classifier training")
# Training loop
while binary_epoch < EPOCH_NUM: # loop over the dataset multiple times
for i, data in enumerate(binary_trainloader):
inputs, labels = data
inputs, labels = inputs.to(DEVICE), labels.to(torch.float32).to(DEVICE)
# Zero the parameter gradients
binary_optimizer.zero_grad()
# Forward pass
outputs = binary_classifier(inputs)
loss = binary_criterion(outputs.logits, labels)
loss.backward()
binary_optimizer.step()
if i % 10 == 0: # print every 10 mini-batches
print(f"[Epoch {binary_epoch}, Batch {i}] loss: {loss.item():.3f}")
binary_epoch += 1
logging.info("Binary classifier training finished")
logging.info("Start binary classifier evaluation")
binary_classifier.eval()
test_loss = 0.0
test_labels = []
test_pridections = []
test_probs = []
with torch.no_grad():
for data in binary_testloader:
images, labels = data
images, labels = images.to(DEVICE), labels.to(torch.float32).to(DEVICE)
outputs = binary_classifier(images)
loss = binary_criterion(outputs.logits, labels)
test_loss += loss.item() * BATCH_SIZE
test_labels.extend(labels.cpu())
test_pridections.extend(outputs.logits.cpu())
test_probs.extend(outputs.probabilities.cpu())
test_loss = test_loss / len(binary_test_dataset)
logging.info(f"Binary classifier test loss {test_loss:.3f}")
test_labels = torch.stack(test_labels)
test_predictions = torch.stack(test_pridections)
test_probs = torch.stack(test_probs)
# Calculate metrics
acc = Accuracy(task="binary")(test_predictions, test_labels)
prec = Precision(task="binary", average="weighted")(test_predictions, test_labels)
rec = Recall(task="binary", average="weighted")(test_predictions, test_labels)
f1_score = F1Score(task="binary", average="weighted")(test_predictions, test_labels)
auroc_score = AUROC(task="binary", average="weighted")(test_probs, test_labels)
logging.info(f"Accuracy: {acc:.2f}")
logging.info(f"Precision: {prec:.2f}")
logging.info(f"Recall: {rec:.2f}")
logging.info(f"F1 Score: {f1_score:.2f}")
logging.info(f"AUROC Score: {auroc_score:.2f}")
if __name__ == "__main__":
train_evaluate_multiclass_classifier()
train_evaluate_binary_classifier()

View File

@@ -0,0 +1,78 @@
import torch
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
ClassifierConfig,
ClassifierOutput,
)
from tests.utils import require_package
def test_classifier_output():
output = ClassifierOutput(
logits=torch.tensor([1, 2, 3]), probabilities=torch.tensor([0.1, 0.2, 0.3]), hidden_states=None
)
assert (
f"{output}"
== "ClassifierOutput(logits=tensor([1, 2, 3]), probabilities=tensor([0.1000, 0.2000, 0.3000]), hidden_states=None)"
)
@require_package("transformers")
def test_binary_classifier_with_default_params():
config = ClassifierConfig()
classifier = Classifier(config)
batch_size = 10
input = torch.rand(batch_size, 3, 224, 224)
output = classifier(input)
assert output is not None
assert output.logits.shape == torch.Size([batch_size])
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
assert output.probabilities.shape == torch.Size([batch_size])
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
assert output.hidden_states.shape == torch.Size([batch_size, 2048])
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
@require_package("transformers")
def test_multiclass_classifier():
num_classes = 5
config = ClassifierConfig(num_classes=num_classes)
classifier = Classifier(config)
batch_size = 10
input = torch.rand(batch_size, 3, 224, 224)
output = classifier(input)
assert output is not None
assert output.logits.shape == torch.Size([batch_size, num_classes])
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
assert output.probabilities.shape == torch.Size([batch_size, num_classes])
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
assert output.hidden_states.shape == torch.Size([batch_size, 2048])
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
@require_package("transformers")
def test_default_device():
config = ClassifierConfig()
assert config.device == "cpu"
classifier = Classifier(config)
for p in classifier.parameters():
assert p.device == torch.device("cpu")
@require_package("transformers")
def test_explicit_device_setup():
config = ClassifierConfig(device="meta")
assert config.device == "meta"
classifier = Classifier(config)
for p in classifier.parameters():
assert p.device == torch.device("meta")