fix(tests): remove lint warnings/errors
This commit is contained in:
@@ -42,7 +42,9 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n
|
||||
table = hf_dataset.data.table
|
||||
total_episodes = calculate_total_episode(hf_dataset)
|
||||
for ep_idx in range(total_episodes):
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
ep_table = table.filter(
|
||||
pc.equal(table["episode_index"], ep_idx)
|
||||
) # TODO(Steven): What is this check supposed to do?
|
||||
episode_lengths.insert(ep_idx, len(ep_table))
|
||||
|
||||
cumulative_lengths = list(accumulate(episode_lengths))
|
||||
@@ -52,8 +54,8 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def synced_timestamps_factory(hf_dataset_factory):
|
||||
@pytest.fixture(name="synced_timestamps_factory", scope="module")
|
||||
def fixture_synced_timestamps_factory(hf_dataset_factory):
|
||||
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
hf_dataset = hf_dataset_factory(fps=fps)
|
||||
timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
|
||||
@@ -64,8 +66,8 @@ def synced_timestamps_factory(hf_dataset_factory):
|
||||
return _create_synced_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def unsynced_timestamps_factory(synced_timestamps_factory):
|
||||
@pytest.fixture(name="unsynced_timestamps_factory", scope="module")
|
||||
def fixture_unsynced_timestamps_factory(synced_timestamps_factory):
|
||||
def _create_unsynced_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
@@ -76,8 +78,8 @@ def unsynced_timestamps_factory(synced_timestamps_factory):
|
||||
return _create_unsynced_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def slightly_off_timestamps_factory(synced_timestamps_factory):
|
||||
@pytest.fixture(name="slightly_off_timestamps_factory", scope="module")
|
||||
def fixture_slightly_off_timestamps_factory(synced_timestamps_factory):
|
||||
def _create_slightly_off_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
@@ -88,22 +90,26 @@ def slightly_off_timestamps_factory(synced_timestamps_factory):
|
||||
return _create_slightly_off_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def valid_delta_timestamps_factory():
|
||||
@pytest.fixture(name="valid_delta_timestamps_factory", scope="module")
|
||||
def fixture_valid_delta_timestamps_factory():
|
||||
def _create_valid_delta_timestamps(
|
||||
fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)
|
||||
fps: int = 30, keys: list | None = None, min_max_range: tuple[int, int] = (-10, 10)
|
||||
) -> dict:
|
||||
if keys is None:
|
||||
keys = DUMMY_MOTOR_FEATURES
|
||||
delta_timestamps = {key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys}
|
||||
return delta_timestamps
|
||||
|
||||
return _create_valid_delta_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
@pytest.fixture(name="invalid_delta_timestamps_factory", scope="module")
|
||||
def fixture_invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
def _create_invalid_delta_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None
|
||||
) -> dict:
|
||||
if keys is None:
|
||||
keys = DUMMY_MOTOR_FEATURES
|
||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
||||
# Modify a single timestamp just outside tolerance
|
||||
for key in keys:
|
||||
@@ -113,11 +119,13 @@ def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
return _create_invalid_delta_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
@pytest.fixture(name="slightly_off_delta_timestamps_factory", scope="module")
|
||||
def fixture_slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
def _create_slightly_off_delta_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None
|
||||
) -> dict:
|
||||
if keys is None:
|
||||
keys = DUMMY_MOTOR_FEATURES
|
||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
||||
# Modify a single timestamp just inside tolerance
|
||||
for key in delta_timestamps:
|
||||
@@ -128,9 +136,11 @@ def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
return _create_slightly_off_delta_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def delta_indices_factory():
|
||||
def _delta_indices(keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)) -> dict:
|
||||
@pytest.fixture(name="delta_indices_factory", scope="module")
|
||||
def fixture_delta_indices_factory():
|
||||
def _delta_indices(keys: list | None = None, min_max_range: tuple[int, int] = (-10, 10)) -> dict:
|
||||
if keys is None:
|
||||
keys = DUMMY_MOTOR_FEATURES
|
||||
return {key: list(range(*min_max_range)) for key in keys}
|
||||
|
||||
return _delta_indices
|
||||
|
||||
Reference in New Issue
Block a user