rename reward classifier

This commit is contained in:
AdilZouitine
2025-04-25 18:38:52 +02:00
parent ea89b29fe5
commit 4257fe5045
7 changed files with 25 additions and 284 deletions

View File

@@ -24,10 +24,10 @@ from lerobot.common.envs.configs import EnvConfig
from lerobot.common.envs.utils import env_to_policy_features
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.configs.policies import PreTrainedConfig
@@ -64,8 +64,8 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
from lerobot.common.policies.sac.modeling_sac import SACPolicy
return SACPolicy
elif name == "hilserl_classifier":
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
elif name == "reward_classifier":
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
return Classifier
else:
@@ -85,8 +85,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return PI0Config(**kwargs)
elif policy_type == "pi0fast":
return PI0FASTConfig(**kwargs)
elif policy_type == "hilserl_classifier":
return ClassifierConfig(**kwargs)
elif policy_type == "reward_classifier":
return RewardClassifierConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")

View File

@@ -80,7 +80,7 @@ def create_stats_buffers(
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
if stats and key in stats:
# NOTE:(maractingi, azouitine): Change the order of these conditions becuase in online environments we don't have dataset stats
# NOTE:(maractingi, azouitine): Change the order of these conditions because in online environments we don't have dataset stats
# Therefore, we don't access to full stats of the data, some elements either have min-max or mean-std only
if norm_mode is NormalizationMode.MEAN_STD:
if "mean" not in stats[key] or "std" not in stats[key]:

View File

@@ -7,12 +7,12 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@PreTrainedConfig.register_subclass(name="hilserl_classifier")
@PreTrainedConfig.register_subclass(name="reward_classifier")
@dataclass
class ClassifierConfig(PreTrainedConfig):
"""Configuration for the Classifier model."""
class RewardClassifierConfig(PreTrainedConfig):
"""Configuration for the Reward Classifier model."""
name: str = "hilserl_classifier"
name: str = "reward_classifier"
num_classes: int = 2
hidden_dim: int = 256
dropout_rate: float = 0.1

View File

@@ -5,11 +5,9 @@ import torch
from torch import Tensor, nn
from lerobot.common.constants import OBS_IMAGE
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
ClassifierConfig,
)
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
@@ -39,12 +37,12 @@ class ClassifierOutput:
class Classifier(PreTrainedPolicy):
"""Image classifier built on top of a pre-trained encoder."""
name = "hilserl_classifier"
config_class = ClassifierConfig
name = "reward_classifier"
config_class = RewardClassifierConfig
def __init__(
self,
config: ClassifierConfig,
config: RewardClassifierConfig,
dataset_stats: Dict[str, Dict[str, Tensor]] | None = None,
):
from transformers import AutoModel

View File

@@ -1284,7 +1284,7 @@ def init_reward_classifier(cfg):
if cfg.reward_classifier_pretrained_path is None:
return None
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
# Get device from config or default to CUDA
device = getattr(cfg, "device", "cpu")

View File

@@ -1,247 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchmetrics import AUROC, Accuracy, F1Score, Precision, Recall
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
ClassifierConfig,
)
BATCH_SIZE = 1000
LR = 0.1
EPOCH_NUM = 2
if torch.cuda.is_available():
DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
DEVICE = torch.device("mps")
else:
DEVICE = torch.device("cpu")
def train_evaluate_multiclass_classifier():
logging.info(
f"Start multiclass classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}"
)
multiclass_config = ClassifierConfig(model_name="microsoft/resnet-18", device=DEVICE, num_classes=10)
multiclass_classifier = Classifier(multiclass_config)
trainset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
testset = CIFAR10(root="data", train=False, download=True, transform=ToTensor())
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)
multiclass_num_classes = 10
epoch = 1
criterion = CrossEntropyLoss()
optimizer = Adam(multiclass_classifier.parameters(), lr=LR)
multiclass_classifier.train()
logging.info("Start multiclass classifier training")
# Training loop
while epoch < EPOCH_NUM: # loop over the dataset multiple times
for i, data in enumerate(trainloader):
inputs, labels = data
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
# Zero the parameter gradients
optimizer.zero_grad()
# Forward pass
outputs = multiclass_classifier(inputs)
loss = criterion(outputs.logits, labels)
loss.backward()
optimizer.step()
if i % 10 == 0: # print every 10 mini-batches
logging.info(f"[Epoch {epoch}, Batch {i}] loss: {loss.item():.3f}")
epoch += 1
print("Multiclass classifier training finished")
multiclass_classifier.eval()
test_loss = 0.0
test_labels = []
test_pridections = []
test_probs = []
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(DEVICE), labels.to(DEVICE)
outputs = multiclass_classifier(images)
loss = criterion(outputs.logits, labels)
test_loss += loss.item() * BATCH_SIZE
_, predicted = torch.max(outputs.logits, 1)
test_labels.extend(labels.cpu())
test_pridections.extend(predicted.cpu())
test_probs.extend(outputs.probabilities.cpu())
test_loss = test_loss / len(testset)
logging.info(f"Multiclass classifier test loss {test_loss:.3f}")
test_labels = torch.stack(test_labels)
test_predictions = torch.stack(test_pridections)
test_probs = torch.stack(test_probs)
accuracy = Accuracy(task="multiclass", num_classes=multiclass_num_classes)
precision = Precision(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
recall = Recall(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
f1 = F1Score(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
auroc = AUROC(task="multiclass", num_classes=multiclass_num_classes, average="weighted")
# Calculate metrics
acc = accuracy(test_predictions, test_labels)
prec = precision(test_predictions, test_labels)
rec = recall(test_predictions, test_labels)
f1_score = f1(test_predictions, test_labels)
auroc_score = auroc(test_probs, test_labels)
logging.info(f"Accuracy: {acc:.2f}")
logging.info(f"Precision: {prec:.2f}")
logging.info(f"Recall: {rec:.2f}")
logging.info(f"F1 Score: {f1_score:.2f}")
logging.info(f"AUROC Score: {auroc_score:.2f}")
def train_evaluate_binary_classifier():
logging.info(
f"Start binary classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}"
)
target_binary_class = 3
def one_vs_rest(dataset, target_class):
new_targets = []
for _, label in dataset:
new_label = float(1.0) if label == target_class else float(0.0)
new_targets.append(new_label)
dataset.targets = new_targets # Replace the original labels with the binary ones
return dataset
binary_train_dataset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
binary_test_dataset = CIFAR10(root="data", train=False, download=True, transform=ToTensor())
# Apply one-vs-rest labeling
binary_train_dataset = one_vs_rest(binary_train_dataset, target_binary_class)
binary_test_dataset = one_vs_rest(binary_test_dataset, target_binary_class)
binary_trainloader = DataLoader(binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
binary_testloader = DataLoader(binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False)
binary_epoch = 1
binary_config = ClassifierConfig(model_name="microsoft/resnet-50", device=DEVICE)
binary_classifier = Classifier(binary_config)
class_counts = np.bincount(binary_train_dataset.targets)
n = len(binary_train_dataset)
w0 = n / (2.0 * class_counts[0])
w1 = n / (2.0 * class_counts[1])
binary_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(w1 / w0))
binary_optimizer = Adam(binary_classifier.parameters(), lr=LR)
binary_classifier.train()
logging.info("Start binary classifier training")
# Training loop
while binary_epoch < EPOCH_NUM: # loop over the dataset multiple times
for i, data in enumerate(binary_trainloader):
inputs, labels = data
inputs, labels = inputs.to(DEVICE), labels.to(torch.float32).to(DEVICE)
# Zero the parameter gradients
binary_optimizer.zero_grad()
# Forward pass
outputs = binary_classifier(inputs)
loss = binary_criterion(outputs.logits, labels)
loss.backward()
binary_optimizer.step()
if i % 10 == 0: # print every 10 mini-batches
print(f"[Epoch {binary_epoch}, Batch {i}] loss: {loss.item():.3f}")
binary_epoch += 1
logging.info("Binary classifier training finished")
logging.info("Start binary classifier evaluation")
binary_classifier.eval()
test_loss = 0.0
test_labels = []
test_pridections = []
test_probs = []
with torch.no_grad():
for data in binary_testloader:
images, labels = data
images, labels = images.to(DEVICE), labels.to(torch.float32).to(DEVICE)
outputs = binary_classifier(images)
loss = binary_criterion(outputs.logits, labels)
test_loss += loss.item() * BATCH_SIZE
test_labels.extend(labels.cpu())
test_pridections.extend(outputs.logits.cpu())
test_probs.extend(outputs.probabilities.cpu())
test_loss = test_loss / len(binary_test_dataset)
logging.info(f"Binary classifier test loss {test_loss:.3f}")
test_labels = torch.stack(test_labels)
test_predictions = torch.stack(test_pridections)
test_probs = torch.stack(test_probs)
# Calculate metrics
acc = Accuracy(task="binary")(test_predictions, test_labels)
prec = Precision(task="binary", average="weighted")(test_predictions, test_labels)
rec = Recall(task="binary", average="weighted")(test_predictions, test_labels)
f1_score = F1Score(task="binary", average="weighted")(test_predictions, test_labels)
auroc_score = AUROC(task="binary", average="weighted")(test_probs, test_labels)
logging.info(f"Accuracy: {acc:.2f}")
logging.info(f"Precision: {prec:.2f}")
logging.info(f"Recall: {rec:.2f}")
logging.info(f"F1 Score: {f1_score:.2f}")
logging.info(f"AUROC Score: {auroc_score:.2f}")
if __name__ == "__main__":
train_evaluate_multiclass_classifier()
train_evaluate_binary_classifier()

View File

@@ -1,9 +1,7 @@
import torch
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
ClassifierConfig,
ClassifierOutput,
)
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.common.policies.reward_model.modeling_classifier import ClassifierOutput
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from tests.utils import require_package
@@ -23,11 +21,9 @@ def test_classifier_output():
@require_package("transformers")
def test_binary_classifier_with_default_params():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
)
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
config = ClassifierConfig()
config = RewardClassifierConfig()
config.input_features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
@@ -66,12 +62,10 @@ def test_binary_classifier_with_default_params():
@require_package("transformers")
def test_multiclass_classifier():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
)
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
num_classes = 5
config = ClassifierConfig()
config = RewardClassifierConfig()
config.input_features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
@@ -107,11 +101,9 @@ def test_multiclass_classifier():
@require_package("transformers")
def test_default_device():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
)
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
config = ClassifierConfig()
config = RewardClassifierConfig()
assert config.device == "cpu"
classifier = Classifier(config)
@@ -121,11 +113,9 @@ def test_default_device():
@require_package("transformers")
def test_explicit_device_setup():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
)
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
config = ClassifierConfig(device="cpu")
config = RewardClassifierConfig(device="cpu")
assert config.device == "cpu"
classifier = Classifier(config)