[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
cdcf346061
commit
1c8daf11fd
@@ -18,7 +18,10 @@ from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
|
||||
@pytest.fixture
|
||||
def mock_metrics():
|
||||
return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")}
|
||||
return {
|
||||
"loss": AverageMeter("loss", ":.3f"),
|
||||
"accuracy": AverageMeter("accuracy", ":.2f"),
|
||||
}
|
||||
|
||||
|
||||
def test_average_meter_initialization():
|
||||
@@ -58,7 +61,11 @@ def test_average_meter_str():
|
||||
|
||||
def test_metrics_tracker_initialization(mock_metrics):
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics, initial_step=10
|
||||
batch_size=32,
|
||||
num_frames=1000,
|
||||
num_episodes=50,
|
||||
metrics=mock_metrics,
|
||||
initial_step=10,
|
||||
)
|
||||
assert tracker.steps == 10
|
||||
assert tracker.samples == 10 * 32
|
||||
@@ -70,7 +77,11 @@ def test_metrics_tracker_initialization(mock_metrics):
|
||||
|
||||
def test_metrics_tracker_step(mock_metrics):
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics, initial_step=5
|
||||
batch_size=32,
|
||||
num_frames=1000,
|
||||
num_episodes=50,
|
||||
metrics=mock_metrics,
|
||||
initial_step=5,
|
||||
)
|
||||
tracker.step()
|
||||
assert tracker.steps == 6
|
||||
@@ -80,7 +91,9 @@ def test_metrics_tracker_step(mock_metrics):
|
||||
|
||||
|
||||
def test_metrics_tracker_getattr(mock_metrics):
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics
|
||||
)
|
||||
assert tracker.loss == mock_metrics["loss"]
|
||||
assert tracker.accuracy == mock_metrics["accuracy"]
|
||||
with pytest.raises(AttributeError):
|
||||
@@ -88,13 +101,17 @@ def test_metrics_tracker_getattr(mock_metrics):
|
||||
|
||||
|
||||
def test_metrics_tracker_setattr(mock_metrics):
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics
|
||||
)
|
||||
tracker.loss = 2.0
|
||||
assert tracker.loss.val == 2.0
|
||||
|
||||
|
||||
def test_metrics_tracker_str(mock_metrics):
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics
|
||||
)
|
||||
tracker.loss.update(3.456, 1)
|
||||
tracker.accuracy.update(0.876, 1)
|
||||
output = str(tracker)
|
||||
@@ -103,7 +120,9 @@ def test_metrics_tracker_str(mock_metrics):
|
||||
|
||||
|
||||
def test_metrics_tracker_to_dict(mock_metrics):
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics
|
||||
)
|
||||
tracker.loss.update(5, 2)
|
||||
metrics_dict = tracker.to_dict()
|
||||
assert isinstance(metrics_dict, dict)
|
||||
@@ -112,7 +131,9 @@ def test_metrics_tracker_to_dict(mock_metrics):
|
||||
|
||||
|
||||
def test_metrics_tracker_reset_averages(mock_metrics):
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics
|
||||
)
|
||||
tracker.loss.update(10, 3)
|
||||
tracker.accuracy.update(0.95, 5)
|
||||
tracker.reset_averages()
|
||||
|
||||
Reference in New Issue
Block a user