[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-04 13:38:47 +00:00
committed by AdilZouitine
parent 76df8a31b3
commit 38f5fa4523
79 changed files with 2782 additions and 788 deletions

View File

@@ -9,7 +9,9 @@ from hydra import compose, initialize_config_dir
from torch import nn
from torch.utils.data import Dataset
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
ClassifierConfig,
)
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.scripts.train_hilserl_classifier import (
create_balanced_sampler,
@@ -34,7 +36,9 @@ class MockDataset(Dataset):
def make_dummy_model():
model_config = ClassifierConfig(
num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=1
num_classes=2,
model_name="hf-tiny-model-private/tiny-random-ResNetModel",
num_cameras=1,
)
model = Classifier(config=model_config)
return model
@@ -65,7 +69,9 @@ def test_create_balanced_sampler():
labels = [item["label"] for item in data]
class_counts = torch.tensor([labels.count(0), labels.count(1)], dtype=torch.float32)
class_weights = 1.0 / class_counts
expected_weights = torch.tensor([class_weights[label] for label in labels], dtype=torch.float32)
expected_weights = torch.tensor(
[class_weights[label] for label in labels], dtype=torch.float32
)
# Test that the weights are correct
assert torch.allclose(weights, expected_weights)
@@ -149,7 +155,9 @@ def test_validate():
def test_train_epoch_multiple_cameras():
model_config = ClassifierConfig(
num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=2
num_classes=2,
model_name="hf-tiny-model-private/tiny-random-ResNetModel",
num_cameras=2,
)
model = Classifier(config=model_config)
@@ -216,10 +224,16 @@ def test_resume_function(
):
# Initialize Hydra
test_file_dir = os.path.dirname(os.path.abspath(__file__))
config_dir = os.path.abspath(os.path.join(test_file_dir, "..", "lerobot", "configs", "policy"))
assert os.path.exists(config_dir), f"Config directory does not exist at {config_dir}"
config_dir = os.path.abspath(
os.path.join(test_file_dir, "..", "lerobot", "configs", "policy")
)
assert os.path.exists(
config_dir
), f"Config directory does not exist at {config_dir}"
with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"):
with initialize_config_dir(
config_dir=config_dir, job_name="test_app", version_base="1.2"
):
cfg = compose(
config_name="hilserl_classifier",
overrides=[
@@ -244,7 +258,9 @@ def test_resume_function(
mock_init_hydra_config.return_value = cfg
# Mock dataset
dataset = MockDataset([{"image": torch.rand(3, 224, 224), "label": i % 2} for i in range(10)])
dataset = MockDataset(
[{"image": torch.rand(3, 224, 224), "label": i % 2} for i in range(10)]
)
mock_dataset.return_value = dataset
# Mock checkpoint handling