forked from tangger/lerobot
Compare commits
25 Commits
temp_branc
...
user/adil-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
278b56bce9 | ||
|
|
0ffc0a7170 | ||
|
|
43d9890489 | ||
|
|
963be41003 | ||
|
|
9edae4a8de | ||
|
|
89d8189d8b | ||
|
|
8b70b129dc | ||
|
|
db3925df28 | ||
|
|
f99e670976 | ||
|
|
eec28baa63 | ||
|
|
f1f04eb4f9 | ||
|
|
77a7f92139 | ||
|
|
35de91ef2b | ||
|
|
ee306e2f9b | ||
|
|
bae3b02928 | ||
|
|
5b4adc00bb | ||
|
|
22fbc9ea4a | ||
|
|
ca74a13d61 | ||
|
|
18a4598986 | ||
|
|
dc54d357ca | ||
|
|
08ec971086 | ||
|
|
b53d6e0ff2 | ||
|
|
70b652f791 | ||
|
|
7b68bfb73b | ||
|
|
7e0f20fbf2 |
@@ -611,11 +611,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return query_timestamps
|
||||
|
||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||
return {
|
||||
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
||||
for key, q_idx in query_indices.items()
|
||||
if key not in self.meta.video_keys
|
||||
}
|
||||
# Step 1: Combine all unique indices
|
||||
all_indices = sorted({idx for indices in query_indices.values() for idx in indices})
|
||||
|
||||
# Step 2: Select all required data at once
|
||||
selected_dataset = self.hf_dataset.select(all_indices).to_dict()
|
||||
selected_dataset = {key: torch.tensor(values) for key, values in selected_dataset.items()}
|
||||
|
||||
# Step 3: Map original indices to their positions in the selected dataset
|
||||
index_map = {original_idx: i for i, original_idx in enumerate(all_indices)}
|
||||
|
||||
# Step 4: Build the result for each key
|
||||
results = {}
|
||||
for key, q_indices in query_indices.items():
|
||||
if key not in self.meta.video_keys:
|
||||
mapped_indices = [index_map[idx] for idx in q_indices]
|
||||
results[key] = torch.stack([selected_dataset[key][i] for i in mapped_indices])
|
||||
|
||||
return results
|
||||
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
|
||||
@@ -25,13 +25,13 @@ from glob import glob
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import wandb
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
import wandb
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
|
||||
|
||||
|
||||
@@ -66,6 +66,12 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
||||
|
||||
return VQBeTPolicy, VQBeTConfig
|
||||
elif name == "sac":
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
return SACPolicy, SACConfig
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
|
||||
@@ -2,8 +2,6 @@ import json
|
||||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassifierConfig:
|
||||
@@ -13,7 +11,7 @@ class ClassifierConfig:
|
||||
hidden_dim: int = 256
|
||||
dropout_rate: float = 0.1
|
||||
model_name: str = "microsoft/resnet-50"
|
||||
device: str = "cuda" if torch.cuda.is_available() else "mps"
|
||||
device: str = "cpu"
|
||||
model_type: str = "cnn" # "transformer" or "cnn"
|
||||
|
||||
def save_pretrained(self, save_dir):
|
||||
|
||||
@@ -22,6 +22,13 @@ class ClassifierOutput:
|
||||
self.probabilities = probabilities
|
||||
self.hidden_states = hidden_states
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"ClassifierOutput(logits={self.logits}, "
|
||||
f"probabilities={self.probabilities}, "
|
||||
f"hidden_states={self.hidden_states})"
|
||||
)
|
||||
|
||||
|
||||
class Classifier(
|
||||
nn.Module,
|
||||
@@ -70,6 +77,8 @@ class Classifier(
|
||||
else:
|
||||
raise ValueError("Unsupported CNN architecture")
|
||||
|
||||
self.encoder = self.encoder.to(self.config.device)
|
||||
|
||||
def _freeze_encoder(self) -> None:
|
||||
"""Freeze the encoder parameters."""
|
||||
for param in self.encoder.parameters():
|
||||
@@ -93,6 +102,7 @@ class Classifier(
|
||||
nn.ReLU(),
|
||||
nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes),
|
||||
)
|
||||
self.classifier_head = self.classifier_head.to(self.config.device)
|
||||
|
||||
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Extract the appropriate output from the encoder."""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -26,4 +26,4 @@ class HILSerlPolicy(
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "hilserl"],
|
||||
):
|
||||
pass
|
||||
pass
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -15,25 +15,59 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class SACConfig:
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 84, 84],
|
||||
"observation.state": [4],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [2],
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
"observation.environment_state": "min_max",
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {"action": "min_max"},
|
||||
)
|
||||
|
||||
discount = 0.99
|
||||
temperature_init = 1.0
|
||||
num_critics = 2
|
||||
num_subsample_critics = None
|
||||
critic_lr = 3e-4
|
||||
actor_lr = 3e-4
|
||||
temperature_lr = 3e-4
|
||||
critic_target_update_weight = 0.005
|
||||
utd_ratio = 2
|
||||
state_encoder_hidden_dim = 256
|
||||
latent_dim = 256
|
||||
target_entropy = None
|
||||
# backup_entropy = False
|
||||
use_backup_entropy = True
|
||||
critic_network_kwargs = {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
}
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
}
|
||||
actor_network_kwargs = {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
}
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
}
|
||||
policy_kwargs = {
|
||||
"tanh_squash_distribution": True,
|
||||
"std_parameterization": "uniform",
|
||||
}
|
||||
"use_tanh_squash": True,
|
||||
"log_std_min": -5,
|
||||
"log_std_max": 2,
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -18,22 +18,18 @@
|
||||
# TODO: (1) better device management
|
||||
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Sequence, Tuple, Union
|
||||
|
||||
import einops
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor
|
||||
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
import numpy as np
|
||||
from typing import Callable, Optional, Tuple, Sequence
|
||||
|
||||
|
||||
|
||||
class SACPolicy(
|
||||
@@ -43,11 +39,11 @@ class SACPolicy(
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "RL", "SAC"],
|
||||
):
|
||||
|
||||
name = "sac"
|
||||
|
||||
def __init__(
|
||||
self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
@@ -60,33 +56,61 @@ class SACPolicy(
|
||||
)
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
# HACK: we need to pass the dataset_stats to the normalization functions
|
||||
dataset_stats = dataset_stats or {
|
||||
"action": {
|
||||
"min": torch.tensor([-1.0, -1.0, -1.0, -1.0]),
|
||||
"max": torch.tensor([1.0, 1.0, 1.0, 1.0]),
|
||||
}
|
||||
}
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
encoder = SACObservationEncoder(config)
|
||||
|
||||
encoder_critic = SACObservationEncoder(config)
|
||||
encoder_actor = SACObservationEncoder(config)
|
||||
# Define networks
|
||||
critic_nets = []
|
||||
for _ in range(config.num_critics):
|
||||
critic_net = Critic(
|
||||
encoder=encoder,
|
||||
network=MLP(**config.critic_network_kwargs)
|
||||
encoder=encoder_critic,
|
||||
network=MLP(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
),
|
||||
)
|
||||
critic_nets.append(critic_net)
|
||||
|
||||
|
||||
target_critic_nets = []
|
||||
for _ in range(config.num_critics):
|
||||
target_critic_net = Critic(
|
||||
encoder=encoder_critic,
|
||||
network=MLP(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
),
|
||||
)
|
||||
target_critic_nets.append(target_critic_net)
|
||||
|
||||
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
|
||||
self.critic_target = deepcopy(self.critic_ensemble)
|
||||
self.critic_target = create_critic_ensemble(target_critic_nets, config.num_critics)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
self.actor_network = Policy(
|
||||
encoder=encoder,
|
||||
network=MLP(**config.actor_network_kwargs),
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
|
||||
action_dim=config.output_shapes["action"][0],
|
||||
**config.policy_kwargs
|
||||
**config.policy_kwargs,
|
||||
)
|
||||
|
||||
self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
|
||||
# TODO: fix later device
|
||||
# TODO: Handle the case where the temparameter is a fixed
|
||||
self.log_alpha = torch.zeros(1, requires_grad=True, device="cpu")
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
@@ -98,24 +122,49 @@ class SACPolicy(
|
||||
"observation.state": deque(maxlen=1),
|
||||
"action": deque(maxlen=1),
|
||||
}
|
||||
if self._use_image:
|
||||
if "observation.image" in self.config.input_shapes:
|
||||
self._queues["observation.image"] = deque(maxlen=1)
|
||||
if self._use_env_state:
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=1)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
actions, _ = self.actor_network(batch['observations'])###
|
||||
"""Select action for inference/evaluation"""
|
||||
actions, _, _ = self.actor(batch)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
return actions
|
||||
|
||||
def critic_forward(
|
||||
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
|
||||
) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
actions: Action tensor
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from all critics
|
||||
"""
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = torch.stack([critic(observations, actions) for critic in critics])
|
||||
return q_values
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||
"""Run the batch through the model and compute the loss.
|
||||
|
||||
Returns a dictionary with loss as a tensor, and other information as native floats.
|
||||
"""
|
||||
# We have to actualize the value of the temperature because in the previous
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
temperature = self.temperature
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
# batch shape is (b, 2, ...) where index 1 returns the current observation and
|
||||
# the next observation for caluculating the right td index.
|
||||
actions = batch["action"][:, 0]
|
||||
# batch shape is (b, 2, ...) where index 1 returns the current observation and
|
||||
# the next observation for calculating the right td index.
|
||||
# actions = batch["action"][:, 0]
|
||||
actions = batch["action"]
|
||||
rewards = batch["next.reward"][:, 0]
|
||||
observations = {}
|
||||
next_observations = {}
|
||||
@@ -123,186 +172,227 @@ class SACPolicy(
|
||||
if k.startswith("observation."):
|
||||
observations[k] = batch[k][:, 0]
|
||||
next_observations[k] = batch[k][:, 1]
|
||||
|
||||
# perform image augmentation
|
||||
done = batch["next.done"]
|
||||
|
||||
# reward bias
|
||||
# from HIL-SERL code base
|
||||
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
|
||||
|
||||
with torch.no_grad():
|
||||
next_action_preds, next_log_probs, _ = self.actor(next_observations)
|
||||
|
||||
# calculate critics loss
|
||||
# 1- compute actions from policy
|
||||
action_preds, log_probs = self.actor_network(observations)
|
||||
# 2- compute q targets
|
||||
q_targets = self.target_qs(next_observations, action_preds)
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(next_observations, next_action_preds, use_target=True)
|
||||
|
||||
# critics subsample size
|
||||
min_q = q_targets.min(dim=0)
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
# backup entropy
|
||||
td_target = rewards + self.discount * min_q
|
||||
# critics subsample size
|
||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||
if self.config.use_backup_entropy:
|
||||
min_q -= self.temperature * next_log_probs
|
||||
td_target = rewards + self.config.discount * min_q * ~done
|
||||
|
||||
# 3- compute predicted qs
|
||||
q_preds = self.critic_ensemble(observations, actions)
|
||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
critics_loss = (
|
||||
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
||||
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
q_preds,
|
||||
einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]),
|
||||
reduction="none",
|
||||
).sum(0) # sum over ensemble
|
||||
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# q_targets depends on the reward and the next observations.
|
||||
* ~batch["next.reward_is_pad"]
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
).sum(0).mean()
|
||||
|
||||
# calculate actors loss
|
||||
# 1- temperature
|
||||
temperature = self.temperature()
|
||||
input=q_preds,
|
||||
target=td_target_duplicate,
|
||||
reduction="none",
|
||||
).mean(1)
|
||||
).sum()
|
||||
|
||||
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
||||
actions, log_probs = self.actor_network(observations) \
|
||||
|
||||
# 3- get q-value predictions
|
||||
with torch.no_grad():
|
||||
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
|
||||
actor_loss = (
|
||||
-(q_preds - temperature * log_probs).mean()
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
).mean()
|
||||
actions_pi, log_probs, _ = self.actor(observations)
|
||||
with torch.inference_mode():
|
||||
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
|
||||
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
|
||||
|
||||
# calculate temperature loss
|
||||
# 1- calculate entropy
|
||||
entropy = -log_probs.mean()
|
||||
temperature_loss = temperature * (entropy - self.target_entropy).mean()
|
||||
with torch.no_grad():
|
||||
_, log_probs, _ = self.actor(observations)
|
||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
|
||||
|
||||
loss = critics_loss + actor_loss + temperature_loss
|
||||
|
||||
return {
|
||||
"critics_loss": critics_loss.item(),
|
||||
"actor_loss": actor_loss.item(),
|
||||
"temperature_loss": temperature_loss.item(),
|
||||
"temperature": temperature.item(),
|
||||
"entropy": entropy.item(),
|
||||
"loss": loss,
|
||||
"critics_loss": critics_loss.item(),
|
||||
"actor_loss": actor_loss.item(),
|
||||
"mean_q_predicts": min_q_preds.mean().item(),
|
||||
"min_q_predicts": min_q_preds.min().item(),
|
||||
"max_q_predicts": min_q_preds.max().item(),
|
||||
"temperature_loss": temperature_loss.item(),
|
||||
"temperature": temperature,
|
||||
"mean_log_probs": log_probs.mean().item(),
|
||||
"min_log_probs": log_probs.min().item(),
|
||||
"max_log_probs": log_probs.max().item(),
|
||||
"td_target_mean": td_target.mean().item(),
|
||||
"td_target_max": td_target.max().item(),
|
||||
"action_mean": actions.mean().item(),
|
||||
"entropy": log_probs.mean().item(),
|
||||
"loss": loss,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
def update(self):
|
||||
self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight)
|
||||
#for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
|
||||
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)
|
||||
def update_target_networks(self):
|
||||
"""Update target networks with exponential moving average"""
|
||||
for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False):
|
||||
for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
def compute_loss_critic(self, observations, actions, rewards, next_observations, done) -> Tensor:
|
||||
temperature = self.log_alpha.exp().item()
|
||||
with torch.no_grad():
|
||||
next_action_preds, next_log_probs, _ = self.actor(next_observations)
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(
|
||||
observations=next_observations, actions=next_action_preds, use_target=True
|
||||
)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
# critics subsample size
|
||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||
if self.config.use_backup_entropy:
|
||||
min_q = min_q - (temperature * next_log_probs)
|
||||
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
||||
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
input=q_preds,
|
||||
target=td_target_duplicate,
|
||||
reduction="none",
|
||||
).mean(1)
|
||||
).sum()
|
||||
return critics_loss
|
||||
|
||||
def compute_loss_temperature(self, observations) -> Tensor:
|
||||
"""Compute the temperature loss"""
|
||||
# calculate temperature loss
|
||||
with torch.no_grad():
|
||||
_, log_probs, _ = self.actor(observations)
|
||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
|
||||
return temperature_loss
|
||||
|
||||
def compute_loss_actor(self, observations) -> Tensor:
|
||||
temperature = self.log_alpha.exp().item()
|
||||
|
||||
actions_pi, log_probs, _ = self.actor(observations)
|
||||
|
||||
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
|
||||
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
|
||||
return actor_loss
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: SACConfig,
|
||||
input_dim: int,
|
||||
hidden_dims: list[int],
|
||||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||
activate_final: bool = False,
|
||||
dropout_rate: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.activate_final = config.activate_final
|
||||
self.activate_final = activate_final
|
||||
layers = []
|
||||
|
||||
for i, size in enumerate(config.network_hidden_dims):
|
||||
layers.append(nn.Linear(config.network_hidden_dims[i-1] if i > 0 else config.network_hidden_dims[0], size))
|
||||
|
||||
if i + 1 < len(config.network_hidden_dims) or activate_final:
|
||||
|
||||
# First layer uses input_dim
|
||||
layers.append(nn.Linear(input_dim, hidden_dims[0]))
|
||||
|
||||
# Add activation after first layer
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(hidden_dims[0]))
|
||||
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
|
||||
|
||||
# Rest of the layers
|
||||
for i in range(1, len(hidden_dims)):
|
||||
layers.append(nn.Linear(hidden_dims[i - 1], hidden_dims[i]))
|
||||
|
||||
if i + 1 < len(hidden_dims) or activate_final:
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(size))
|
||||
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
|
||||
|
||||
layers.append(nn.LayerNorm(hidden_dims[i]))
|
||||
layers.append(
|
||||
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor, train: bool = False) -> torch.Tensor:
|
||||
# in training mode or not. TODO: find better way to do this
|
||||
self.train(train)
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(x)
|
||||
|
||||
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
network: nn.Module,
|
||||
init_final: Optional[float] = None,
|
||||
activate_final: bool = False,
|
||||
device: str = "cuda"
|
||||
device: str = "cpu",
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
self.encoder = encoder
|
||||
self.network = network
|
||||
self.init_final = init_final
|
||||
self.activate_final = activate_final
|
||||
|
||||
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
out_features = layer.out_features
|
||||
break
|
||||
|
||||
# Output layer
|
||||
if init_final is not None:
|
||||
if self.activate_final:
|
||||
self.output_layer = nn.Linear(network.net[-3].out_features, 1)
|
||||
else:
|
||||
self.output_layer = nn.Linear(network.net[-2].out_features, 1)
|
||||
self.output_layer = nn.Linear(out_features, 1)
|
||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
if self.activate_final:
|
||||
self.output_layer = nn.Linear(network.net[-3].out_features, 1)
|
||||
else:
|
||||
self.output_layer = nn.Linear(network.net[-2].out_features, 1)
|
||||
self.output_layer = nn.Linear(out_features, 1)
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
self,
|
||||
observations: dict[str, torch.Tensor],
|
||||
actions: torch.Tensor,
|
||||
train: bool = False
|
||||
) -> torch.Tensor:
|
||||
self.train(train)
|
||||
|
||||
observations = observations.to(self.device)
|
||||
# Move each tensor in observations to device
|
||||
observations = {k: v.to(self.device) for k, v in observations.items()}
|
||||
actions = actions.to(self.device)
|
||||
|
||||
if self.encoder is not None:
|
||||
obs_enc = self.encoder(observations)
|
||||
else:
|
||||
obs_enc = observations
|
||||
|
||||
|
||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
x = self.network(inputs)
|
||||
value = self.output_layer(x)
|
||||
return value.squeeze(-1)
|
||||
|
||||
def q_value_ensemble(
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
train: bool = False
|
||||
) -> torch.Tensor:
|
||||
observations = observations.to(self.device)
|
||||
actions = actions.to(self.device)
|
||||
|
||||
if len(actions.shape) == 3: # [batch_size, num_actions, action_dim]
|
||||
batch_size, num_actions = actions.shape[:2]
|
||||
obs_expanded = observations.unsqueeze(1).expand(-1, num_actions, -1)
|
||||
obs_flat = obs_expanded.reshape(-1, observations.shape[-1])
|
||||
actions_flat = actions.reshape(-1, actions.shape[-1])
|
||||
q_values = self(obs_flat, actions_flat, train)
|
||||
return q_values.reshape(batch_size, num_actions)
|
||||
else:
|
||||
return self(observations, actions, train)
|
||||
|
||||
|
||||
class Policy(nn.Module):
|
||||
@@ -311,115 +401,93 @@ class Policy(nn.Module):
|
||||
encoder: Optional[nn.Module],
|
||||
network: nn.Module,
|
||||
action_dim: int,
|
||||
std_parameterization: str = "exp",
|
||||
std_min: float = 1e-5,
|
||||
std_max: float = 10.0,
|
||||
tanh_squash_distribution: bool = False,
|
||||
log_std_min: float = -5,
|
||||
log_std_max: float = 2,
|
||||
fixed_std: Optional[torch.Tensor] = None,
|
||||
init_final: Optional[float] = None,
|
||||
activate_final: bool = False,
|
||||
device: str = "cuda"
|
||||
use_tanh_squash: bool = False,
|
||||
device: str = "cpu",
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
self.encoder = encoder
|
||||
self.network = network
|
||||
self.action_dim = action_dim
|
||||
self.std_parameterization = std_parameterization
|
||||
self.std_min = std_min
|
||||
self.std_max = std_max
|
||||
self.tanh_squash_distribution = tanh_squash_distribution
|
||||
self.log_std_min = log_std_min
|
||||
self.log_std_max = log_std_max
|
||||
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
|
||||
self.activate_final = activate_final
|
||||
|
||||
self.use_tanh_squash = use_tanh_squash
|
||||
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
out_features = layer.out_features
|
||||
break
|
||||
|
||||
# Mean layer
|
||||
if self.activate_final:
|
||||
self.mean_layer = nn.Linear(network.net[-3].out_features, action_dim)
|
||||
else:
|
||||
self.mean_layer = nn.Linear(network.net[-2].out_features, action_dim)
|
||||
self.mean_layer = nn.Linear(out_features, action_dim)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.mean_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.mean_layer.weight)
|
||||
|
||||
|
||||
# Standard deviation layer or parameter
|
||||
if fixed_std is None:
|
||||
if std_parameterization == "uniform":
|
||||
self.log_stds = nn.Parameter(torch.zeros(action_dim, device=self.device))
|
||||
self.std_layer = nn.Linear(out_features, action_dim)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
if self.activate_final:
|
||||
self.std_layer = nn.Linear(network.net[-3].out_features, action_dim)
|
||||
else:
|
||||
self.std_layer = nn.Linear(network.net[-2].out_features, action_dim)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.std_layer.weight)
|
||||
|
||||
orthogonal_init()(self.std_layer.weight)
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
temperature: float = 1.0,
|
||||
train: bool = False,
|
||||
non_squash_distribution: bool = False
|
||||
) -> torch.distributions.Distribution:
|
||||
self.train(train)
|
||||
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Encode observations if encoder exists
|
||||
if self.encoder is not None:
|
||||
with torch.set_grad_enabled(train):
|
||||
obs_enc = self.encoder(observations, train=train)
|
||||
else:
|
||||
obs_enc = observations
|
||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||
|
||||
# Get network outputs
|
||||
outputs = self.network(obs_enc)
|
||||
means = self.mean_layer(outputs)
|
||||
|
||||
|
||||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
if self.std_parameterization == "exp":
|
||||
log_stds = self.std_layer(outputs)
|
||||
stds = torch.exp(log_stds)
|
||||
elif self.std_parameterization == "softplus":
|
||||
stds = torch.nn.functional.softplus(self.std_layer(outputs))
|
||||
elif self.std_parameterization == "uniform":
|
||||
stds = torch.exp(self.log_stds).expand_as(means)
|
||||
log_std = self.std_layer(outputs)
|
||||
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||
|
||||
if self.use_tanh_squash:
|
||||
log_std = torch.tanh(log_std)
|
||||
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid std_parameterization: {self.std_parameterization}"
|
||||
)
|
||||
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
|
||||
else:
|
||||
assert self.std_parameterization == "fixed"
|
||||
stds = self.fixed_std.expand_as(means)
|
||||
log_std = self.fixed_std.expand_as(means)
|
||||
|
||||
# Clip standard deviations and scale with temperature
|
||||
temperature = torch.tensor(temperature, device=self.device)
|
||||
stds = torch.clamp(stds, self.std_min, self.std_max) * torch.sqrt(temperature)
|
||||
# uses tanh activation function to squash the action to be in the range of [-1, 1]
|
||||
normal = torch.distributions.Normal(means, torch.exp(log_std))
|
||||
x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
|
||||
log_probs = normal.log_prob(x_t) # Base log probability before Tanh
|
||||
|
||||
# Create distribution
|
||||
if self.tanh_squash_distribution and not non_squash_distribution:
|
||||
distribution = TanhMultivariateNormalDiag(
|
||||
loc=means,
|
||||
scale_diag=stds,
|
||||
)
|
||||
if self.use_tanh_squash:
|
||||
actions = torch.tanh(x_t)
|
||||
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
|
||||
else:
|
||||
distribution = torch.distributions.Normal(
|
||||
loc=means,
|
||||
scale=stds,
|
||||
)
|
||||
actions = x_t # No Tanh; raw Gaussian sample
|
||||
|
||||
log_probs = log_probs.sum(-1) # Sum over action dimensions
|
||||
means = torch.tanh(means) if self.use_tanh_squash else means
|
||||
return actions, log_probs, means
|
||||
|
||||
return distribution
|
||||
|
||||
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
"""Get encoded features from observations"""
|
||||
observations = observations.to(self.device)
|
||||
if self.encoder is not None:
|
||||
with torch.no_grad():
|
||||
return self.encoder(observations, train=False)
|
||||
with torch.inference_mode():
|
||||
return self.encoder(observations)
|
||||
return observations
|
||||
|
||||
|
||||
@@ -461,19 +529,13 @@ class SACObservationEncoder(nn.Module):
|
||||
)
|
||||
if "observation.state" in config.input_shapes:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
nn.Linear(config.input_shapes["observation.state"][0], config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
|
||||
),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
nn.Linear(config.input_shapes["observation.environment_state"][0], config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
@@ -493,175 +555,25 @@ class SACObservationEncoder(nn.Module):
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
# TODO(ke-wang): currently average over all features, concatenate all features maybe a better way
|
||||
return torch.stack(feat, dim=0).mean(0)
|
||||
|
||||
|
||||
class LagrangeMultiplier(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
init_value: float = 1.0,
|
||||
constraint_shape: Sequence[int] = (),
|
||||
device: str = "cuda"
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1)
|
||||
|
||||
# Initialize the Lagrange multiplier as a parameter
|
||||
self.lagrange = nn.Parameter(
|
||||
torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device)
|
||||
)
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
lhs: Optional[torch.Tensor] = None,
|
||||
rhs: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
# Get the multiplier value based on parameterization
|
||||
multiplier = torch.nn.functional.softplus(self.lagrange)
|
||||
|
||||
# Return the raw multiplier if no constraint values provided
|
||||
if lhs is None:
|
||||
return multiplier
|
||||
|
||||
# Move inputs to device
|
||||
lhs = lhs.to(self.device)
|
||||
if rhs is not None:
|
||||
rhs = rhs.to(self.device)
|
||||
|
||||
# Use the multiplier to compute the Lagrange penalty
|
||||
if rhs is None:
|
||||
rhs = torch.zeros_like(lhs, device=self.device)
|
||||
|
||||
diff = lhs - rhs
|
||||
|
||||
assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}"
|
||||
|
||||
return multiplier * diff
|
||||
|
||||
|
||||
# The TanhMultivariateNormalDiag is a probability distribution that represents a transformed normal (Gaussian) distribution where:
|
||||
# 1. The base distribution is a diagonal multivariate normal distribution
|
||||
# 2. The samples from this normal distribution are transformed through a tanh function, which squashes the values to be between -1 and 1
|
||||
# 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation
|
||||
# This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces
|
||||
class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
|
||||
def __init__(
|
||||
self,
|
||||
loc: torch.Tensor,
|
||||
scale_diag: torch.Tensor,
|
||||
low: Optional[torch.Tensor] = None,
|
||||
high: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# Create base normal distribution
|
||||
base_distribution = torch.distributions.Normal(loc=loc, scale=scale_diag)
|
||||
|
||||
# Create list of transforms
|
||||
transforms = []
|
||||
|
||||
# Add tanh transform
|
||||
transforms.append(torch.distributions.transforms.TanhTransform())
|
||||
|
||||
# Add rescaling transform if bounds are provided
|
||||
if low is not None and high is not None:
|
||||
transforms.append(
|
||||
torch.distributions.transforms.AffineTransform(
|
||||
loc=(high + low) / 2,
|
||||
scale=(high - low) / 2
|
||||
)
|
||||
)
|
||||
|
||||
# Initialize parent class
|
||||
super().__init__(
|
||||
base_distribution=base_distribution,
|
||||
transforms=transforms
|
||||
)
|
||||
|
||||
# Store parameters
|
||||
self.loc = loc
|
||||
self.scale_diag = scale_diag
|
||||
self.low = low
|
||||
self.high = high
|
||||
|
||||
def mode(self) -> torch.Tensor:
|
||||
"""Get the mode of the transformed distribution"""
|
||||
# The mode of a normal distribution is its mean
|
||||
mode = self.loc
|
||||
|
||||
# Apply transforms
|
||||
for transform in self.transforms:
|
||||
mode = transform(mode)
|
||||
|
||||
return mode
|
||||
|
||||
def rsample(self, sample_shape=torch.Size()) -> torch.Tensor:
|
||||
"""
|
||||
Reparameterized sample from the distribution
|
||||
"""
|
||||
# Sample from base distribution
|
||||
x = self.base_dist.rsample(sample_shape)
|
||||
|
||||
# Apply transforms
|
||||
for transform in self.transforms:
|
||||
x = transform(x)
|
||||
|
||||
return x
|
||||
|
||||
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute log probability of a value
|
||||
Includes the log det jacobian for the transforms
|
||||
"""
|
||||
# Initialize log prob
|
||||
log_prob = torch.zeros_like(value[..., 0])
|
||||
|
||||
# Inverse transforms to get back to normal distribution
|
||||
q = value
|
||||
for transform in reversed(self.transforms):
|
||||
q = transform.inv(q)
|
||||
log_prob = log_prob - transform.log_abs_det_jacobian(q, transform(q))
|
||||
|
||||
# Add base distribution log prob
|
||||
log_prob = log_prob + self.base_dist.log_prob(q).sum(-1)
|
||||
|
||||
return log_prob
|
||||
|
||||
def sample_and_log_prob(self, sample_shape=torch.Size()) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Sample from the distribution and compute log probability
|
||||
"""
|
||||
x = self.rsample(sample_shape)
|
||||
log_prob = self.log_prob(x)
|
||||
return x, log_prob
|
||||
|
||||
def entropy(self) -> torch.Tensor:
|
||||
"""
|
||||
Compute entropy of the distribution
|
||||
"""
|
||||
# Start with base distribution entropy
|
||||
entropy = self.base_dist.entropy().sum(-1)
|
||||
|
||||
# Add log det jacobian for each transform
|
||||
x = self.rsample()
|
||||
for transform in self.transforms:
|
||||
entropy = entropy + transform.log_abs_det_jacobian(x, transform(x))
|
||||
x = transform(x)
|
||||
|
||||
return entropy
|
||||
|
||||
|
||||
def create_critic_ensemble(critic_class, num_critics: int, device: str = "cuda") -> nn.ModuleList:
|
||||
"""Creates an ensemble of critic networks"""
|
||||
critics = nn.ModuleList([critic_class() for _ in range(num_critics)])
|
||||
return critics.to(device)
|
||||
@property
|
||||
def output_dim(self) -> int:
|
||||
"""Returns the dimension of the encoder output"""
|
||||
return self.config.latent_dim
|
||||
|
||||
|
||||
def orthogonal_init():
|
||||
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
||||
|
||||
|
||||
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cpu") -> nn.ModuleList:
|
||||
"""Creates an ensemble of critic networks"""
|
||||
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
|
||||
return nn.ModuleList(critics).to(device)
|
||||
|
||||
|
||||
# borrowed from tdmpc
|
||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||
@@ -669,7 +581,7 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
|
||||
Args:
|
||||
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
|
||||
(B, *), where * is any number of dimensions.
|
||||
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
|
||||
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
|
||||
can be more than 1 dimensions, generally different from *.
|
||||
Returns:
|
||||
A return value from the callable reshaped to (**, *).
|
||||
@@ -680,4 +592,3 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
|
||||
inp = torch.flatten(image_tensor, end_dim=-4)
|
||||
flat_out = fn(inp)
|
||||
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
||||
|
||||
|
||||
89
lerobot/configs/policy/sac_pusht_keypoints.yaml
Normal file
89
lerobot/configs/policy/sac_pusht_keypoints.yaml
Normal file
@@ -0,0 +1,89 @@
|
||||
# @package _global_
|
||||
|
||||
# Train with:
|
||||
#
|
||||
# python lerobot/scripts/train.py \
|
||||
# env=pusht \
|
||||
# +dataset=lerobot/pusht_keypoints
|
||||
|
||||
seed: 1
|
||||
dataset_repo_id: lerobot/pusht_keypoints
|
||||
|
||||
training:
|
||||
offline_steps: 0
|
||||
|
||||
# Offline training dataloader
|
||||
num_workers: 4
|
||||
|
||||
batch_size: 256
|
||||
grad_clip_norm: 10.0
|
||||
lr: 3e-4
|
||||
|
||||
eval_freq: 2500
|
||||
log_freq: 500
|
||||
save_freq: 50000
|
||||
|
||||
online_steps: 1000000
|
||||
online_rollout_n_episodes: 10
|
||||
online_rollout_batch_size: 10
|
||||
online_steps_between_rollouts: 1000
|
||||
online_sampling_ratio: 1.0
|
||||
online_env_seed: 10000
|
||||
online_buffer_capacity: 40000
|
||||
online_buffer_seed_size: 0
|
||||
do_online_rollout_async: false
|
||||
|
||||
delta_timestamps:
|
||||
observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
action: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
|
||||
policy:
|
||||
name: sac
|
||||
|
||||
pretrained_model_path:
|
||||
|
||||
# Input / output structure.
|
||||
n_action_repeats: 1
|
||||
horizon: 2
|
||||
n_action_steps: 2
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.environment_state: [16]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.environment_state: min_max
|
||||
observation.state: min_max
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
# image_encoder_hidden_dim: 32
|
||||
discount: 0.99
|
||||
temperature_init: 1.0
|
||||
num_critics: 2
|
||||
num_subsample_critics: None
|
||||
critic_lr: 3e-4
|
||||
actor_lr: 3e-4
|
||||
temperature_lr: 3e-4
|
||||
critic_target_update_weight: 0.005
|
||||
utd_ratio: 2
|
||||
|
||||
|
||||
# # Loss coefficients.
|
||||
# reward_coeff: 0.5
|
||||
# expectile_weight: 0.9
|
||||
# value_coeff: 0.1
|
||||
# consistency_coeff: 20.0
|
||||
# advantage_scaling: 3.0
|
||||
# pi_coeff: 0.5
|
||||
# temporal_decay_coeff: 0.5
|
||||
# # Target model.
|
||||
# target_model_momentum: 0.995
|
||||
@@ -24,7 +24,7 @@ python lerobot/scripts/eval_on_robot.py \
|
||||
```
|
||||
|
||||
**NOTE** (michel-aractingi): This script is incomplete and it is being prepared
|
||||
for running training on the real robot.
|
||||
for running training on the real robot.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -47,7 +47,7 @@ from lerobot.common.utils.utils import (
|
||||
|
||||
|
||||
def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, use_amp: bool = True) -> dict:
|
||||
"""Run a batched policy rollout on the real robot.
|
||||
"""Run a batched policy rollout on the real robot.
|
||||
|
||||
The return dictionary contains:
|
||||
"robot": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation
|
||||
@@ -64,7 +64,7 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20,
|
||||
extraneous elements from the sequences above.
|
||||
|
||||
Args:
|
||||
robot: The robot class that defines the interface with the real robot.
|
||||
robot: The robot class that defines the interface with the real robot.
|
||||
policy: The policy. Must be a PyTorch nn module.
|
||||
|
||||
Returns:
|
||||
@@ -77,7 +77,7 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20,
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
# Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready.
|
||||
# policy.reset()
|
||||
# policy.reset()
|
||||
|
||||
# Get observation from real robot
|
||||
observation = robot.capture_observation()
|
||||
|
||||
@@ -93,6 +93,18 @@ def make_optimizer_and_scheduler(cfg, policy):
|
||||
elif policy.name == "tdmpc":
|
||||
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
|
||||
lr_scheduler = None
|
||||
|
||||
elif policy.name == "sac":
|
||||
optimizer = torch.optim.Adam(
|
||||
[
|
||||
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
|
||||
{"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr},
|
||||
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
|
||||
{"params": [policy.log_alpha], "lr": policy.config.temperature_lr},
|
||||
]
|
||||
)
|
||||
lr_scheduler = None
|
||||
|
||||
elif cfg.policy.name == "vqbet":
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
|
||||
|
||||
@@ -311,6 +323,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
logging.info("make_dataset")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
# TODO (michel-aractingi): temporary fix to avoid datasets with task_index key that doesn't exist in online environment
|
||||
# i.e., pusht
|
||||
if "task_index" in offline_dataset.hf_dataset[0]:
|
||||
offline_dataset.hf_dataset = offline_dataset.hf_dataset.remove_columns(["task_index"])
|
||||
|
||||
if isinstance(offline_dataset, MultiLeRobotDataset):
|
||||
logging.info(
|
||||
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||
|
||||
@@ -22,7 +22,6 @@ from pprint import pformat
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import wandb
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
@@ -31,6 +30,7 @@ from torch.cuda.amp import GradScaler
|
||||
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
|
||||
from tqdm import tqdm
|
||||
|
||||
import wandb
|
||||
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.logger import Logger
|
||||
|
||||
1112
lerobot/scripts/train_sac.py
Normal file
1112
lerobot/scripts/train_sac.py
Normal file
File diff suppressed because it is too large
Load Diff
153
poetry.lock
generated
153
poetry.lock
generated
@@ -3139,6 +3139,27 @@ dev = ["changelist (==0.5)"]
|
||||
lint = ["pre-commit (==3.7.0)"]
|
||||
test = ["pytest (>=7.4)", "pytest-cov (>=4.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "lightning-utilities"
|
||||
version = "0.11.9"
|
||||
description = "Lightning toolbox for across the our ecosystem."
|
||||
optional = true
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "lightning_utilities-0.11.9-py3-none-any.whl", hash = "sha256:ac6d4e9e28faf3ff4be997876750fee10dc604753dbc429bf3848a95c5d7e0d2"},
|
||||
{file = "lightning_utilities-0.11.9.tar.gz", hash = "sha256:f5052b81344cc2684aa9afd74b7ce8819a8f49a858184ec04548a5a109dfd053"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
packaging = ">=17.1"
|
||||
setuptools = "*"
|
||||
typing-extensions = "*"
|
||||
|
||||
[package.extras]
|
||||
cli = ["fire"]
|
||||
docs = ["requests (>=2.0.0)"]
|
||||
typing = ["mypy (>=1.0.0)", "types-setuptools"]
|
||||
|
||||
[[package]]
|
||||
name = "llvmlite"
|
||||
version = "0.43.0"
|
||||
@@ -6798,6 +6819,38 @@ webencodings = ">=0.4"
|
||||
doc = ["sphinx", "sphinx_rtd_theme"]
|
||||
test = ["pytest", "ruff"]
|
||||
|
||||
[[package]]
|
||||
name = "tokenizers"
|
||||
version = "0.21.0"
|
||||
description = ""
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:3c4c93eae637e7d2aaae3d376f06085164e1660f89304c0ab2b1d08a406636b2"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:f53ea537c925422a2e0e92a24cce96f6bc5046bbef24a1652a5edc8ba975f62e"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b177fb54c4702ef611de0c069d9169f0004233890e0c4c5bd5508ae05abf193"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6b43779a269f4629bebb114e19c3fca0223296ae9fea8bb9a7a6c6fb0657ff8e"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9aeb255802be90acfd363626753fda0064a8df06031012fe7d52fd9a905eb00e"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d8b09dbeb7a8d73ee204a70f94fc06ea0f17dcf0844f16102b9f414f0b7463ba"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:400832c0904f77ce87c40f1a8a27493071282f785724ae62144324f171377273"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e84ca973b3a96894d1707e189c14a774b701596d579ffc7e69debfc036a61a04"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:eb7202d231b273c34ec67767378cd04c767e967fda12d4a9e36208a34e2f137e"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:089d56db6782a73a27fd8abf3ba21779f5b85d4a9f35e3b493c7bbcbbf0d539b"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:c87ca3dc48b9b1222d984b6b7490355a6fdb411a2d810f6f05977258400ddb74"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:4145505a973116f91bc3ac45988a92e618a6f83eb458f49ea0790df94ee243ff"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-win32.whl", hash = "sha256:eb1702c2f27d25d9dd5b389cc1f2f51813e99f8ca30d9e25348db6585a97e24a"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-win_amd64.whl", hash = "sha256:87841da5a25a3a5f70c102de371db120f41873b854ba65e52bccd57df5a3780c"},
|
||||
{file = "tokenizers-0.21.0.tar.gz", hash = "sha256:ee0894bf311b75b0c03079f33859ae4b2334d675d4e93f5a4132e1eae2834fe4"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
huggingface-hub = ">=0.16.4,<1.0"
|
||||
|
||||
[package.extras]
|
||||
dev = ["tokenizers[testing]"]
|
||||
docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"]
|
||||
testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests", "ruff"]
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.0.2"
|
||||
@@ -6863,6 +6916,34 @@ typing-extensions = ">=4.8.0"
|
||||
opt-einsum = ["opt-einsum (>=3.3)"]
|
||||
optree = ["optree (>=0.11.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "torchmetrics"
|
||||
version = "1.6.0"
|
||||
description = "PyTorch native Metrics"
|
||||
optional = true
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "torchmetrics-1.6.0-py3-none-any.whl", hash = "sha256:a508cdd87766cedaaf55a419812bf9f493aff8fffc02cc19df5a8e2e7ccb942a"},
|
||||
{file = "torchmetrics-1.6.0.tar.gz", hash = "sha256:aebba248708fb90def20cccba6f55bddd134a58de43fb22b0c5ca0f3a89fa984"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
lightning-utilities = ">=0.8.0"
|
||||
numpy = ">1.20.0"
|
||||
packaging = ">17.1"
|
||||
torch = ">=2.0.0"
|
||||
|
||||
[package.extras]
|
||||
all = ["SciencePlots (>=2.0.0)", "gammatone (>=1.0.0)", "ipadic (>=1.0.0)", "librosa (>=0.10.0)", "matplotlib (>=3.6.0)", "mecab-python3 (>=1.0.6)", "mypy (==1.13.0)", "nltk (>3.8.1)", "numpy (<2.0)", "onnxruntime (>=1.12.0)", "pesq (>=0.0.4)", "piq (<=0.8.0)", "pycocotools (>2.0.0)", "pystoi (>=0.4.0)", "regex (>=2021.9.24)", "requests (>=2.19.0)", "scipy (>1.0.0)", "sentencepiece (>=0.2.0)", "torch (==2.5.1)", "torch-fidelity (<=0.4.0)", "torchaudio (>=2.0.1)", "torchvision (>=0.15.1)", "tqdm (<4.68.0)", "transformers (>4.4.0)", "transformers (>=4.42.3)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"]
|
||||
audio = ["gammatone (>=1.0.0)", "librosa (>=0.10.0)", "numpy (<2.0)", "onnxruntime (>=1.12.0)", "pesq (>=0.0.4)", "pystoi (>=0.4.0)", "requests (>=2.19.0)", "torchaudio (>=2.0.1)"]
|
||||
detection = ["pycocotools (>2.0.0)", "torchvision (>=0.15.1)"]
|
||||
dev = ["PyTDC (==0.4.1)", "SciencePlots (>=2.0.0)", "bert-score (==0.3.13)", "dython (==0.7.6)", "dython (>=0.7.8,<0.8.0)", "fairlearn", "fast-bss-eval (>=0.1.0)", "faster-coco-eval (>=1.6.3)", "gammatone (>=1.0.0)", "huggingface-hub (<0.27)", "ipadic (>=1.0.0)", "jiwer (>=2.3.0)", "kornia (>=0.6.7)", "librosa (>=0.10.0)", "lpips (<=0.1.4)", "matplotlib (>=3.6.0)", "mecab-ko (>=1.0.0,<1.1.0)", "mecab-ko-dic (>=1.0.0)", "mecab-python3 (>=1.0.6)", "mir-eval (>=0.6)", "monai (==1.3.2)", "monai (==1.4.0)", "mypy (==1.13.0)", "netcal (>1.0.0)", "nltk (>3.8.1)", "numpy (<2.0)", "numpy (<2.2.0)", "onnxruntime (>=1.12.0)", "pandas (>1.4.0)", "permetrics (==2.0.0)", "pesq (>=0.0.4)", "piq (<=0.8.0)", "pycocotools (>2.0.0)", "pystoi (>=0.4.0)", "pytorch-msssim (==1.0.0)", "regex (>=2021.9.24)", "requests (>=2.19.0)", "rouge-score (>0.1.0)", "sacrebleu (>=2.3.0)", "scikit-image (>=0.19.0)", "scipy (>1.0.0)", "sentencepiece (>=0.2.0)", "sewar (>=0.4.4)", "statsmodels (>0.13.5)", "torch (==2.5.1)", "torch-complex (<0.5.0)", "torch-fidelity (<=0.4.0)", "torchaudio (>=2.0.1)", "torchvision (>=0.15.1)", "tqdm (<4.68.0)", "transformers (>4.4.0)", "transformers (>=4.42.3)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"]
|
||||
image = ["scipy (>1.0.0)", "torch-fidelity (<=0.4.0)", "torchvision (>=0.15.1)"]
|
||||
multimodal = ["piq (<=0.8.0)", "transformers (>=4.42.3)"]
|
||||
text = ["ipadic (>=1.0.0)", "mecab-python3 (>=1.0.6)", "nltk (>3.8.1)", "regex (>=2021.9.24)", "sentencepiece (>=0.2.0)", "tqdm (<4.68.0)", "transformers (>4.4.0)"]
|
||||
typing = ["mypy (==1.13.0)", "torch (==2.5.1)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"]
|
||||
visual = ["SciencePlots (>=2.0.0)", "matplotlib (>=3.6.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "torchvision"
|
||||
version = "0.19.1"
|
||||
@@ -6956,6 +7037,75 @@ files = [
|
||||
docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"]
|
||||
test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<8.2)", "pytest-mock", "pytest-mypy-testing"]
|
||||
|
||||
[[package]]
|
||||
name = "transformers"
|
||||
version = "4.47.0"
|
||||
description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
|
||||
optional = true
|
||||
python-versions = ">=3.9.0"
|
||||
files = [
|
||||
{file = "transformers-4.47.0-py3-none-any.whl", hash = "sha256:a8e1bafdaae69abdda3cad638fe392e37c86d2ce0ecfcae11d60abb8f949ff4d"},
|
||||
{file = "transformers-4.47.0.tar.gz", hash = "sha256:f8ead7a5a4f6937bb507e66508e5e002dc5930f7b6122a9259c37b099d0f3b19"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
filelock = "*"
|
||||
huggingface-hub = ">=0.24.0,<1.0"
|
||||
numpy = ">=1.17"
|
||||
packaging = ">=20.0"
|
||||
pyyaml = ">=5.1"
|
||||
regex = "!=2019.12.17"
|
||||
requests = "*"
|
||||
safetensors = ">=0.4.1"
|
||||
tokenizers = ">=0.21,<0.22"
|
||||
tqdm = ">=4.27"
|
||||
|
||||
[package.extras]
|
||||
accelerate = ["accelerate (>=0.26.0)"]
|
||||
agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch"]
|
||||
all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch", "torchaudio", "torchvision"]
|
||||
audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
benchmark = ["optimum-benchmark (>=0.3.0)"]
|
||||
codecarbon = ["codecarbon (==1.2.0)"]
|
||||
deepspeed = ["accelerate (>=0.26.0)", "deepspeed (>=0.9.3)"]
|
||||
deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
|
||||
dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
|
||||
dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.21,<0.22)", "urllib3 (<2.0.0)"]
|
||||
dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "libcst", "librosa", "nltk (<=3.8.1)", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
|
||||
flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"]
|
||||
flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
ftfy = ["ftfy"]
|
||||
integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"]
|
||||
ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"]
|
||||
modelcreation = ["cookiecutter (==1.7.3)"]
|
||||
natten = ["natten (>=0.14.6,<0.15.0)"]
|
||||
onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"]
|
||||
onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
|
||||
optuna = ["optuna"]
|
||||
quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "libcst", "rich", "ruff (==0.5.1)", "urllib3 (<2.0.0)"]
|
||||
ray = ["ray[tune] (>=2.7.0)"]
|
||||
retrieval = ["datasets (!=2.5.0)", "faiss-cpu"]
|
||||
ruff = ["ruff (==0.5.1)"]
|
||||
sagemaker = ["sagemaker (>=2.31.0)"]
|
||||
sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"]
|
||||
serving = ["fastapi", "pydantic", "starlette", "uvicorn"]
|
||||
sigopt = ["sigopt"]
|
||||
sklearn = ["scikit-learn"]
|
||||
speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
|
||||
testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
|
||||
tf = ["keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"]
|
||||
tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"]
|
||||
tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
tiktoken = ["blobfile", "tiktoken"]
|
||||
timm = ["timm (<=1.0.11)"]
|
||||
tokenizers = ["tokenizers (>=0.21,<0.22)"]
|
||||
torch = ["accelerate (>=0.26.0)", "torch"]
|
||||
torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
|
||||
torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"]
|
||||
torchhub = ["filelock", "huggingface-hub (>=0.24.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.21,<0.22)", "torch", "tqdm (>=4.27)"]
|
||||
video = ["av (==9.2.0)"]
|
||||
vision = ["Pillow (>=10.0.1,<=15.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "transforms3d"
|
||||
version = "0.4.2"
|
||||
@@ -7558,6 +7708,7 @@ dev = ["debugpy", "pre-commit"]
|
||||
dora = ["gym-dora"]
|
||||
dynamixel = ["dynamixel-sdk", "pynput"]
|
||||
feetech = ["feetech-servo-sdk", "pynput"]
|
||||
hilserl = ["torchmetrics", "transformers"]
|
||||
intelrealsense = ["pyrealsense2"]
|
||||
pusht = ["gym-pusht"]
|
||||
stretch = ["hello-robot-stretch-body", "pynput", "pyrealsense2", "pyrender"]
|
||||
@@ -7569,4 +7720,4 @@ xarm = ["gym-xarm"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "41344f0eb2d06d9a378abcd10df8205aa3926ff0a08ac5ab1a0b1bcae7440fd8"
|
||||
content-hash = "b9d299916ced6af1d243f961a32b0a4aacbef18e0b95337a5224e8511f5d6dda"
|
||||
|
||||
@@ -71,6 +71,8 @@ pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platfo
|
||||
hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true}
|
||||
pyserial = {version = ">=3.5", optional = true}
|
||||
jsonlines = ">=4.0.0"
|
||||
transformers = {version = "^4.47.0", optional = true}
|
||||
torchmetrics = {version = "^1.6.0", optional = true}
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
@@ -86,6 +88,7 @@ dynamixel = ["dynamixel-sdk", "pynput"]
|
||||
feetech = ["feetech-servo-sdk", "pynput"]
|
||||
intelrealsense = ["pyrealsense2"]
|
||||
stretch = ["hello-robot-stretch-body", "pyrender", "pyrealsense2", "pynput"]
|
||||
hilserl = ["transformers", "torchmetrics"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 110
|
||||
|
||||
@@ -14,9 +14,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import traceback
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from serial import SerialException
|
||||
|
||||
from lerobot import available_cameras, available_motors, available_robots
|
||||
@@ -124,3 +126,14 @@ def patch_builtins_input(monkeypatch):
|
||||
print(text)
|
||||
|
||||
monkeypatch.setattr("builtins.input", print_text)
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption("--seed", action="store", default="42", help="Set random seed for reproducibility")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_random_seed(request):
|
||||
seed = int(request.config.getoption("--seed"))
|
||||
random.seed(seed) # Python random
|
||||
torch.manual_seed(seed) # PyTorch
|
||||
|
||||
@@ -0,0 +1,244 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torchmetrics import AUROC, Accuracy, F1Score, Precision, Recall
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torchvision.transforms import ToTensor
|
||||
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier, ClassifierConfig
|
||||
|
||||
BATCH_SIZE = 1000
|
||||
LR = 0.1
|
||||
EPOCH_NUM = 2
|
||||
|
||||
if torch.cuda.is_available():
|
||||
DEVICE = torch.device("cuda")
|
||||
elif torch.backends.mps.is_available():
|
||||
DEVICE = torch.device("mps")
|
||||
else:
|
||||
DEVICE = torch.device("cpu")
|
||||
|
||||
|
||||
def train_evaluate_multiclass_classifier():
|
||||
logging.info(
|
||||
f"Start multiclass classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}"
|
||||
)
|
||||
multiclass_config = ClassifierConfig(model_name="microsoft/resnet-18", device=DEVICE, num_classes=10)
|
||||
multiclass_classifier = Classifier(multiclass_config)
|
||||
|
||||
trainset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
|
||||
testset = CIFAR10(root="data", train=False, download=True, transform=ToTensor())
|
||||
|
||||
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)
|
||||
|
||||
multiclass_num_classes = 10
|
||||
epoch = 1
|
||||
|
||||
criterion = CrossEntropyLoss()
|
||||
optimizer = Adam(multiclass_classifier.parameters(), lr=LR)
|
||||
|
||||
multiclass_classifier.train()
|
||||
|
||||
logging.info("Start multiclass classifier training")
|
||||
|
||||
# Training loop
|
||||
while epoch < EPOCH_NUM: # loop over the dataset multiple times
|
||||
for i, data in enumerate(trainloader):
|
||||
inputs, labels = data
|
||||
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
||||
|
||||
# Zero the parameter gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
outputs = multiclass_classifier(inputs)
|
||||
|
||||
loss = criterion(outputs.logits, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if i % 10 == 0: # print every 10 mini-batches
|
||||
logging.info(f"[Epoch {epoch}, Batch {i}] loss: {loss.item():.3f}")
|
||||
|
||||
epoch += 1
|
||||
|
||||
print("Multiclass classifier training finished")
|
||||
|
||||
multiclass_classifier.eval()
|
||||
|
||||
test_loss = 0.0
|
||||
test_labels = []
|
||||
test_pridections = []
|
||||
test_probs = []
|
||||
|
||||
with torch.no_grad():
|
||||
for data in testloader:
|
||||
images, labels = data
|
||||
images, labels = images.to(DEVICE), labels.to(DEVICE)
|
||||
outputs = multiclass_classifier(images)
|
||||
loss = criterion(outputs.logits, labels)
|
||||
test_loss += loss.item() * BATCH_SIZE
|
||||
|
||||
_, predicted = torch.max(outputs.logits, 1)
|
||||
test_labels.extend(labels.cpu())
|
||||
test_pridections.extend(predicted.cpu())
|
||||
test_probs.extend(outputs.probabilities.cpu())
|
||||
|
||||
test_loss = test_loss / len(testset)
|
||||
|
||||
logging.info(f"Multiclass classifier test loss {test_loss:.3f}")
|
||||
|
||||
test_labels = torch.stack(test_labels)
|
||||
test_predictions = torch.stack(test_pridections)
|
||||
test_probs = torch.stack(test_probs)
|
||||
|
||||
accuracy = Accuracy(task="multiclass", num_classes=multiclass_num_classes)
|
||||
precision = Precision(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
|
||||
recall = Recall(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
|
||||
f1 = F1Score(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
|
||||
auroc = AUROC(task="multiclass", num_classes=multiclass_num_classes, average="weighted")
|
||||
|
||||
# Calculate metrics
|
||||
acc = accuracy(test_predictions, test_labels)
|
||||
prec = precision(test_predictions, test_labels)
|
||||
rec = recall(test_predictions, test_labels)
|
||||
f1_score = f1(test_predictions, test_labels)
|
||||
auroc_score = auroc(test_probs, test_labels)
|
||||
|
||||
logging.info(f"Accuracy: {acc:.2f}")
|
||||
logging.info(f"Precision: {prec:.2f}")
|
||||
logging.info(f"Recall: {rec:.2f}")
|
||||
logging.info(f"F1 Score: {f1_score:.2f}")
|
||||
logging.info(f"AUROC Score: {auroc_score:.2f}")
|
||||
|
||||
|
||||
def train_evaluate_binary_classifier():
|
||||
logging.info(
|
||||
f"Start binary classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}"
|
||||
)
|
||||
|
||||
target_binary_class = 3
|
||||
|
||||
def one_vs_rest(dataset, target_class):
|
||||
new_targets = []
|
||||
for _, label in dataset:
|
||||
new_label = float(1.0) if label == target_class else float(0.0)
|
||||
new_targets.append(new_label)
|
||||
|
||||
dataset.targets = new_targets # Replace the original labels with the binary ones
|
||||
return dataset
|
||||
|
||||
binary_train_dataset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
|
||||
binary_test_dataset = CIFAR10(root="data", train=False, download=True, transform=ToTensor())
|
||||
|
||||
# Apply one-vs-rest labeling
|
||||
binary_train_dataset = one_vs_rest(binary_train_dataset, target_binary_class)
|
||||
binary_test_dataset = one_vs_rest(binary_test_dataset, target_binary_class)
|
||||
|
||||
binary_trainloader = DataLoader(binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
binary_testloader = DataLoader(binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
||||
|
||||
binary_epoch = 1
|
||||
|
||||
binary_config = ClassifierConfig(model_name="microsoft/resnet-50", device=DEVICE)
|
||||
binary_classifier = Classifier(binary_config)
|
||||
|
||||
class_counts = np.bincount(binary_train_dataset.targets)
|
||||
n = len(binary_train_dataset)
|
||||
w0 = n / (2.0 * class_counts[0])
|
||||
w1 = n / (2.0 * class_counts[1])
|
||||
|
||||
binary_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(w1 / w0))
|
||||
binary_optimizer = Adam(binary_classifier.parameters(), lr=LR)
|
||||
|
||||
binary_classifier.train()
|
||||
|
||||
logging.info("Start binary classifier training")
|
||||
|
||||
# Training loop
|
||||
while binary_epoch < EPOCH_NUM: # loop over the dataset multiple times
|
||||
for i, data in enumerate(binary_trainloader):
|
||||
inputs, labels = data
|
||||
inputs, labels = inputs.to(DEVICE), labels.to(torch.float32).to(DEVICE)
|
||||
|
||||
# Zero the parameter gradients
|
||||
binary_optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
outputs = binary_classifier(inputs)
|
||||
loss = binary_criterion(outputs.logits, labels)
|
||||
loss.backward()
|
||||
binary_optimizer.step()
|
||||
|
||||
if i % 10 == 0: # print every 10 mini-batches
|
||||
print(f"[Epoch {binary_epoch}, Batch {i}] loss: {loss.item():.3f}")
|
||||
binary_epoch += 1
|
||||
|
||||
logging.info("Binary classifier training finished")
|
||||
logging.info("Start binary classifier evaluation")
|
||||
|
||||
binary_classifier.eval()
|
||||
|
||||
test_loss = 0.0
|
||||
test_labels = []
|
||||
test_pridections = []
|
||||
test_probs = []
|
||||
|
||||
with torch.no_grad():
|
||||
for data in binary_testloader:
|
||||
images, labels = data
|
||||
images, labels = images.to(DEVICE), labels.to(torch.float32).to(DEVICE)
|
||||
outputs = binary_classifier(images)
|
||||
loss = binary_criterion(outputs.logits, labels)
|
||||
test_loss += loss.item() * BATCH_SIZE
|
||||
|
||||
test_labels.extend(labels.cpu())
|
||||
test_pridections.extend(outputs.logits.cpu())
|
||||
test_probs.extend(outputs.probabilities.cpu())
|
||||
|
||||
test_loss = test_loss / len(binary_test_dataset)
|
||||
|
||||
logging.info(f"Binary classifier test loss {test_loss:.3f}")
|
||||
|
||||
test_labels = torch.stack(test_labels)
|
||||
test_predictions = torch.stack(test_pridections)
|
||||
test_probs = torch.stack(test_probs)
|
||||
|
||||
# Calculate metrics
|
||||
acc = Accuracy(task="binary")(test_predictions, test_labels)
|
||||
prec = Precision(task="binary", average="weighted")(test_predictions, test_labels)
|
||||
rec = Recall(task="binary", average="weighted")(test_predictions, test_labels)
|
||||
f1_score = F1Score(task="binary", average="weighted")(test_predictions, test_labels)
|
||||
auroc_score = AUROC(task="binary", average="weighted")(test_probs, test_labels)
|
||||
|
||||
logging.info(f"Accuracy: {acc:.2f}")
|
||||
logging.info(f"Precision: {prec:.2f}")
|
||||
logging.info(f"Recall: {rec:.2f}")
|
||||
logging.info(f"F1 Score: {f1_score:.2f}")
|
||||
logging.info(f"AUROC Score: {auroc_score:.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_evaluate_multiclass_classifier()
|
||||
train_evaluate_binary_classifier()
|
||||
@@ -0,0 +1,78 @@
|
||||
import torch
|
||||
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
||||
Classifier,
|
||||
ClassifierConfig,
|
||||
ClassifierOutput,
|
||||
)
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
def test_classifier_output():
|
||||
output = ClassifierOutput(
|
||||
logits=torch.tensor([1, 2, 3]), probabilities=torch.tensor([0.1, 0.2, 0.3]), hidden_states=None
|
||||
)
|
||||
|
||||
assert (
|
||||
f"{output}"
|
||||
== "ClassifierOutput(logits=tensor([1, 2, 3]), probabilities=tensor([0.1000, 0.2000, 0.3000]), hidden_states=None)"
|
||||
)
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_binary_classifier_with_default_params():
|
||||
config = ClassifierConfig()
|
||||
classifier = Classifier(config)
|
||||
|
||||
batch_size = 10
|
||||
|
||||
input = torch.rand(batch_size, 3, 224, 224)
|
||||
output = classifier(input)
|
||||
|
||||
assert output is not None
|
||||
assert output.logits.shape == torch.Size([batch_size])
|
||||
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
||||
assert output.probabilities.shape == torch.Size([batch_size])
|
||||
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
||||
assert output.hidden_states.shape == torch.Size([batch_size, 2048])
|
||||
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_multiclass_classifier():
|
||||
num_classes = 5
|
||||
config = ClassifierConfig(num_classes=num_classes)
|
||||
classifier = Classifier(config)
|
||||
|
||||
batch_size = 10
|
||||
|
||||
input = torch.rand(batch_size, 3, 224, 224)
|
||||
output = classifier(input)
|
||||
|
||||
assert output is not None
|
||||
assert output.logits.shape == torch.Size([batch_size, num_classes])
|
||||
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
||||
assert output.probabilities.shape == torch.Size([batch_size, num_classes])
|
||||
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
||||
assert output.hidden_states.shape == torch.Size([batch_size, 2048])
|
||||
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_default_device():
|
||||
config = ClassifierConfig()
|
||||
assert config.device == "cpu"
|
||||
|
||||
classifier = Classifier(config)
|
||||
for p in classifier.parameters():
|
||||
assert p.device == torch.device("cpu")
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_explicit_device_setup():
|
||||
config = ClassifierConfig(device="meta")
|
||||
assert config.device == "meta"
|
||||
|
||||
classifier = Classifier(config)
|
||||
for p in classifier.parameters():
|
||||
assert p.device == torch.device("meta")
|
||||
Reference in New Issue
Block a user