Improves Type Annotations (#252)

This commit is contained in:
Wael Karkoub
2024-06-10 19:09:48 +01:00
committed by GitHub
parent a06598678c
commit 54c9776bde
7 changed files with 54 additions and 23 deletions

View File

@@ -66,28 +66,31 @@ import gc
import logging
import time
from pathlib import Path
from typing import Iterator
import numpy as np
import rerun as rr
import torch
import torch.utils.data
import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
class EpisodeSampler(torch.utils.data.Sampler):
def __init__(self, dataset, episode_index):
def __init__(self, dataset: LeRobotDataset, episode_index: int):
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
self.frame_ids = range(from_idx, to_idx)
def __iter__(self):
def __iter__(self) -> Iterator:
return iter(self.frame_ids)
def __len__(self):
def __len__(self) -> int:
return len(self.frame_ids)
def to_hwc_uint8_numpy(chw_float32_torch):
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
assert chw_float32_torch.dtype == torch.float32
assert chw_float32_torch.ndim == 3
c, h, w = chw_float32_torch.shape