Extend reward classifier for multiple camera views (#626)

This commit is contained in:
Michel Aractingi
2025-01-13 13:57:49 +01:00
parent 844bfcf484
commit bbb5ba0adf
9 changed files with 192 additions and 49 deletions

View File

@@ -33,7 +33,9 @@ class MockDataset(Dataset):
def make_dummy_model():
model_config = ClassifierConfig(num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel")
model_config = ClassifierConfig(
num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=1
)
model = Classifier(config=model_config)
return model
@@ -88,7 +90,7 @@ def test_train_epoch():
logger = MagicMock()
step = 0
cfg = MagicMock()
cfg.training.image_key = "image"
cfg.training.image_keys = ["image"]
cfg.training.label_key = "label"
cfg.training.use_amp = False
@@ -130,7 +132,7 @@ def test_validate():
device = torch.device("cpu")
logger = MagicMock()
cfg = MagicMock()
cfg.training.image_key = "image"
cfg.training.image_keys = ["image"]
cfg.training.label_key = "label"
cfg.training.use_amp = False
@@ -145,6 +147,57 @@ def test_validate():
assert isinstance(eval_info, dict)
def test_train_epoch_multiple_cameras():
model_config = ClassifierConfig(
num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=2
)
model = Classifier(config=model_config)
# Mock components
model.train = MagicMock()
train_loader = [
{
"image_1": torch.rand(2, 3, 224, 224),
"image_2": torch.rand(2, 3, 224, 224),
"label": torch.tensor([0.0, 1.0]),
}
]
criterion = nn.BCEWithLogitsLoss()
optimizer = MagicMock()
grad_scaler = MagicMock()
device = torch.device("cpu")
logger = MagicMock()
step = 0
cfg = MagicMock()
cfg.training.image_keys = ["image_1", "image_2"]
cfg.training.label_key = "label"
cfg.training.use_amp = False
# Call the function under test
train_epoch(
model,
train_loader,
criterion,
optimizer,
grad_scaler,
device,
logger,
step,
cfg,
)
# Check that model.train() was called
model.train.assert_called_once()
# Check that optimizer.zero_grad() was called
optimizer.zero_grad.assert_called()
# Check that logger.log_dict was called
logger.log_dict.assert_called()
@pytest.mark.parametrize("resume", [True, False])
@patch("lerobot.scripts.train_hilserl_classifier.init_hydra_config")
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_checkpoint_dir")
@@ -179,7 +232,7 @@ def test_resume_function(
"train_split_proportion=0.8",
"training.num_workers=0",
"training.batch_size=2",
"training.image_key=image",
"training.image_keys=[image]",
"training.label_key=label",
"training.use_amp=False",
"training.num_epochs=1",