fix(tests): remove lint warnings/errors

This commit is contained in:
Steven Palma
2025-03-07 14:45:09 +01:00
parent e59ef036e1
commit 0eb56cec14
24 changed files with 163 additions and 133 deletions

View File

@@ -42,7 +42,9 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n
table = hf_dataset.data.table
total_episodes = calculate_total_episode(hf_dataset)
for ep_idx in range(total_episodes):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
ep_table = table.filter(
pc.equal(table["episode_index"], ep_idx)
) # TODO(Steven): What is this check supposed to do?
episode_lengths.insert(ep_idx, len(ep_table))
cumulative_lengths = list(accumulate(episode_lengths))
@@ -52,8 +54,8 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n
}
@pytest.fixture(scope="module")
def synced_timestamps_factory(hf_dataset_factory):
@pytest.fixture(name="synced_timestamps_factory", scope="module")
def fixture_synced_timestamps_factory(hf_dataset_factory):
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
hf_dataset = hf_dataset_factory(fps=fps)
timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
@@ -64,8 +66,8 @@ def synced_timestamps_factory(hf_dataset_factory):
return _create_synced_timestamps
@pytest.fixture(scope="module")
def unsynced_timestamps_factory(synced_timestamps_factory):
@pytest.fixture(name="unsynced_timestamps_factory", scope="module")
def fixture_unsynced_timestamps_factory(synced_timestamps_factory):
def _create_unsynced_timestamps(
fps: int = 30, tolerance_s: float = 1e-4
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
@@ -76,8 +78,8 @@ def unsynced_timestamps_factory(synced_timestamps_factory):
return _create_unsynced_timestamps
@pytest.fixture(scope="module")
def slightly_off_timestamps_factory(synced_timestamps_factory):
@pytest.fixture(name="slightly_off_timestamps_factory", scope="module")
def fixture_slightly_off_timestamps_factory(synced_timestamps_factory):
def _create_slightly_off_timestamps(
fps: int = 30, tolerance_s: float = 1e-4
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
@@ -88,22 +90,26 @@ def slightly_off_timestamps_factory(synced_timestamps_factory):
return _create_slightly_off_timestamps
@pytest.fixture(scope="module")
def valid_delta_timestamps_factory():
@pytest.fixture(name="valid_delta_timestamps_factory", scope="module")
def fixture_valid_delta_timestamps_factory():
def _create_valid_delta_timestamps(
fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)
fps: int = 30, keys: list | None = None, min_max_range: tuple[int, int] = (-10, 10)
) -> dict:
if keys is None:
keys = DUMMY_MOTOR_FEATURES
delta_timestamps = {key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys}
return delta_timestamps
return _create_valid_delta_timestamps
@pytest.fixture(scope="module")
def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
@pytest.fixture(name="invalid_delta_timestamps_factory", scope="module")
def fixture_invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
def _create_invalid_delta_timestamps(
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None
) -> dict:
if keys is None:
keys = DUMMY_MOTOR_FEATURES
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
# Modify a single timestamp just outside tolerance
for key in keys:
@@ -113,11 +119,13 @@ def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
return _create_invalid_delta_timestamps
@pytest.fixture(scope="module")
def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
@pytest.fixture(name="slightly_off_delta_timestamps_factory", scope="module")
def fixture_slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
def _create_slightly_off_delta_timestamps(
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None
) -> dict:
if keys is None:
keys = DUMMY_MOTOR_FEATURES
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
# Modify a single timestamp just inside tolerance
for key in delta_timestamps:
@@ -128,9 +136,11 @@ def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
return _create_slightly_off_delta_timestamps
@pytest.fixture(scope="module")
def delta_indices_factory():
def _delta_indices(keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)) -> dict:
@pytest.fixture(name="delta_indices_factory", scope="module")
def fixture_delta_indices_factory():
def _delta_indices(keys: list | None = None, min_max_range: tuple[int, int] = (-10, 10)) -> dict:
if keys is None:
keys = DUMMY_MOTOR_FEATURES
return {key: list(range(*min_max_range)) for key in keys}
return _delta_indices