[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
cdcf346061
commit
1c8daf11fd
@@ -54,7 +54,9 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def synced_timestamps_factory(hf_dataset_factory):
|
||||
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
def _create_synced_timestamps(
|
||||
fps: int = 30,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
hf_dataset = hf_dataset_factory(fps=fps)
|
||||
timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
|
||||
episode_indices = torch.stack(hf_dataset["episode_index"]).numpy()
|
||||
@@ -69,8 +71,12 @@ def unsynced_timestamps_factory(synced_timestamps_factory):
|
||||
def _create_unsynced_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
|
||||
timestamps[30] += tolerance_s * 1.1 # Modify a single timestamp just outside tolerance
|
||||
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(
|
||||
fps=fps
|
||||
)
|
||||
timestamps[30] += (
|
||||
tolerance_s * 1.1
|
||||
) # Modify a single timestamp just outside tolerance
|
||||
return timestamps, episode_indices, episode_data_index
|
||||
|
||||
return _create_unsynced_timestamps
|
||||
@@ -81,8 +87,12 @@ def slightly_off_timestamps_factory(synced_timestamps_factory):
|
||||
def _create_slightly_off_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
|
||||
timestamps[30] += tolerance_s * 0.9 # Modify a single timestamp just inside tolerance
|
||||
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(
|
||||
fps=fps
|
||||
)
|
||||
timestamps[30] += (
|
||||
tolerance_s * 0.9
|
||||
) # Modify a single timestamp just inside tolerance
|
||||
return timestamps, episode_indices, episode_data_index
|
||||
|
||||
return _create_slightly_off_timestamps
|
||||
@@ -91,9 +101,13 @@ def slightly_off_timestamps_factory(synced_timestamps_factory):
|
||||
@pytest.fixture(scope="module")
|
||||
def valid_delta_timestamps_factory():
|
||||
def _create_valid_delta_timestamps(
|
||||
fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)
|
||||
fps: int = 30,
|
||||
keys: list = DUMMY_MOTOR_FEATURES,
|
||||
min_max_range: tuple[int, int] = (-10, 10),
|
||||
) -> dict:
|
||||
delta_timestamps = {key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys}
|
||||
delta_timestamps = {
|
||||
key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys
|
||||
}
|
||||
return delta_timestamps
|
||||
|
||||
return _create_valid_delta_timestamps
|
||||
@@ -130,7 +144,9 @@ def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def delta_indices_factory():
|
||||
def _delta_indices(keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)) -> dict:
|
||||
def _delta_indices(
|
||||
keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)
|
||||
) -> dict:
|
||||
return {key: list(range(*min_max_range)) for key in keys}
|
||||
|
||||
return _delta_indices
|
||||
@@ -182,7 +198,9 @@ def test_check_timestamps_sync_unsynced_no_exception(unsynced_timestamps_factory
|
||||
def test_check_timestamps_sync_slightly_off(slightly_off_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory(fps, tolerance_s)
|
||||
timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory(
|
||||
fps, tolerance_s
|
||||
)
|
||||
result = check_timestamps_sync(
|
||||
timestamps=timestamps,
|
||||
episode_indices=ep_idx,
|
||||
@@ -223,7 +241,9 @@ def test_check_delta_timestamps_valid(valid_delta_timestamps_factory):
|
||||
def test_check_delta_timestamps_slightly_off(slightly_off_delta_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory(fps, tolerance_s)
|
||||
slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory(
|
||||
fps, tolerance_s
|
||||
)
|
||||
result = check_delta_timestamps(
|
||||
delta_timestamps=slightly_off_delta_timestamps,
|
||||
fps=fps,
|
||||
|
||||
Reference in New Issue
Block a user