id -> index, finish moving compute_stats before hf_dataset push_to_hub

This commit is contained in:
Cadene
2024-04-19 10:33:42 +00:00
parent 64b09ea7a7
commit 714a776277
9 changed files with 120 additions and 99 deletions

View File

@@ -19,7 +19,7 @@ from huggingface_hub import HfApi
from PIL import Image as PILImage from PIL import Image as PILImage
from safetensors.torch import save_file from safetensors.torch import save_file
from lerobot.common.datasets.utils import compute_stats, flatten_dict from lerobot.common.datasets.utils import compute_stats, convert_images_to_channel_first_tensors, flatten_dict
def download_and_upload(root, revision, dataset_id): def download_and_upload(root, revision, dataset_id):
@@ -75,28 +75,18 @@ def concatenate_episodes(ep_dicts):
for x in ep_dict[key]: for x in ep_dict[key]:
data_dict[key].append(x) data_dict[key].append(x)
total_frames = data_dict["frame_id"].shape[0] total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1) data_dict["index"] = torch.arange(0, total_frames, 1)
return data_dict return data_dict
def push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id): def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id):
hf_dataset = hf_dataset.with_format("torch")
# push to main to indicate latest version # push to main to indicate latest version
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True) hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
# push to version branch # push to version branch
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision=revision) hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision=revision)
# get stats
stats_pth_path = root / dataset_id / "stats.pth"
if stats_pth_path.exists():
stats = torch.load(stats_pth_path)
else:
stats = compute_stats(hf_dataset)
torch.save(stats, stats_pth_path)
# create and store meta_data # create and store meta_data
meta_data_dir = root / dataset_id / "meta_data" meta_data_dir = root / dataset_id / "meta_data"
meta_data_dir.mkdir(parents=True, exist_ok=True) meta_data_dir.mkdir(parents=True, exist_ok=True)
@@ -237,8 +227,8 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
"observation.image": [PILImage.fromarray(x.numpy()) for x in image], "observation.image": [PILImage.fromarray(x.numpy()) for x in image],
"observation.state": agent_pos, "observation.state": agent_pos,
"action": actions[id_from:id_to], "action": actions[id_from:id_to],
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int), "episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1), "frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps, "timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.image": image[1:], # "next.observation.image": image[1:],
# "next.observation.state": agent_pos[1:], # "next.observation.state": agent_pos[1:],
@@ -262,8 +252,8 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
), ),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)), "action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_id": Value(dtype="int64", id=None), "episode_index": Value(dtype="int64", id=None),
"frame_id": Value(dtype="int64", id=None), "frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None), "timestamp": Value(dtype="float32", id=None),
"next.reward": Value(dtype="float32", id=None), "next.reward": Value(dtype="float32", id=None),
"next.done": Value(dtype="bool", id=None), "next.done": Value(dtype="bool", id=None),
@@ -272,11 +262,14 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
} }
features = Features(features) features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features) hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
info = { info = {
"fps": fps, "fps": fps,
} }
push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id) stats = compute_stats(hf_dataset)
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
def download_and_upload_xarm(root, revision, dataset_id, fps=15): def download_and_upload_xarm(root, revision, dataset_id, fps=15):
@@ -334,8 +327,8 @@ def download_and_upload_xarm(root, revision, dataset_id, fps=15):
"observation.image": [PILImage.fromarray(x.numpy()) for x in image], "observation.image": [PILImage.fromarray(x.numpy()) for x in image],
"observation.state": state, "observation.state": state,
"action": action, "action": action,
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int), "episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1), "frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps, "timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.image": next_image, # "next.observation.image": next_image,
# "next.observation.state": next_state, # "next.observation.state": next_state,
@@ -358,8 +351,8 @@ def download_and_upload_xarm(root, revision, dataset_id, fps=15):
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
), ),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)), "action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_id": Value(dtype="int64", id=None), "episode_index": Value(dtype="int64", id=None),
"frame_id": Value(dtype="int64", id=None), "frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None), "timestamp": Value(dtype="float32", id=None),
"next.reward": Value(dtype="float32", id=None), "next.reward": Value(dtype="float32", id=None),
"next.done": Value(dtype="bool", id=None), "next.done": Value(dtype="bool", id=None),
@@ -368,11 +361,14 @@ def download_and_upload_xarm(root, revision, dataset_id, fps=15):
} }
features = Features(features) features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features) hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
info = { info = {
"fps": fps, "fps": fps,
} }
push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id) stats = compute_stats(hf_dataset)
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
def download_and_upload_aloha(root, revision, dataset_id, fps=50): def download_and_upload_aloha(root, revision, dataset_id, fps=50):
@@ -464,8 +460,8 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
{ {
"observation.state": state, "observation.state": state,
"action": action, "action": action,
"episode_id": torch.tensor([ep_id] * num_frames), "episode_index": torch.tensor([ep_id] * num_frames),
"frame_id": torch.arange(0, num_frames, 1), "frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps, "timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.state": state, # "next.observation.state": state,
# TODO(rcadene): compute reward and success # TODO(rcadene): compute reward and success
@@ -493,8 +489,8 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
), ),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)), "action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_id": Value(dtype="int64", id=None), "episode_index": Value(dtype="int64", id=None),
"frame_id": Value(dtype="int64", id=None), "frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None), "timestamp": Value(dtype="float32", id=None),
#'next.reward': Value(dtype='float32', id=None), #'next.reward': Value(dtype='float32', id=None),
"next.done": Value(dtype="bool", id=None), "next.done": Value(dtype="bool", id=None),
@@ -503,11 +499,14 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
} }
features = Features(features) features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features) hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
info = { info = {
"fps": fps, "fps": fps,
} }
push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id) stats = compute_stats(hf_dataset)
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -49,7 +49,7 @@ print(f"number of episodes: {len(hf_dataset.unique('episode_id'))=}")
print(f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_id')):.3f}") print(f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_id')):.3f}")
# select the frames belonging to episode number 5 # select the frames belonging to episode number 5
hf_dataset = hf_dataset.filter(lambda frame: frame["episode_id"] == 5) hf_dataset = hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
# load all frames of episode 5 in RAM in PIL format # load all frames of episode 5 in RAM in PIL format
frames = hf_dataset["observation.image"] frames = hf_dataset["observation.image"]

View File

@@ -55,7 +55,7 @@ print(f"frames per second used during data collection: {dataset.fps=}")
print(f"keys to access images from cameras: {dataset.image_keys=}") print(f"keys to access images from cameras: {dataset.image_keys=}")
# While the LeRobot dataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5. # While the LeRobot dataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_id"] == 5) dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grap all the image frames. # LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grap all the image frames.
frames = [sample["observation.image"] for sample in dataset] frames = [sample["observation.image"] for sample in dataset]

View File

@@ -54,7 +54,7 @@ class AlohaDataset(torch.utils.data.Dataset):
@property @property
def num_episodes(self) -> int: def num_episodes(self) -> int:
return len(self.hf_dataset.unique("episode_id")) return len(self.hf_dataset.unique("episode_index"))
def __len__(self): def __len__(self):
return self.num_samples return self.num_samples

View File

@@ -1,11 +1,9 @@
import logging
import os import os
from pathlib import Path from pathlib import Path
import torch import torch
from torchvision.transforms import v2 from torchvision.transforms import v2
from lerobot.common.datasets.utils import compute_stats
from lerobot.common.transforms import NormalizeTransform, Prod from lerobot.common.transforms import NormalizeTransform, Prod
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
@@ -52,26 +50,14 @@ def make_dataset(
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
elif stats_path is None: elif stats_path is None:
# load stats if the file exists already or compute stats and save it # load a first dataset to access precomputed stats
if DATA_DIR is None: stats_dataset = clsfunc(
# TODO(rcadene): clean stats dataset_id=cfg.dataset_id,
precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth" split="train",
else: root=DATA_DIR,
precomputed_stats_path = DATA_DIR / cfg.dataset_id / "stats.pth" transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
if precomputed_stats_path.exists(): )
stats = torch.load(precomputed_stats_path) stats = stats_dataset.stats
else:
logging.info(f"compute_stats and save to {precomputed_stats_path}")
# Create a dataset for stats computation.
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
split="train",
root=DATA_DIR,
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
)
stats = compute_stats(stats_dataset)
precomputed_stats_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(stats, precomputed_stats_path)
else: else:
stats = torch.load(stats_path) stats = torch.load(stats_path)

View File

@@ -10,6 +10,8 @@ from datasets import load_dataset, load_from_disk
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from safetensors.torch import load_file from safetensors.torch import load_file
from lerobot.common.utils.utils import set_global_seed
def flatten_dict(d, parent_key="", sep="/"): def flatten_dict(d, parent_key="", sep="/"):
items = [] items = []
@@ -42,7 +44,9 @@ def load_hf_dataset(dataset_id, version, root, split) -> datasets.Dataset:
else: else:
repo_id = f"lerobot/{dataset_id}" repo_id = f"lerobot/{dataset_id}"
hf_dataset = load_dataset(repo_id, revision=version, split=split) hf_dataset = load_dataset(repo_id, revision=version, split=split)
return hf_dataset.with_format("torch") hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
return hf_dataset
def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor]: def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor]:
@@ -126,7 +130,7 @@ def load_previous_and_future_frames(
issues with timestamps during data collection. issues with timestamps during data collection.
""" """
# get indices of the frames associated to the episode, and their timestamps # get indices of the frames associated to the episode, and their timestamps
ep_id = item["episode_id"].item() ep_id = item["episode_index"].item()
ep_data_id_from = episode_data_index["from"][ep_id].item() ep_data_id_from = episode_data_index["from"][ep_id].item()
ep_data_id_to = episode_data_index["to"][ep_id].item() ep_data_id_to = episode_data_index["to"][ep_id].item()
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1) ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
@@ -168,34 +172,53 @@ def load_previous_and_future_frames(
return item return item
def get_stats_einops_patterns(dataset): def convert_images_to_channel_first_tensors(examples):
"""These einops patterns will be used to aggregate batches and compute statistics.""" for key in examples:
stats_patterns = { if examples[key].ndim == 3: # we assume it's an image
"action": "b c -> c", # (h w c) -> (c h w)
"observation.state": "b c -> c", h, w, c = examples[key].shape
} assert c < h and c < w, f"expect a channel last image, but instead {examples[key].shape}"
for key in dataset.image_keys: examples[key] = [img.permute((2, 0, 1)) for img in examples[key]]
stats_patterns[key] = "b c h w -> c 1 1" return examples
def get_stats_einops_patterns(hf_dataset):
"""These einops patterns will be used to aggregate batches and compute statistics.
Note: We assume the images are returned in channel first format
"""
dataloader = torch.utils.data.DataLoader(
hf_dataset,
num_workers=0,
batch_size=2,
shuffle=False,
)
batch = next(iter(dataloader))
stats_patterns = {}
for key, feats_type in hf_dataset.features.items():
if batch[key].ndim == 4 and isinstance(feats_type, datasets.features.image.Image):
# sanity check that images are channel first
_, c, h, w = batch[key].shape
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
# convert from (h w c) to (c h w) to fit pytorch convention, then apply reduce
stats_patterns[key] = "b c h w -> c 1 1"
elif batch[key].ndim == 2:
stats_patterns[key] = "b c -> c "
elif batch[key].ndim == 1:
stats_patterns[key] = "b -> 1"
else:
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
return stats_patterns return stats_patterns
def compute_stats(dataset, batch_size=32, max_num_samples=None): def compute_stats(hf_dataset, batch_size=32, max_num_samples=None):
if max_num_samples is None: if max_num_samples is None:
max_num_samples = len(dataset) max_num_samples = len(hf_dataset)
else:
raise NotImplementedError("We need to set shuffle=True, but this violate an assert for now.")
dataloader = torch.utils.data.DataLoader( stats_patterns = get_stats_einops_patterns(hf_dataset)
dataset,
num_workers=4,
batch_size=batch_size,
shuffle=False,
# pin_memory=cfg.device != "cpu",
drop_last=False,
)
# get einops patterns to aggregate batches and compute statistics
stats_patterns = get_stats_einops_patterns(dataset)
# mean and std will be computed incrementally while max and min will track the running value. # mean and std will be computed incrementally while max and min will track the running value.
mean, std, max, min = {}, {}, {}, {} mean, std, max, min = {}, {}, {}, {}
@@ -205,10 +228,23 @@ def compute_stats(dataset, batch_size=32, max_num_samples=None):
max[key] = torch.tensor(-float("inf")).float() max[key] = torch.tensor(-float("inf")).float()
min[key] = torch.tensor(float("inf")).float() min[key] = torch.tensor(float("inf")).float()
def create_seeded_dataloader(hf_dataset, batch_size, seed):
set_global_seed(seed)
dataloader = torch.utils.data.DataLoader(
hf_dataset,
num_workers=4,
batch_size=batch_size,
shuffle=False,
# pin_memory=cfg.device != "cpu",
drop_last=False,
)
return dataloader
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
# surprises when rerunning the sampler. # surprises when rerunning the sampler.
first_batch = None first_batch = None
running_item_count = 0 # for online mean computation running_item_count = 0 # for online mean computation
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
for i, batch in enumerate( for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max") tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
): ):
@@ -234,6 +270,7 @@ def compute_stats(dataset, batch_size=32, max_num_samples=None):
first_batch_ = None first_batch_ = None
running_item_count = 0 # for online std computation running_item_count = 0 # for online std computation
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
for i, batch in enumerate( for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std") tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
): ):

View File

@@ -46,7 +46,7 @@ class XarmDataset(torch.utils.data.Dataset):
@property @property
def num_episodes(self) -> int: def num_episodes(self) -> int:
return len(self.hf_dataset.unique("episode_id")) return len(self.hf_dataset.unique("episode_index"))
def __len__(self): def __len__(self):
return self.num_samples return self.num_samples

View File

@@ -157,7 +157,7 @@ def add_episodes_inplace(
Raises: Raises:
- AssertionError: If the first episode_id or index in hf_dataset is not 0 - AssertionError: If the first episode_id or index in hf_dataset is not 0
""" """
first_episode_id = hf_dataset.select_columns("episode_id")[0]["episode_id"].item() first_episode_id = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
first_index = hf_dataset.select_columns("index")[0]["index"].item() first_index = hf_dataset.select_columns("index")[0]["index"].item()
assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}" assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}"
assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}" assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}"
@@ -167,12 +167,12 @@ def add_episodes_inplace(
online_dataset.hf_dataset = hf_dataset online_dataset.hf_dataset = hf_dataset
else: else:
# find episode index and data frame indices according to previous episode in online_dataset # find episode index and data frame indices according to previous episode in online_dataset
start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1 start_episode = online_dataset.select_columns("episode_index")[-1]["episode_index"].item() + 1
start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1 start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1
def shift_indices(example): def shift_indices(example):
# note: we dont shift "frame_id" since it represents the index of the frame in the episode it belongs to # note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
example["episode_id"] += start_episode example["episode_index"] += start_episode
example["index"] += start_index example["index"] += start_index
example["episode_data_index_from"] += start_index example["episode_data_index_from"] += start_index
example["episode_data_index_to"] += start_index example["episode_data_index_to"] += start_index

View File

@@ -1,22 +1,21 @@
import logging
from copy import deepcopy
import json import json
import logging
import os import os
from copy import deepcopy
from pathlib import Path from pathlib import Path
import einops import einops
import pytest import pytest
import torch import torch
from datasets import Dataset from datasets import Dataset
import lerobot import lerobot
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
compute_stats, compute_stats,
flatten_dict,
get_stats_einops_patterns, get_stats_einops_patterns,
load_previous_and_future_frames, load_previous_and_future_frames,
flatten_dict,
unflatten_dict, unflatten_dict,
) )
from lerobot.common.transforms import Prod from lerobot.common.transforms import Prod
@@ -44,8 +43,8 @@ def test_factory(env_name, dataset_id, policy_name):
keys_ndim_required = [ keys_ndim_required = [
("action", 1, True), ("action", 1, True),
("episode_id", 0, True), ("episode_index", 0, True),
("frame_id", 0, True), ("frame_index", 0, True),
("timestamp", 0, True), ("timestamp", 0, True),
# TODO(rcadene): should we rename it agent_pos? # TODO(rcadene): should we rename it agent_pos?
("observation.state", 1, True), ("observation.state", 1, True),
@@ -165,13 +164,13 @@ def test_load_previous_and_future_frames_within_tolerance():
{ {
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4], "index": [0, 1, 2, 3, 4],
"episode_id": [0, 0, 0, 0, 0], "episode_index": [0, 0, 0, 0, 0],
} }
) )
hf_dataset = hf_dataset.with_format("torch") hf_dataset = hf_dataset.with_format("torch")
episode_data_index = { episode_data_index = {
"from": torch.tensor([0]), "from": torch.tensor([0]),
"to": torch.tensor([5]), "to": torch.tensor([5]),
} }
delta_timestamps = {"index": [-0.2, 0, 0.139]} delta_timestamps = {"index": [-0.2, 0, 0.139]}
tol = 0.04 tol = 0.04
@@ -187,13 +186,13 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range(
{ {
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4], "index": [0, 1, 2, 3, 4],
"episode_id": [0, 0, 0, 0, 0], "episode_index": [0, 0, 0, 0, 0],
} }
) )
hf_dataset = hf_dataset.with_format("torch") hf_dataset = hf_dataset.with_format("torch")
episode_data_index = { episode_data_index = {
"from": torch.tensor([0]), "from": torch.tensor([0]),
"to": torch.tensor([5]), "to": torch.tensor([5]),
} }
delta_timestamps = {"index": [-0.2, 0, 0.141]} delta_timestamps = {"index": [-0.2, 0, 0.141]}
tol = 0.04 tol = 0.04
@@ -207,13 +206,13 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range
{ {
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4], "index": [0, 1, 2, 3, 4],
"episode_id": [0, 0, 0, 0, 0], "episode_index": [0, 0, 0, 0, 0],
} }
) )
hf_dataset = hf_dataset.with_format("torch") hf_dataset = hf_dataset.with_format("torch")
episode_data_index = { episode_data_index = {
"from": torch.tensor([0]), "from": torch.tensor([0]),
"to": torch.tensor([5]), "to": torch.tensor([5]),
} }
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]} delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
tol = 0.04 tol = 0.04
@@ -224,7 +223,7 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range
assert torch.equal( assert torch.equal(
is_pad, torch.tensor([True, False, False, True, True]) is_pad, torch.tensor([True, False, False, True, True])
), "Padding does not match expected values" ), "Padding does not match expected values"
def test_flatten_unflatten_dict(): def test_flatten_unflatten_dict():
d = { d = {