Files
lerobot_piper/lerobot/common/datasets/utils.py
Cadene 1cdfbc8b52 WIP
WIP

WIP train.py works, loss going down

WIP eval.py

Fix

WIP (eval running, TODO: verify results reproduced)

Eval works! (testing reproducibility)

WIP

pretrained model pusht reproduces same results as torchrl

pretrained model pusht reproduces same results as torchrl

Remove AbstractPolicy, Move all queues in select_action

WIP test_datasets passed (TODO: re-enable NormalizeTransform)
2024-04-04 15:31:03 +00:00

100 lines
3.2 KiB
Python

import io
import zipfile
from pathlib import Path
import requests
import torch
import tqdm
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
print(f"downloading from {url}")
response = requests.get(url, stream=True)
if response.status_code == 200:
total_size = int(response.headers.get("content-length", 0))
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
zip_file = io.BytesIO()
for chunk in response.iter_content(chunk_size=1024):
if chunk:
zip_file.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()
zip_file.seek(0)
with zipfile.ZipFile(zip_file, "r") as zip_ref:
zip_ref.extractall(destination_folder)
return True
else:
return False
def euclidean_distance_matrix(mat0, mat1):
# Compute the square of the distance matrix
sq0 = torch.sum(mat0**2, dim=1, keepdim=True)
sq1 = torch.sum(mat1**2, dim=1, keepdim=True)
distance_sq = sq0 + sq1.transpose(0, 1) - 2 * mat0 @ mat1.transpose(0, 1)
# Taking the square root to get the euclidean distance
distance = torch.sqrt(torch.clamp(distance_sq, min=0))
return distance
def is_contiguously_true_or_false(bool_vector):
assert bool_vector.ndim == 1
assert bool_vector.dtype == torch.bool
# Compare each element with its neighbor to find changes
changes = bool_vector[1:] != bool_vector[:-1]
# Count the number of changes
num_changes = changes.sum().item()
# If there's more than one change, the list is not contiguous
return num_changes <= 1
# examples = [
# ([True, False, True, False, False, False], False),
# ([True, True, True, False, False, False], True),
# ([False, False, False, False, False, False], True)
# ]
# for bool_list, expected in examples:
# result = is_contiguously_true_or_false(bool_list)
def load_data_with_delta_timestamps(
data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode
):
# get indices of the frames associated to the episode, and their timestamps
ep_data_ids = data_ids_per_episode[episode]
ep_timestamps = data_dict["timestamp"][ep_data_ids]
# get timestamps used as query to retrieve data of previous/future frames
delta_ts = delta_timestamps[key]
query_ts = current_ts + torch.tensor(delta_ts)
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
dist = euclidean_distance_matrix(query_ts[:, None], ep_timestamps[:, None])
min_, argmin_ = dist.min(1)
# get the indices of the data that are closest to the query timestamps
data_ids = ep_data_ids[argmin_]
# closest_ts = ep_timestamps[argmin_]
# get the data
data = data_dict[key][data_ids].clone()
# TODO(rcadene): synchronize timestamps + interpolation if needed
tol = 0.02
is_pad = min_ > tol
assert is_contiguously_true_or_false(is_pad), (
"One or several timestamps unexpectedly violate the tolerance."
"This might be due to synchronization issues with timestamps during data collection."
)
return data, is_pad