rename reward classifier
This commit is contained in:
@@ -24,10 +24,10 @@ from lerobot.common.envs.configs import EnvConfig
|
|||||||
from lerobot.common.envs.utils import env_to_policy_features
|
from lerobot.common.envs.utils import env_to_policy_features
|
||||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
|
||||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
|
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
|
||||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
@@ -64,8 +64,8 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
|||||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||||
|
|
||||||
return SACPolicy
|
return SACPolicy
|
||||||
elif name == "hilserl_classifier":
|
elif name == "reward_classifier":
|
||||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||||
|
|
||||||
return Classifier
|
return Classifier
|
||||||
else:
|
else:
|
||||||
@@ -85,8 +85,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||||||
return PI0Config(**kwargs)
|
return PI0Config(**kwargs)
|
||||||
elif policy_type == "pi0fast":
|
elif policy_type == "pi0fast":
|
||||||
return PI0FASTConfig(**kwargs)
|
return PI0FASTConfig(**kwargs)
|
||||||
elif policy_type == "hilserl_classifier":
|
elif policy_type == "reward_classifier":
|
||||||
return ClassifierConfig(**kwargs)
|
return RewardClassifierConfig(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||||
|
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ def create_stats_buffers(
|
|||||||
|
|
||||||
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
||||||
if stats and key in stats:
|
if stats and key in stats:
|
||||||
# NOTE:(maractingi, azouitine): Change the order of these conditions becuase in online environments we don't have dataset stats
|
# NOTE:(maractingi, azouitine): Change the order of these conditions because in online environments we don't have dataset stats
|
||||||
# Therefore, we don't access to full stats of the data, some elements either have min-max or mean-std only
|
# Therefore, we don't access to full stats of the data, some elements either have min-max or mean-std only
|
||||||
if norm_mode is NormalizationMode.MEAN_STD:
|
if norm_mode is NormalizationMode.MEAN_STD:
|
||||||
if "mean" not in stats[key] or "std" not in stats[key]:
|
if "mean" not in stats[key] or "std" not in stats[key]:
|
||||||
|
|||||||
@@ -7,12 +7,12 @@ from lerobot.configs.policies import PreTrainedConfig
|
|||||||
from lerobot.configs.types import NormalizationMode
|
from lerobot.configs.types import NormalizationMode
|
||||||
|
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass(name="hilserl_classifier")
|
@PreTrainedConfig.register_subclass(name="reward_classifier")
|
||||||
@dataclass
|
@dataclass
|
||||||
class ClassifierConfig(PreTrainedConfig):
|
class RewardClassifierConfig(PreTrainedConfig):
|
||||||
"""Configuration for the Classifier model."""
|
"""Configuration for the Reward Classifier model."""
|
||||||
|
|
||||||
name: str = "hilserl_classifier"
|
name: str = "reward_classifier"
|
||||||
num_classes: int = 2
|
num_classes: int = 2
|
||||||
hidden_dim: int = 256
|
hidden_dim: int = 256
|
||||||
dropout_rate: float = 0.1
|
dropout_rate: float = 0.1
|
||||||
@@ -5,11 +5,9 @@ import torch
|
|||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from lerobot.common.constants import OBS_IMAGE
|
from lerobot.common.constants import OBS_IMAGE
|
||||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
|
|
||||||
ClassifierConfig,
|
|
||||||
)
|
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
|
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -39,12 +37,12 @@ class ClassifierOutput:
|
|||||||
class Classifier(PreTrainedPolicy):
|
class Classifier(PreTrainedPolicy):
|
||||||
"""Image classifier built on top of a pre-trained encoder."""
|
"""Image classifier built on top of a pre-trained encoder."""
|
||||||
|
|
||||||
name = "hilserl_classifier"
|
name = "reward_classifier"
|
||||||
config_class = ClassifierConfig
|
config_class = RewardClassifierConfig
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: ClassifierConfig,
|
config: RewardClassifierConfig,
|
||||||
dataset_stats: Dict[str, Dict[str, Tensor]] | None = None,
|
dataset_stats: Dict[str, Dict[str, Tensor]] | None = None,
|
||||||
):
|
):
|
||||||
from transformers import AutoModel
|
from transformers import AutoModel
|
||||||
@@ -1284,7 +1284,7 @@ def init_reward_classifier(cfg):
|
|||||||
if cfg.reward_classifier_pretrained_path is None:
|
if cfg.reward_classifier_pretrained_path is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||||
|
|
||||||
# Get device from config or default to CUDA
|
# Get device from config or default to CUDA
|
||||||
device = getattr(cfg, "device", "cpu")
|
device = getattr(cfg, "device", "cpu")
|
||||||
|
|||||||
@@ -1,247 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
|
|
||||||
from torch.optim import Adam
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from torchmetrics import AUROC, Accuracy, F1Score, Precision, Recall
|
|
||||||
from torchvision.datasets import CIFAR10
|
|
||||||
from torchvision.transforms import ToTensor
|
|
||||||
|
|
||||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
|
||||||
Classifier,
|
|
||||||
ClassifierConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
BATCH_SIZE = 1000
|
|
||||||
LR = 0.1
|
|
||||||
EPOCH_NUM = 2
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
DEVICE = torch.device("cuda")
|
|
||||||
elif torch.backends.mps.is_available():
|
|
||||||
DEVICE = torch.device("mps")
|
|
||||||
else:
|
|
||||||
DEVICE = torch.device("cpu")
|
|
||||||
|
|
||||||
|
|
||||||
def train_evaluate_multiclass_classifier():
|
|
||||||
logging.info(
|
|
||||||
f"Start multiclass classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}"
|
|
||||||
)
|
|
||||||
multiclass_config = ClassifierConfig(model_name="microsoft/resnet-18", device=DEVICE, num_classes=10)
|
|
||||||
multiclass_classifier = Classifier(multiclass_config)
|
|
||||||
|
|
||||||
trainset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
|
|
||||||
testset = CIFAR10(root="data", train=False, download=True, transform=ToTensor())
|
|
||||||
|
|
||||||
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
|
|
||||||
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)
|
|
||||||
|
|
||||||
multiclass_num_classes = 10
|
|
||||||
epoch = 1
|
|
||||||
|
|
||||||
criterion = CrossEntropyLoss()
|
|
||||||
optimizer = Adam(multiclass_classifier.parameters(), lr=LR)
|
|
||||||
|
|
||||||
multiclass_classifier.train()
|
|
||||||
|
|
||||||
logging.info("Start multiclass classifier training")
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
while epoch < EPOCH_NUM: # loop over the dataset multiple times
|
|
||||||
for i, data in enumerate(trainloader):
|
|
||||||
inputs, labels = data
|
|
||||||
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
|
||||||
|
|
||||||
# Zero the parameter gradients
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
# Forward pass
|
|
||||||
outputs = multiclass_classifier(inputs)
|
|
||||||
|
|
||||||
loss = criterion(outputs.logits, labels)
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
if i % 10 == 0: # print every 10 mini-batches
|
|
||||||
logging.info(f"[Epoch {epoch}, Batch {i}] loss: {loss.item():.3f}")
|
|
||||||
|
|
||||||
epoch += 1
|
|
||||||
|
|
||||||
print("Multiclass classifier training finished")
|
|
||||||
|
|
||||||
multiclass_classifier.eval()
|
|
||||||
|
|
||||||
test_loss = 0.0
|
|
||||||
test_labels = []
|
|
||||||
test_pridections = []
|
|
||||||
test_probs = []
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for data in testloader:
|
|
||||||
images, labels = data
|
|
||||||
images, labels = images.to(DEVICE), labels.to(DEVICE)
|
|
||||||
outputs = multiclass_classifier(images)
|
|
||||||
loss = criterion(outputs.logits, labels)
|
|
||||||
test_loss += loss.item() * BATCH_SIZE
|
|
||||||
|
|
||||||
_, predicted = torch.max(outputs.logits, 1)
|
|
||||||
test_labels.extend(labels.cpu())
|
|
||||||
test_pridections.extend(predicted.cpu())
|
|
||||||
test_probs.extend(outputs.probabilities.cpu())
|
|
||||||
|
|
||||||
test_loss = test_loss / len(testset)
|
|
||||||
|
|
||||||
logging.info(f"Multiclass classifier test loss {test_loss:.3f}")
|
|
||||||
|
|
||||||
test_labels = torch.stack(test_labels)
|
|
||||||
test_predictions = torch.stack(test_pridections)
|
|
||||||
test_probs = torch.stack(test_probs)
|
|
||||||
|
|
||||||
accuracy = Accuracy(task="multiclass", num_classes=multiclass_num_classes)
|
|
||||||
precision = Precision(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
|
|
||||||
recall = Recall(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
|
|
||||||
f1 = F1Score(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
|
|
||||||
auroc = AUROC(task="multiclass", num_classes=multiclass_num_classes, average="weighted")
|
|
||||||
|
|
||||||
# Calculate metrics
|
|
||||||
acc = accuracy(test_predictions, test_labels)
|
|
||||||
prec = precision(test_predictions, test_labels)
|
|
||||||
rec = recall(test_predictions, test_labels)
|
|
||||||
f1_score = f1(test_predictions, test_labels)
|
|
||||||
auroc_score = auroc(test_probs, test_labels)
|
|
||||||
|
|
||||||
logging.info(f"Accuracy: {acc:.2f}")
|
|
||||||
logging.info(f"Precision: {prec:.2f}")
|
|
||||||
logging.info(f"Recall: {rec:.2f}")
|
|
||||||
logging.info(f"F1 Score: {f1_score:.2f}")
|
|
||||||
logging.info(f"AUROC Score: {auroc_score:.2f}")
|
|
||||||
|
|
||||||
|
|
||||||
def train_evaluate_binary_classifier():
|
|
||||||
logging.info(
|
|
||||||
f"Start binary classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}"
|
|
||||||
)
|
|
||||||
|
|
||||||
target_binary_class = 3
|
|
||||||
|
|
||||||
def one_vs_rest(dataset, target_class):
|
|
||||||
new_targets = []
|
|
||||||
for _, label in dataset:
|
|
||||||
new_label = float(1.0) if label == target_class else float(0.0)
|
|
||||||
new_targets.append(new_label)
|
|
||||||
|
|
||||||
dataset.targets = new_targets # Replace the original labels with the binary ones
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
binary_train_dataset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
|
|
||||||
binary_test_dataset = CIFAR10(root="data", train=False, download=True, transform=ToTensor())
|
|
||||||
|
|
||||||
# Apply one-vs-rest labeling
|
|
||||||
binary_train_dataset = one_vs_rest(binary_train_dataset, target_binary_class)
|
|
||||||
binary_test_dataset = one_vs_rest(binary_test_dataset, target_binary_class)
|
|
||||||
|
|
||||||
binary_trainloader = DataLoader(binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
|
||||||
binary_testloader = DataLoader(binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
|
||||||
|
|
||||||
binary_epoch = 1
|
|
||||||
|
|
||||||
binary_config = ClassifierConfig(model_name="microsoft/resnet-50", device=DEVICE)
|
|
||||||
binary_classifier = Classifier(binary_config)
|
|
||||||
|
|
||||||
class_counts = np.bincount(binary_train_dataset.targets)
|
|
||||||
n = len(binary_train_dataset)
|
|
||||||
w0 = n / (2.0 * class_counts[0])
|
|
||||||
w1 = n / (2.0 * class_counts[1])
|
|
||||||
|
|
||||||
binary_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(w1 / w0))
|
|
||||||
binary_optimizer = Adam(binary_classifier.parameters(), lr=LR)
|
|
||||||
|
|
||||||
binary_classifier.train()
|
|
||||||
|
|
||||||
logging.info("Start binary classifier training")
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
while binary_epoch < EPOCH_NUM: # loop over the dataset multiple times
|
|
||||||
for i, data in enumerate(binary_trainloader):
|
|
||||||
inputs, labels = data
|
|
||||||
inputs, labels = inputs.to(DEVICE), labels.to(torch.float32).to(DEVICE)
|
|
||||||
|
|
||||||
# Zero the parameter gradients
|
|
||||||
binary_optimizer.zero_grad()
|
|
||||||
|
|
||||||
# Forward pass
|
|
||||||
outputs = binary_classifier(inputs)
|
|
||||||
loss = binary_criterion(outputs.logits, labels)
|
|
||||||
loss.backward()
|
|
||||||
binary_optimizer.step()
|
|
||||||
|
|
||||||
if i % 10 == 0: # print every 10 mini-batches
|
|
||||||
print(f"[Epoch {binary_epoch}, Batch {i}] loss: {loss.item():.3f}")
|
|
||||||
binary_epoch += 1
|
|
||||||
|
|
||||||
logging.info("Binary classifier training finished")
|
|
||||||
logging.info("Start binary classifier evaluation")
|
|
||||||
|
|
||||||
binary_classifier.eval()
|
|
||||||
|
|
||||||
test_loss = 0.0
|
|
||||||
test_labels = []
|
|
||||||
test_pridections = []
|
|
||||||
test_probs = []
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for data in binary_testloader:
|
|
||||||
images, labels = data
|
|
||||||
images, labels = images.to(DEVICE), labels.to(torch.float32).to(DEVICE)
|
|
||||||
outputs = binary_classifier(images)
|
|
||||||
loss = binary_criterion(outputs.logits, labels)
|
|
||||||
test_loss += loss.item() * BATCH_SIZE
|
|
||||||
|
|
||||||
test_labels.extend(labels.cpu())
|
|
||||||
test_pridections.extend(outputs.logits.cpu())
|
|
||||||
test_probs.extend(outputs.probabilities.cpu())
|
|
||||||
|
|
||||||
test_loss = test_loss / len(binary_test_dataset)
|
|
||||||
|
|
||||||
logging.info(f"Binary classifier test loss {test_loss:.3f}")
|
|
||||||
|
|
||||||
test_labels = torch.stack(test_labels)
|
|
||||||
test_predictions = torch.stack(test_pridections)
|
|
||||||
test_probs = torch.stack(test_probs)
|
|
||||||
|
|
||||||
# Calculate metrics
|
|
||||||
acc = Accuracy(task="binary")(test_predictions, test_labels)
|
|
||||||
prec = Precision(task="binary", average="weighted")(test_predictions, test_labels)
|
|
||||||
rec = Recall(task="binary", average="weighted")(test_predictions, test_labels)
|
|
||||||
f1_score = F1Score(task="binary", average="weighted")(test_predictions, test_labels)
|
|
||||||
auroc_score = AUROC(task="binary", average="weighted")(test_probs, test_labels)
|
|
||||||
|
|
||||||
logging.info(f"Accuracy: {acc:.2f}")
|
|
||||||
logging.info(f"Precision: {prec:.2f}")
|
|
||||||
logging.info(f"Recall: {rec:.2f}")
|
|
||||||
logging.info(f"F1 Score: {f1_score:.2f}")
|
|
||||||
logging.info(f"AUROC Score: {auroc_score:.2f}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
train_evaluate_multiclass_classifier()
|
|
||||||
train_evaluate_binary_classifier()
|
|
||||||
@@ -1,9 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
|
||||||
ClassifierConfig,
|
from lerobot.common.policies.reward_model.modeling_classifier import ClassifierOutput
|
||||||
ClassifierOutput,
|
|
||||||
)
|
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
from tests.utils import require_package
|
from tests.utils import require_package
|
||||||
|
|
||||||
@@ -23,11 +21,9 @@ def test_classifier_output():
|
|||||||
|
|
||||||
@require_package("transformers")
|
@require_package("transformers")
|
||||||
def test_binary_classifier_with_default_params():
|
def test_binary_classifier_with_default_params():
|
||||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||||
Classifier,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = ClassifierConfig()
|
config = RewardClassifierConfig()
|
||||||
config.input_features = {
|
config.input_features = {
|
||||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||||
}
|
}
|
||||||
@@ -66,12 +62,10 @@ def test_binary_classifier_with_default_params():
|
|||||||
|
|
||||||
@require_package("transformers")
|
@require_package("transformers")
|
||||||
def test_multiclass_classifier():
|
def test_multiclass_classifier():
|
||||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||||
Classifier,
|
|
||||||
)
|
|
||||||
|
|
||||||
num_classes = 5
|
num_classes = 5
|
||||||
config = ClassifierConfig()
|
config = RewardClassifierConfig()
|
||||||
config.input_features = {
|
config.input_features = {
|
||||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||||
}
|
}
|
||||||
@@ -107,11 +101,9 @@ def test_multiclass_classifier():
|
|||||||
|
|
||||||
@require_package("transformers")
|
@require_package("transformers")
|
||||||
def test_default_device():
|
def test_default_device():
|
||||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||||
Classifier,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = ClassifierConfig()
|
config = RewardClassifierConfig()
|
||||||
assert config.device == "cpu"
|
assert config.device == "cpu"
|
||||||
|
|
||||||
classifier = Classifier(config)
|
classifier = Classifier(config)
|
||||||
@@ -121,11 +113,9 @@ def test_default_device():
|
|||||||
|
|
||||||
@require_package("transformers")
|
@require_package("transformers")
|
||||||
def test_explicit_device_setup():
|
def test_explicit_device_setup():
|
||||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||||
Classifier,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = ClassifierConfig(device="cpu")
|
config = RewardClassifierConfig(device="cpu")
|
||||||
assert config.device == "cpu"
|
assert config.device == "cpu"
|
||||||
|
|
||||||
classifier = Classifier(config)
|
classifier = Classifier(config)
|
||||||
Reference in New Issue
Block a user