rename reward classifier
This commit is contained in:
@@ -24,10 +24,10 @@ from lerobot.common.envs.configs import EnvConfig
|
||||
from lerobot.common.envs.utils import env_to_policy_features
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
@@ -64,8 +64,8 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
return SACPolicy
|
||||
elif name == "hilserl_classifier":
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
elif name == "reward_classifier":
|
||||
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||
|
||||
return Classifier
|
||||
else:
|
||||
@@ -85,8 +85,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi0fast":
|
||||
return PI0FASTConfig(**kwargs)
|
||||
elif policy_type == "hilserl_classifier":
|
||||
return ClassifierConfig(**kwargs)
|
||||
elif policy_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ def create_stats_buffers(
|
||||
|
||||
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
||||
if stats and key in stats:
|
||||
# NOTE:(maractingi, azouitine): Change the order of these conditions becuase in online environments we don't have dataset stats
|
||||
# NOTE:(maractingi, azouitine): Change the order of these conditions because in online environments we don't have dataset stats
|
||||
# Therefore, we don't access to full stats of the data, some elements either have min-max or mean-std only
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
if "mean" not in stats[key] or "std" not in stats[key]:
|
||||
|
||||
@@ -7,12 +7,12 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass(name="hilserl_classifier")
|
||||
@PreTrainedConfig.register_subclass(name="reward_classifier")
|
||||
@dataclass
|
||||
class ClassifierConfig(PreTrainedConfig):
|
||||
"""Configuration for the Classifier model."""
|
||||
class RewardClassifierConfig(PreTrainedConfig):
|
||||
"""Configuration for the Reward Classifier model."""
|
||||
|
||||
name: str = "hilserl_classifier"
|
||||
name: str = "reward_classifier"
|
||||
num_classes: int = 2
|
||||
hidden_dim: int = 256
|
||||
dropout_rate: float = 0.1
|
||||
@@ -5,11 +5,9 @@ import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.constants import OBS_IMAGE
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
|
||||
ClassifierConfig,
|
||||
)
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -39,12 +37,12 @@ class ClassifierOutput:
|
||||
class Classifier(PreTrainedPolicy):
|
||||
"""Image classifier built on top of a pre-trained encoder."""
|
||||
|
||||
name = "hilserl_classifier"
|
||||
config_class = ClassifierConfig
|
||||
name = "reward_classifier"
|
||||
config_class = RewardClassifierConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ClassifierConfig,
|
||||
config: RewardClassifierConfig,
|
||||
dataset_stats: Dict[str, Dict[str, Tensor]] | None = None,
|
||||
):
|
||||
from transformers import AutoModel
|
||||
@@ -1284,7 +1284,7 @@ def init_reward_classifier(cfg):
|
||||
if cfg.reward_classifier_pretrained_path is None:
|
||||
return None
|
||||
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||
|
||||
# Get device from config or default to CUDA
|
||||
device = getattr(cfg, "device", "cpu")
|
||||
|
||||
@@ -1,247 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torchmetrics import AUROC, Accuracy, F1Score, Precision, Recall
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torchvision.transforms import ToTensor
|
||||
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
||||
Classifier,
|
||||
ClassifierConfig,
|
||||
)
|
||||
|
||||
BATCH_SIZE = 1000
|
||||
LR = 0.1
|
||||
EPOCH_NUM = 2
|
||||
|
||||
if torch.cuda.is_available():
|
||||
DEVICE = torch.device("cuda")
|
||||
elif torch.backends.mps.is_available():
|
||||
DEVICE = torch.device("mps")
|
||||
else:
|
||||
DEVICE = torch.device("cpu")
|
||||
|
||||
|
||||
def train_evaluate_multiclass_classifier():
|
||||
logging.info(
|
||||
f"Start multiclass classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}"
|
||||
)
|
||||
multiclass_config = ClassifierConfig(model_name="microsoft/resnet-18", device=DEVICE, num_classes=10)
|
||||
multiclass_classifier = Classifier(multiclass_config)
|
||||
|
||||
trainset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
|
||||
testset = CIFAR10(root="data", train=False, download=True, transform=ToTensor())
|
||||
|
||||
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)
|
||||
|
||||
multiclass_num_classes = 10
|
||||
epoch = 1
|
||||
|
||||
criterion = CrossEntropyLoss()
|
||||
optimizer = Adam(multiclass_classifier.parameters(), lr=LR)
|
||||
|
||||
multiclass_classifier.train()
|
||||
|
||||
logging.info("Start multiclass classifier training")
|
||||
|
||||
# Training loop
|
||||
while epoch < EPOCH_NUM: # loop over the dataset multiple times
|
||||
for i, data in enumerate(trainloader):
|
||||
inputs, labels = data
|
||||
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
||||
|
||||
# Zero the parameter gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
outputs = multiclass_classifier(inputs)
|
||||
|
||||
loss = criterion(outputs.logits, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if i % 10 == 0: # print every 10 mini-batches
|
||||
logging.info(f"[Epoch {epoch}, Batch {i}] loss: {loss.item():.3f}")
|
||||
|
||||
epoch += 1
|
||||
|
||||
print("Multiclass classifier training finished")
|
||||
|
||||
multiclass_classifier.eval()
|
||||
|
||||
test_loss = 0.0
|
||||
test_labels = []
|
||||
test_pridections = []
|
||||
test_probs = []
|
||||
|
||||
with torch.no_grad():
|
||||
for data in testloader:
|
||||
images, labels = data
|
||||
images, labels = images.to(DEVICE), labels.to(DEVICE)
|
||||
outputs = multiclass_classifier(images)
|
||||
loss = criterion(outputs.logits, labels)
|
||||
test_loss += loss.item() * BATCH_SIZE
|
||||
|
||||
_, predicted = torch.max(outputs.logits, 1)
|
||||
test_labels.extend(labels.cpu())
|
||||
test_pridections.extend(predicted.cpu())
|
||||
test_probs.extend(outputs.probabilities.cpu())
|
||||
|
||||
test_loss = test_loss / len(testset)
|
||||
|
||||
logging.info(f"Multiclass classifier test loss {test_loss:.3f}")
|
||||
|
||||
test_labels = torch.stack(test_labels)
|
||||
test_predictions = torch.stack(test_pridections)
|
||||
test_probs = torch.stack(test_probs)
|
||||
|
||||
accuracy = Accuracy(task="multiclass", num_classes=multiclass_num_classes)
|
||||
precision = Precision(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
|
||||
recall = Recall(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
|
||||
f1 = F1Score(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
|
||||
auroc = AUROC(task="multiclass", num_classes=multiclass_num_classes, average="weighted")
|
||||
|
||||
# Calculate metrics
|
||||
acc = accuracy(test_predictions, test_labels)
|
||||
prec = precision(test_predictions, test_labels)
|
||||
rec = recall(test_predictions, test_labels)
|
||||
f1_score = f1(test_predictions, test_labels)
|
||||
auroc_score = auroc(test_probs, test_labels)
|
||||
|
||||
logging.info(f"Accuracy: {acc:.2f}")
|
||||
logging.info(f"Precision: {prec:.2f}")
|
||||
logging.info(f"Recall: {rec:.2f}")
|
||||
logging.info(f"F1 Score: {f1_score:.2f}")
|
||||
logging.info(f"AUROC Score: {auroc_score:.2f}")
|
||||
|
||||
|
||||
def train_evaluate_binary_classifier():
|
||||
logging.info(
|
||||
f"Start binary classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}"
|
||||
)
|
||||
|
||||
target_binary_class = 3
|
||||
|
||||
def one_vs_rest(dataset, target_class):
|
||||
new_targets = []
|
||||
for _, label in dataset:
|
||||
new_label = float(1.0) if label == target_class else float(0.0)
|
||||
new_targets.append(new_label)
|
||||
|
||||
dataset.targets = new_targets # Replace the original labels with the binary ones
|
||||
return dataset
|
||||
|
||||
binary_train_dataset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
|
||||
binary_test_dataset = CIFAR10(root="data", train=False, download=True, transform=ToTensor())
|
||||
|
||||
# Apply one-vs-rest labeling
|
||||
binary_train_dataset = one_vs_rest(binary_train_dataset, target_binary_class)
|
||||
binary_test_dataset = one_vs_rest(binary_test_dataset, target_binary_class)
|
||||
|
||||
binary_trainloader = DataLoader(binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
binary_testloader = DataLoader(binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
||||
|
||||
binary_epoch = 1
|
||||
|
||||
binary_config = ClassifierConfig(model_name="microsoft/resnet-50", device=DEVICE)
|
||||
binary_classifier = Classifier(binary_config)
|
||||
|
||||
class_counts = np.bincount(binary_train_dataset.targets)
|
||||
n = len(binary_train_dataset)
|
||||
w0 = n / (2.0 * class_counts[0])
|
||||
w1 = n / (2.0 * class_counts[1])
|
||||
|
||||
binary_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(w1 / w0))
|
||||
binary_optimizer = Adam(binary_classifier.parameters(), lr=LR)
|
||||
|
||||
binary_classifier.train()
|
||||
|
||||
logging.info("Start binary classifier training")
|
||||
|
||||
# Training loop
|
||||
while binary_epoch < EPOCH_NUM: # loop over the dataset multiple times
|
||||
for i, data in enumerate(binary_trainloader):
|
||||
inputs, labels = data
|
||||
inputs, labels = inputs.to(DEVICE), labels.to(torch.float32).to(DEVICE)
|
||||
|
||||
# Zero the parameter gradients
|
||||
binary_optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
outputs = binary_classifier(inputs)
|
||||
loss = binary_criterion(outputs.logits, labels)
|
||||
loss.backward()
|
||||
binary_optimizer.step()
|
||||
|
||||
if i % 10 == 0: # print every 10 mini-batches
|
||||
print(f"[Epoch {binary_epoch}, Batch {i}] loss: {loss.item():.3f}")
|
||||
binary_epoch += 1
|
||||
|
||||
logging.info("Binary classifier training finished")
|
||||
logging.info("Start binary classifier evaluation")
|
||||
|
||||
binary_classifier.eval()
|
||||
|
||||
test_loss = 0.0
|
||||
test_labels = []
|
||||
test_pridections = []
|
||||
test_probs = []
|
||||
|
||||
with torch.no_grad():
|
||||
for data in binary_testloader:
|
||||
images, labels = data
|
||||
images, labels = images.to(DEVICE), labels.to(torch.float32).to(DEVICE)
|
||||
outputs = binary_classifier(images)
|
||||
loss = binary_criterion(outputs.logits, labels)
|
||||
test_loss += loss.item() * BATCH_SIZE
|
||||
|
||||
test_labels.extend(labels.cpu())
|
||||
test_pridections.extend(outputs.logits.cpu())
|
||||
test_probs.extend(outputs.probabilities.cpu())
|
||||
|
||||
test_loss = test_loss / len(binary_test_dataset)
|
||||
|
||||
logging.info(f"Binary classifier test loss {test_loss:.3f}")
|
||||
|
||||
test_labels = torch.stack(test_labels)
|
||||
test_predictions = torch.stack(test_pridections)
|
||||
test_probs = torch.stack(test_probs)
|
||||
|
||||
# Calculate metrics
|
||||
acc = Accuracy(task="binary")(test_predictions, test_labels)
|
||||
prec = Precision(task="binary", average="weighted")(test_predictions, test_labels)
|
||||
rec = Recall(task="binary", average="weighted")(test_predictions, test_labels)
|
||||
f1_score = F1Score(task="binary", average="weighted")(test_predictions, test_labels)
|
||||
auroc_score = AUROC(task="binary", average="weighted")(test_probs, test_labels)
|
||||
|
||||
logging.info(f"Accuracy: {acc:.2f}")
|
||||
logging.info(f"Precision: {prec:.2f}")
|
||||
logging.info(f"Recall: {rec:.2f}")
|
||||
logging.info(f"F1 Score: {f1_score:.2f}")
|
||||
logging.info(f"AUROC Score: {auroc_score:.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_evaluate_multiclass_classifier()
|
||||
train_evaluate_binary_classifier()
|
||||
@@ -1,9 +1,7 @@
|
||||
import torch
|
||||
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
||||
ClassifierConfig,
|
||||
ClassifierOutput,
|
||||
)
|
||||
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.common.policies.reward_model.modeling_classifier import ClassifierOutput
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from tests.utils import require_package
|
||||
|
||||
@@ -23,11 +21,9 @@ def test_classifier_output():
|
||||
|
||||
@require_package("transformers")
|
||||
def test_binary_classifier_with_default_params():
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
||||
Classifier,
|
||||
)
|
||||
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = ClassifierConfig()
|
||||
config = RewardClassifierConfig()
|
||||
config.input_features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
@@ -66,12 +62,10 @@ def test_binary_classifier_with_default_params():
|
||||
|
||||
@require_package("transformers")
|
||||
def test_multiclass_classifier():
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
||||
Classifier,
|
||||
)
|
||||
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||
|
||||
num_classes = 5
|
||||
config = ClassifierConfig()
|
||||
config = RewardClassifierConfig()
|
||||
config.input_features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
@@ -107,11 +101,9 @@ def test_multiclass_classifier():
|
||||
|
||||
@require_package("transformers")
|
||||
def test_default_device():
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
||||
Classifier,
|
||||
)
|
||||
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = ClassifierConfig()
|
||||
config = RewardClassifierConfig()
|
||||
assert config.device == "cpu"
|
||||
|
||||
classifier = Classifier(config)
|
||||
@@ -121,11 +113,9 @@ def test_default_device():
|
||||
|
||||
@require_package("transformers")
|
||||
def test_explicit_device_setup():
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
||||
Classifier,
|
||||
)
|
||||
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = ClassifierConfig(device="cpu")
|
||||
config = RewardClassifierConfig(device="cpu")
|
||||
assert config.device == "cpu"
|
||||
|
||||
classifier = Classifier(config)
|
||||
Reference in New Issue
Block a user