[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
AdilZouitine
parent
76df8a31b3
commit
38f5fa4523
@@ -9,7 +9,9 @@ from hydra import compose, initialize_config_dir
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
|
||||
ClassifierConfig,
|
||||
)
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
from lerobot.scripts.train_hilserl_classifier import (
|
||||
create_balanced_sampler,
|
||||
@@ -34,7 +36,9 @@ class MockDataset(Dataset):
|
||||
|
||||
def make_dummy_model():
|
||||
model_config = ClassifierConfig(
|
||||
num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=1
|
||||
num_classes=2,
|
||||
model_name="hf-tiny-model-private/tiny-random-ResNetModel",
|
||||
num_cameras=1,
|
||||
)
|
||||
model = Classifier(config=model_config)
|
||||
return model
|
||||
@@ -65,7 +69,9 @@ def test_create_balanced_sampler():
|
||||
labels = [item["label"] for item in data]
|
||||
class_counts = torch.tensor([labels.count(0), labels.count(1)], dtype=torch.float32)
|
||||
class_weights = 1.0 / class_counts
|
||||
expected_weights = torch.tensor([class_weights[label] for label in labels], dtype=torch.float32)
|
||||
expected_weights = torch.tensor(
|
||||
[class_weights[label] for label in labels], dtype=torch.float32
|
||||
)
|
||||
|
||||
# Test that the weights are correct
|
||||
assert torch.allclose(weights, expected_weights)
|
||||
@@ -149,7 +155,9 @@ def test_validate():
|
||||
|
||||
def test_train_epoch_multiple_cameras():
|
||||
model_config = ClassifierConfig(
|
||||
num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=2
|
||||
num_classes=2,
|
||||
model_name="hf-tiny-model-private/tiny-random-ResNetModel",
|
||||
num_cameras=2,
|
||||
)
|
||||
model = Classifier(config=model_config)
|
||||
|
||||
@@ -216,10 +224,16 @@ def test_resume_function(
|
||||
):
|
||||
# Initialize Hydra
|
||||
test_file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
config_dir = os.path.abspath(os.path.join(test_file_dir, "..", "lerobot", "configs", "policy"))
|
||||
assert os.path.exists(config_dir), f"Config directory does not exist at {config_dir}"
|
||||
config_dir = os.path.abspath(
|
||||
os.path.join(test_file_dir, "..", "lerobot", "configs", "policy")
|
||||
)
|
||||
assert os.path.exists(
|
||||
config_dir
|
||||
), f"Config directory does not exist at {config_dir}"
|
||||
|
||||
with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"):
|
||||
with initialize_config_dir(
|
||||
config_dir=config_dir, job_name="test_app", version_base="1.2"
|
||||
):
|
||||
cfg = compose(
|
||||
config_name="hilserl_classifier",
|
||||
overrides=[
|
||||
@@ -244,7 +258,9 @@ def test_resume_function(
|
||||
mock_init_hydra_config.return_value = cfg
|
||||
|
||||
# Mock dataset
|
||||
dataset = MockDataset([{"image": torch.rand(3, 224, 224), "label": i % 2} for i in range(10)])
|
||||
dataset = MockDataset(
|
||||
[{"image": torch.rand(3, 224, 224), "label": i % 2} for i in range(10)]
|
||||
)
|
||||
mock_dataset.return_value = dataset
|
||||
|
||||
# Mock checkpoint handling
|
||||
|
||||
Reference in New Issue
Block a user