import io import logging import zipfile from copy import deepcopy from math import ceil from pathlib import Path import einops import requests import torch import tqdm def download_and_extract_zip(url: str, destination_folder: Path) -> bool: print(f"downloading from {url}") response = requests.get(url, stream=True) if response.status_code == 200: total_size = int(response.headers.get("content-length", 0)) progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True) zip_file = io.BytesIO() for chunk in response.iter_content(chunk_size=1024): if chunk: zip_file.write(chunk) progress_bar.update(len(chunk)) progress_bar.close() zip_file.seek(0) with zipfile.ZipFile(zip_file, "r") as zip_ref: zip_ref.extractall(destination_folder) return True else: return False def euclidean_distance_matrix(mat0, mat1): # Compute the square of the distance matrix sq0 = torch.sum(mat0**2, dim=1, keepdim=True) sq1 = torch.sum(mat1**2, dim=1, keepdim=True) distance_sq = sq0 + sq1.transpose(0, 1) - 2 * mat0 @ mat1.transpose(0, 1) # Taking the square root to get the euclidean distance distance = torch.sqrt(torch.clamp(distance_sq, min=0)) return distance def is_contiguously_true_or_false(bool_vector): assert bool_vector.ndim == 1 assert bool_vector.dtype == torch.bool # Compare each element with its neighbor to find changes changes = bool_vector[1:] != bool_vector[:-1] # Count the number of changes num_changes = changes.sum().item() # If there's more than one change, the list is not contiguous return num_changes <= 1 # examples = [ # ([True, False, True, False, False, False], False), # ([True, True, True, False, False, False], True), # ([False, False, False, False, False, False], True) # ] # for bool_list, expected in examples: # result = is_contiguously_true_or_false(bool_list) def load_data_with_delta_timestamps( data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode ): # get indices of the frames associated to the episode, and their timestamps ep_data_ids = data_ids_per_episode[episode] ep_timestamps = data_dict["timestamp"][ep_data_ids] # get timestamps used as query to retrieve data of previous/future frames delta_ts = delta_timestamps[key] query_ts = current_ts + torch.tensor(delta_ts) # compute distances between each query timestamp and all timestamps of all the frames belonging to the episode dist = euclidean_distance_matrix(query_ts[:, None], ep_timestamps[:, None]) min_, argmin_ = dist.min(1) # get the indices of the data that are closest to the query timestamps data_ids = ep_data_ids[argmin_] # closest_ts = ep_timestamps[argmin_] # get the data data = data_dict[key][data_ids].clone() # TODO(rcadene): synchronize timestamps + interpolation if needed tol = 0.02 is_pad = min_ > tol assert is_contiguously_true_or_false(is_pad), ( "One or several timestamps unexpectedly violate the tolerance." "This might be due to synchronization issues with timestamps during data collection." ) return data, is_pad def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None): stats_path = dataset.data_dir / "stats.pth" if stats_path.exists(): return torch.load(stats_path) logging.info(f"compute_stats and save to {stats_path}") if max_num_samples is None: max_num_samples = len(dataset) else: raise NotImplementedError("We need to set shuffle=True, but this violate an assert for now.") dataloader = torch.utils.data.DataLoader( dataset, num_workers=4, batch_size=batch_size, shuffle=False, # pin_memory=cfg.device != "cpu", drop_last=False, ) # these einops patterns will be used to aggregate batches and compute statistics stats_patterns = { "action": "b c -> c", "observation.state": "b c -> c", } for key in dataset.image_keys: stats_patterns[key] = "b c h w -> c 1 1" # mean and std will be computed incrementally while max and min will track the running value. mean, std, max, min = {}, {}, {}, {} for key in stats_patterns: mean[key] = torch.tensor(0.0).float() std[key] = torch.tensor(0.0).float() max[key] = torch.tensor(-float("inf")).float() min[key] = torch.tensor(float("inf")).float() # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get # surprises when rerunning the sampler. first_batch = None running_item_count = 0 # for online mean computation for i, batch in enumerate( tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max") ): this_batch_size = len(batch["index"]) running_item_count += this_batch_size if first_batch is None: first_batch = deepcopy(batch) for key, pattern in stats_patterns.items(): batch[key] = batch[key].float() # Numerically stable update step for mean computation. batch_mean = einops.reduce(batch[key], pattern, "mean") # Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents # the update step, N is the running item count, B is this batch size, x̄ is the running mean, # and x is the current batch mean. Some rearrangement is then required to avoid risking # numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields # x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) if i == ceil(max_num_samples / batch_size) - 1: break first_batch_ = None running_item_count = 0 # for online std computation for i, batch in enumerate( tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std") ): this_batch_size = len(batch["index"]) running_item_count += this_batch_size # Sanity check to make sure the batches are still in the same order as before. if first_batch_ is None: first_batch_ = deepcopy(batch) for key in stats_patterns: assert torch.equal(first_batch_[key], first_batch[key]) for key, pattern in stats_patterns.items(): batch[key] = batch[key].float() # Numerically stable update step for mean computation (where the mean is over squared # residuals).See notes in the mean computation loop above. batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean") std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count if i == ceil(max_num_samples / batch_size) - 1: break for key in stats_patterns: std[key] = torch.sqrt(std[key]) stats = {} for key in stats_patterns: stats[key] = { "mean": mean[key], "std": std[key], "max": max[key], "min": min[key], } torch.save(stats, stats_path) return stats def cycle(iterable): iterator = iter(iterable) while True: try: yield next(iterator) except StopIteration: iterator = iter(iterable)