Add download_and_upload_dataset.py in script, update all datasets, update online training

This commit is contained in:
Cadene
2024-04-15 21:26:33 +00:00
parent c6aca7fe44
commit 67d79732f9
8 changed files with 493 additions and 519 deletions

View File

@@ -1,72 +1,17 @@
import logging
from pathlib import Path
import einops
import gdown
import h5py
import torch
import tqdm
from datasets import load_dataset
from lerobot.common.datasets.utils import load_data_with_delta_timestamps
FOLDER_URLS = {
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
"aloha_sim_transfer_cube_human": "https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo",
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj",
}
EP48_URLS = {
"aloha_sim_insertion_human": "https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link",
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link",
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link",
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link",
}
EP49_URLS = {
"aloha_sim_insertion_human": "https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link",
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link",
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link",
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link",
}
NUM_EPISODES = {
"aloha_sim_insertion_human": 50,
"aloha_sim_insertion_scripted": 50,
"aloha_sim_transfer_cube_human": 50,
"aloha_sim_transfer_cube_scripted": 50,
}
EPISODE_LEN = {
"aloha_sim_insertion_human": 500,
"aloha_sim_insertion_scripted": 400,
"aloha_sim_transfer_cube_human": 400,
"aloha_sim_transfer_cube_scripted": 400,
}
CAMERAS = {
"aloha_sim_insertion_human": ["top"],
"aloha_sim_insertion_scripted": ["top"],
"aloha_sim_transfer_cube_human": ["top"],
"aloha_sim_transfer_cube_scripted": ["top"],
}
def download(data_dir, dataset_id):
assert dataset_id in FOLDER_URLS
assert dataset_id in EP48_URLS
assert dataset_id in EP49_URLS
data_dir.mkdir(parents=True, exist_ok=True)
gdown.download_folder(FOLDER_URLS[dataset_id], output=str(data_dir))
# because of the 50 files limit per directory, two files episode 48 and 49 were missing
gdown.download(EP48_URLS[dataset_id], output=str(data_dir / "episode_48.hdf5"), fuzzy=True)
gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True)
from lerobot.common.datasets.utils import load_previous_and_future_frames
class AlohaDataset(torch.utils.data.Dataset):
"""
https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human
https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted
https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human
https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted
"""
available_datasets = [
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
@@ -79,129 +24,40 @@ class AlohaDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset_id: str,
version: str | None = "v1.2",
root: Path | None = None,
version: str | None = "v1.0",
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.dataset_id = dataset_id
self.version = version
self.root = root
self.transform = transform
self.delta_timestamps = delta_timestamps
self.data_dir = self.root / f"{self.dataset_id}"
if (self.data_dir / "data_dict.pth").exists() and (
self.data_dir / "data_ids_per_episode.pth"
).exists():
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
else:
self._download_and_preproc_obsolete()
self.data_dir.mkdir(parents=True, exist_ok=True)
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train")
self.data_dict = self.data_dict.with_format("torch")
@property
def num_samples(self) -> int:
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
return len(self.data_dict)
@property
def num_episodes(self) -> int:
return len(self.data_ids_per_episode)
return len(self.data_dict.unique("episode_id"))
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
item = {}
item = self.data_dict[idx]
# get episode id and timestamp of the sampled frame
current_ts = self.data_dict["timestamp"][idx].item()
episode = self.data_dict["episode"][idx].item()
for key in self.data_dict:
if self.delta_timestamps is not None and key in self.delta_timestamps:
data, is_pad = load_data_with_delta_timestamps(
self.data_dict,
self.data_ids_per_episode,
self.delta_timestamps,
key,
current_ts,
episode,
)
item[key] = data
item[f"{key}_is_pad"] = is_pad
else:
item[key] = self.data_dict[key][idx]
if self.delta_timestamps is not None:
item = load_previous_and_future_frames(
item,
self.data_dict,
self.delta_timestamps,
)
if self.transform is not None:
item = self.transform(item)
return item
def _download_and_preproc_obsolete(self):
assert self.root is not None
raw_dir = self.root / f"{self.dataset_id}_raw"
if not raw_dir.is_dir():
download(raw_dir, self.dataset_id)
total_frames = 0
logging.info("Compute total number of frames to initialize offline buffer")
for ep_id in range(NUM_EPISODES[self.dataset_id]):
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
with h5py.File(ep_path, "r") as ep:
total_frames += ep["/action"].shape[0] - 1
logging.info(f"{total_frames=}")
self.data_ids_per_episode = {}
ep_dicts = []
frame_idx = 0
for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])):
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
with h5py.File(ep_path, "r") as ep:
num_frames = ep["/action"].shape[0]
# last step of demonstration is considered done
done = torch.zeros(num_frames, dtype=torch.bool)
done[-1] = True
state = torch.from_numpy(ep["/observations/qpos"][:])
action = torch.from_numpy(ep["/action"][:])
ep_dict = {
"observation.state": state,
"action": action,
"episode": torch.tensor([ep_id] * num_frames),
"frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
# "next.observation.state": state,
# TODO(rcadene): compute reward and success
# "next.reward": reward[1:],
"next.done": done[1:],
# "next.success": success[1:],
}
for cam in CAMERAS[self.dataset_id]:
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
ep_dict[f"observation.images.{cam}"] = image[:-1]
# ep_dict[f"next.observation.images.{cam}"] = image[1:]
assert isinstance(ep_id, int)
self.data_ids_per_episode[ep_id] = torch.arange(frame_idx, frame_idx + num_frames, 1)
assert len(self.data_ids_per_episode[ep_id]) == num_frames
ep_dicts.append(ep_dict)
frame_idx += num_frames
self.data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
self.data_dict["index"] = torch.arange(0, total_frames, 1)

View File

@@ -1,5 +1,4 @@
import logging
import os
from pathlib import Path
import torch
@@ -8,11 +7,6 @@ from torchvision.transforms import v2
from lerobot.common.datasets.utils import compute_stats
from lerobot.common.transforms import NormalizeTransform, Prod
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
# to load a subset of our datasets for faster continuous integration.
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
def make_dataset(
cfg,
@@ -57,12 +51,11 @@ def make_dataset(
# instantiate a one frame dataset with light transform
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
root=DATA_DIR,
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
)
# load stats if the file exists already or compute stats and save it
precomputed_stats_path = stats_dataset.data_dir / "stats.pth"
precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth"
if precomputed_stats_path.exists():
stats = torch.load(precomputed_stats_path)
else:
@@ -94,7 +87,6 @@ def make_dataset(
dataset = clsfunc(
dataset_id=cfg.dataset_id,
root=DATA_DIR,
delta_timestamps=delta_timestamps,
transform=transforms,
)

View File

@@ -1,25 +1,12 @@
from pathlib import Path
import einops
import numpy as np
import torch
import tqdm
from datasets import Dataset, load_dataset
from datasets import load_dataset
from lerobot.common.datasets._diffusion_policy_replay_buffer import (
ReplayBuffer as DiffusionPolicyReplayBuffer,
)
from lerobot.common.datasets.utils import download_and_extract_zip, load_previous_and_future_frames
# as define in env
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
from lerobot.common.datasets.utils import load_previous_and_future_frames
class PushtDataset(torch.utils.data.Dataset):
"""
https://huggingface.co/datasets/lerobot/pusht
Arguments
----------
@@ -35,32 +22,19 @@ class PushtDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset_id: str,
version: str | None = "v1.2",
root: Path | None = None,
version: str | None = "v1.0",
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.dataset_id = dataset_id
self.version = version
self.root = root
self.transform = transform
self.delta_timestamps = delta_timestamps
# self.data_dir = self.root / f"{self.dataset_id}"
# if (self.data_dir / "data_dict.pth").exists() and (
# self.data_dir / "data_ids_per_episode.pth"
# ).exists():
# self.data_dict = torch.load(self.data_dir / "data_dict.pth")
# self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
# else:
# self._download_and_preproc_obsolete()
# self.data_dir.mkdir(parents=True, exist_ok=True)
# torch.save(self.data_dict, self.data_dir / "data_dict.pth")
# torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
self.data_dict = load_dataset("lerobot/pusht", split="train")
# self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train")
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", split="train")
self.data_dict = self.data_dict.with_format("torch")
self.data_dict.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
@property
def num_samples(self) -> int:
@@ -87,135 +61,3 @@ class PushtDataset(torch.utils.data.Dataset):
item = self.transform(item)
return item
def _download_and_preproc_obsolete(self):
try:
import pymunk
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
except ModuleNotFoundError as e:
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
raise e
assert self.root is not None
raw_dir = self.root / f"{self.dataset_id}_raw"
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
if not zarr_path.is_dir():
raw_dir.mkdir(parents=True, exist_ok=True)
download_and_extract_zip(PUSHT_URL, raw_dir)
# load
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(
zarr_path
) # , keys=['img', 'state', 'action'])
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
total_frames = dataset_dict["action"].shape[0]
# to create test artifact
# num_episodes = 1
# total_frames = 50
assert len(
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
), "Some data type dont have the same number of total frames."
# TODO: verify that goal pose is expected to be fixed
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
imgs = torch.from_numpy(dataset_dict["img"])
imgs = einops.rearrange(imgs, "b h w c -> b c h w")
states = torch.from_numpy(dataset_dict["state"])
actions = torch.from_numpy(dataset_dict["action"])
self.data_ids_per_episode = {}
ep_dicts = []
idx0 = 0
for episode_id in tqdm.tqdm(range(num_episodes)):
idx1 = dataset_dict.meta["episode_ends"][episode_id]
num_frames = idx1 - idx0
assert (episode_ids[idx0:idx1] == episode_id).all()
image = imgs[idx0:idx1]
assert image.min() >= 0.0
assert image.max() <= 255.0
image = image.type(torch.uint8)
state = states[idx0:idx1]
agent_pos = state[:, :2]
block_pos = state[:, 2:4]
block_angle = state[:, 4]
reward = torch.zeros(num_frames)
success = torch.zeros(num_frames, dtype=torch.bool)
done = torch.zeros(num_frames, dtype=torch.bool)
for i in range(num_frames):
space = pymunk.Space()
space.gravity = 0, 0
space.damping = 0
# Add walls.
walls = [
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
]
space.add(*walls)
block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
intersection_area = goal_geom.intersection(block_geom).area
goal_area = goal_geom.area
coverage = intersection_area / goal_area
reward[i] = np.clip(coverage / SUCCESS_THRESHOLD, 0, 1)
success[i] = coverage > SUCCESS_THRESHOLD
# last step of demonstration is considered done
done[-1] = True
ep_dict = {
"observation.image": image,
"observation.state": agent_pos,
"action": actions[idx0:idx1],
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
# "next.observation.image": image[1:],
# "next.observation.state": agent_pos[1:],
# TODO(rcadene): verify that reward and done are aligned with image and agent_pos
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
"next.done": torch.cat([done[1:], done[[-1]]]),
"next.success": torch.cat([success[1:], success[[-1]]]),
}
ep_dicts.append(ep_dict)
assert isinstance(episode_id, int)
self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
assert len(self.data_ids_per_episode[episode_id]) == num_frames
idx0 = idx1
self.data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
self.data_dict["index"] = torch.arange(0, total_frames, 1)
dataset = Dataset.from_dict(self.data_dict)
dataset = dataset.with_format("torch")
def add_episode_data_id_from_to(frame):
ep_id = frame["episode"].item()
frame["episode_data_id_from"] = self.data_ids_per_episode[ep_id][0]
frame["episode_data_id_to"] = self.data_ids_per_episode[ep_id][-1]
return frame
dataset = dataset.map(add_episode_data_id_from_to, num_proc=4)
dataset = dataset.rename_column("episode", "episode_id")
dataset.push_to_hub("lerobot/pusht", token=True)

View File

@@ -1,39 +1,11 @@
import io
import zipfile
from copy import deepcopy
from math import ceil
from pathlib import Path
import einops
import requests
import torch
import tqdm
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
print(f"downloading from {url}")
response = requests.get(url, stream=True)
if response.status_code == 200:
total_size = int(response.headers.get("content-length", 0))
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
zip_file = io.BytesIO()
for chunk in response.iter_content(chunk_size=1024):
if chunk:
zip_file.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()
zip_file.seek(0)
with zipfile.ZipFile(zip_file, "r") as zip_ref:
zip_ref.extractall(destination_folder)
return True
else:
return False
def load_previous_and_future_frames(
item: dict[torch.Tensor],
data_dict: dict[torch.Tensor],

View File

@@ -1,30 +1,14 @@
import pickle
import zipfile
from pathlib import Path
import torch
import tqdm
from datasets import load_dataset
from lerobot.common.datasets.utils import load_data_with_delta_timestamps
def download(raw_dir):
import gdown
raw_dir.mkdir(parents=True, exist_ok=True)
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
zip_path = raw_dir / "data.zip"
gdown.download(url, str(zip_path), quiet=False)
print("Extracting...")
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
for member in zip_f.namelist():
if member.startswith("data/xarm") and member.endswith(".pkl"):
print(member)
zip_f.extract(member=member)
zip_path.unlink()
from lerobot.common.datasets.utils import load_previous_and_future_frames
class XarmDataset(torch.utils.data.Dataset):
"""
https://huggingface.co/datasets/lerobot/xarm_lift_medium
"""
available_datasets = [
"xarm_lift_medium",
]
@@ -34,130 +18,40 @@ class XarmDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset_id: str,
version: str | None = "v1.1",
root: Path | None = None,
version: str | None = "v1.0",
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.dataset_id = dataset_id
self.version = version
self.root = root
self.transform = transform
self.delta_timestamps = delta_timestamps
self.data_dir = self.root / f"{self.dataset_id}"
if (self.data_dir / "data_dict.pth").exists() and (
self.data_dir / "data_ids_per_episode.pth"
).exists():
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
else:
self._download_and_preproc_obsolete()
self.data_dir.mkdir(parents=True, exist_ok=True)
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train")
self.data_dict = self.data_dict.with_format("torch")
@property
def num_samples(self) -> int:
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
return len(self.data_dict)
@property
def num_episodes(self) -> int:
return len(self.data_ids_per_episode)
return len(self.data_dict.unique("episode_id"))
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
item = {}
item = self.data_dict[idx]
# get episode id and timestamp of the sampled frame
current_ts = self.data_dict["timestamp"][idx].item()
episode = self.data_dict["episode"][idx].item()
for key in self.data_dict:
if self.delta_timestamps is not None and key in self.delta_timestamps:
data, is_pad = load_data_with_delta_timestamps(
self.data_dict,
self.data_ids_per_episode,
self.delta_timestamps,
key,
current_ts,
episode,
)
item[key] = data
item[f"{key}_is_pad"] = is_pad
else:
item[key] = self.data_dict[key][idx]
if self.delta_timestamps is not None:
item = load_previous_and_future_frames(
item,
self.data_dict,
self.delta_timestamps,
)
if self.transform is not None:
item = self.transform(item)
return item
def _download_and_preproc_obsolete(self):
assert self.root is not None
raw_dir = self.root / f"{self.dataset_id}_raw"
if not raw_dir.exists():
download(raw_dir)
dataset_path = self.root / f"{self.dataset_id}" / "buffer.pkl"
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
total_frames = dataset_dict["actions"].shape[0]
self.data_ids_per_episode = {}
ep_dicts = []
idx0 = 0
idx1 = 0
episode_id = 0
for i in tqdm.tqdm(range(total_frames)):
idx1 += 1
if not dataset_dict["dones"][i]:
continue
num_frames = idx1 - idx0
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
action = torch.tensor(dataset_dict["actions"][idx0:idx1])
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
# it is critical to have this frame for tdmpc to predict a "done observation/state"
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
ep_dict = {
"observation.image": image,
"observation.state": state,
"action": action,
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
# "next.observation.image": next_image,
# "next.observation.state": next_state,
"next.reward": next_reward,
"next.done": next_done,
}
ep_dicts.append(ep_dict)
assert isinstance(episode_id, int)
self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
assert len(self.data_ids_per_episode[episode_id]) == num_frames
idx0 = idx1
episode_id += 1
self.data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
self.data_dict["index"] = torch.arange(0, total_frames, 1)