most unit tests passing (TODO: convert datasets)

This commit is contained in:
Remi Cadene
2025-04-16 21:30:58 +02:00
parent c2a05a1fde
commit 6b6a990f4c
22 changed files with 150 additions and 136 deletions

View File

@@ -43,8 +43,8 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n
def synced_timestamps_factory(hf_dataset_factory):
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
hf_dataset = hf_dataset_factory(fps=fps)
timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
episode_indices = torch.stack(hf_dataset["episode_index"]).numpy()
timestamps = hf_dataset["timestamp"].numpy()
episode_indices = hf_dataset["episode_index"].numpy()
episode_data_index = calculate_episode_data_index(hf_dataset)
return timestamps, episode_indices, episode_data_index