Extend reward classifier for multiple camera views (#626)
This commit is contained in:
@@ -33,7 +33,9 @@ class MockDataset(Dataset):
|
||||
|
||||
|
||||
def make_dummy_model():
|
||||
model_config = ClassifierConfig(num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel")
|
||||
model_config = ClassifierConfig(
|
||||
num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=1
|
||||
)
|
||||
model = Classifier(config=model_config)
|
||||
return model
|
||||
|
||||
@@ -88,7 +90,7 @@ def test_train_epoch():
|
||||
logger = MagicMock()
|
||||
step = 0
|
||||
cfg = MagicMock()
|
||||
cfg.training.image_key = "image"
|
||||
cfg.training.image_keys = ["image"]
|
||||
cfg.training.label_key = "label"
|
||||
cfg.training.use_amp = False
|
||||
|
||||
@@ -130,7 +132,7 @@ def test_validate():
|
||||
device = torch.device("cpu")
|
||||
logger = MagicMock()
|
||||
cfg = MagicMock()
|
||||
cfg.training.image_key = "image"
|
||||
cfg.training.image_keys = ["image"]
|
||||
cfg.training.label_key = "label"
|
||||
cfg.training.use_amp = False
|
||||
|
||||
@@ -145,6 +147,57 @@ def test_validate():
|
||||
assert isinstance(eval_info, dict)
|
||||
|
||||
|
||||
def test_train_epoch_multiple_cameras():
|
||||
model_config = ClassifierConfig(
|
||||
num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=2
|
||||
)
|
||||
model = Classifier(config=model_config)
|
||||
|
||||
# Mock components
|
||||
model.train = MagicMock()
|
||||
|
||||
train_loader = [
|
||||
{
|
||||
"image_1": torch.rand(2, 3, 224, 224),
|
||||
"image_2": torch.rand(2, 3, 224, 224),
|
||||
"label": torch.tensor([0.0, 1.0]),
|
||||
}
|
||||
]
|
||||
|
||||
criterion = nn.BCEWithLogitsLoss()
|
||||
optimizer = MagicMock()
|
||||
grad_scaler = MagicMock()
|
||||
device = torch.device("cpu")
|
||||
logger = MagicMock()
|
||||
step = 0
|
||||
cfg = MagicMock()
|
||||
cfg.training.image_keys = ["image_1", "image_2"]
|
||||
cfg.training.label_key = "label"
|
||||
cfg.training.use_amp = False
|
||||
|
||||
# Call the function under test
|
||||
train_epoch(
|
||||
model,
|
||||
train_loader,
|
||||
criterion,
|
||||
optimizer,
|
||||
grad_scaler,
|
||||
device,
|
||||
logger,
|
||||
step,
|
||||
cfg,
|
||||
)
|
||||
|
||||
# Check that model.train() was called
|
||||
model.train.assert_called_once()
|
||||
|
||||
# Check that optimizer.zero_grad() was called
|
||||
optimizer.zero_grad.assert_called()
|
||||
|
||||
# Check that logger.log_dict was called
|
||||
logger.log_dict.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("resume", [True, False])
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.init_hydra_config")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_checkpoint_dir")
|
||||
@@ -179,7 +232,7 @@ def test_resume_function(
|
||||
"train_split_proportion=0.8",
|
||||
"training.num_workers=0",
|
||||
"training.batch_size=2",
|
||||
"training.image_key=image",
|
||||
"training.image_keys=[image]",
|
||||
"training.label_key=label",
|
||||
"training.use_amp=False",
|
||||
"training.num_epochs=1",
|
||||
|
||||
Reference in New Issue
Block a user