forked from tangger/lerobot
fix(lerobot/common/datasets): remove lint warnings/errors
This commit is contained in:
@@ -108,7 +108,7 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
|
|||||||
|
|
||||||
|
|
||||||
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
||||||
for i in range(len(stats_list)):
|
for i in enumerate(stats_list):
|
||||||
for fkey in stats_list[i]:
|
for fkey in stats_list[i]:
|
||||||
for k, v in stats_list[i][fkey].items():
|
for k, v in stats_list[i][fkey].items():
|
||||||
if not isinstance(v, np.ndarray):
|
if not isinstance(v, np.ndarray):
|
||||||
|
|||||||
@@ -13,8 +13,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
|
||||||
from pprint import pformat
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -98,17 +96,17 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||||
dataset = MultiLeRobotDataset(
|
# dataset = MultiLeRobotDataset(
|
||||||
cfg.dataset.repo_id,
|
# cfg.dataset.repo_id,
|
||||||
# TODO(aliberts): add proper support for multi dataset
|
# # TODO(aliberts): add proper support for multi dataset
|
||||||
# delta_timestamps=delta_timestamps,
|
# # delta_timestamps=delta_timestamps,
|
||||||
image_transforms=image_transforms,
|
# image_transforms=image_transforms,
|
||||||
video_backend=cfg.dataset.video_backend,
|
# video_backend=cfg.dataset.video_backend,
|
||||||
)
|
# )
|
||||||
logging.info(
|
# logging.info(
|
||||||
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
# "Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||||
f"{pformat(dataset.repo_id_to_index, indent=2)}"
|
# f"{pformat(dataset.repo_id_to_index, indent=2)}"
|
||||||
)
|
# )
|
||||||
|
|
||||||
if cfg.dataset.use_imagenet_stats:
|
if cfg.dataset.use_imagenet_stats:
|
||||||
for key in dataset.meta.camera_keys:
|
for key in dataset.meta.camera_keys:
|
||||||
|
|||||||
@@ -81,21 +81,21 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
|
|||||||
print(f"Error writing image {fpath}: {e}")
|
print(f"Error writing image {fpath}: {e}")
|
||||||
|
|
||||||
|
|
||||||
def worker_thread_loop(queue: queue.Queue):
|
def worker_thread_loop(task_queue: queue.Queue):
|
||||||
while True:
|
while True:
|
||||||
item = queue.get()
|
item = task_queue.get()
|
||||||
if item is None:
|
if item is None:
|
||||||
queue.task_done()
|
task_queue.task_done()
|
||||||
break
|
break
|
||||||
image_array, fpath = item
|
image_array, fpath = item
|
||||||
write_image(image_array, fpath)
|
write_image(image_array, fpath)
|
||||||
queue.task_done()
|
task_queue.task_done()
|
||||||
|
|
||||||
|
|
||||||
def worker_process(queue: queue.Queue, num_threads: int):
|
def worker_process(task_queue: queue.Queue, num_threads: int):
|
||||||
threads = []
|
threads = []
|
||||||
for _ in range(num_threads):
|
for _ in range(num_threads):
|
||||||
t = threading.Thread(target=worker_thread_loop, args=(queue,))
|
t = threading.Thread(target=worker_thread_loop, args=(task_queue,))
|
||||||
t.daemon = True
|
t.daemon = True
|
||||||
t.start()
|
t.start()
|
||||||
threads.append(t)
|
threads.append(t)
|
||||||
|
|||||||
@@ -87,6 +87,7 @@ class LeRobotDatasetMetadata:
|
|||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
self.revision = revision if revision else CODEBASE_VERSION
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||||
|
self.stats = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if force_cache_sync:
|
if force_cache_sync:
|
||||||
@@ -102,10 +103,10 @@ class LeRobotDatasetMetadata:
|
|||||||
|
|
||||||
def load_metadata(self):
|
def load_metadata(self):
|
||||||
self.info = load_info(self.root)
|
self.info = load_info(self.root)
|
||||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
check_version_compatibility(self.repo_id, self.version, CODEBASE_VERSION)
|
||||||
self.tasks, self.task_to_task_index = load_tasks(self.root)
|
self.tasks, self.task_to_task_index = load_tasks(self.root)
|
||||||
self.episodes = load_episodes(self.root)
|
self.episodes = load_episodes(self.root)
|
||||||
if self._version < packaging.version.parse("v2.1"):
|
if self.version < packaging.version.parse("v2.1"):
|
||||||
self.stats = load_stats(self.root)
|
self.stats = load_stats(self.root)
|
||||||
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
||||||
else:
|
else:
|
||||||
@@ -127,7 +128,7 @@ class LeRobotDatasetMetadata:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _version(self) -> packaging.version.Version:
|
def version(self) -> packaging.version.Version:
|
||||||
"""Codebase version used to create this dataset."""
|
"""Codebase version used to create this dataset."""
|
||||||
return packaging.version.parse(self.info["codebase_version"])
|
return packaging.version.parse(self.info["codebase_version"])
|
||||||
|
|
||||||
@@ -321,8 +322,9 @@ class LeRobotDatasetMetadata:
|
|||||||
robot_type = robot.robot_type
|
robot_type = robot.robot_type
|
||||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
"Some cameras in your %s robot don't have an fps matching the fps of your dataset."
|
||||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks."
|
"In this case, frames from lower fps cameras will be repeated to fill in the blanks.",
|
||||||
|
robot.robot_type,
|
||||||
)
|
)
|
||||||
elif features is None:
|
elif features is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -486,7 +488,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self.meta = LeRobotDatasetMetadata(
|
self.meta = LeRobotDatasetMetadata(
|
||||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||||
)
|
)
|
||||||
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
|
if self.episodes is not None and self.meta.version >= packaging.version.parse("v2.1"):
|
||||||
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
||||||
self.stats = aggregate_stats(episodes_stats)
|
self.stats = aggregate_stats(episodes_stats)
|
||||||
|
|
||||||
@@ -518,7 +520,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self,
|
self,
|
||||||
branch: str | None = None,
|
branch: str | None = None,
|
||||||
tags: list | None = None,
|
tags: list | None = None,
|
||||||
license: str | None = "apache-2.0",
|
dataset_license: str | None = "apache-2.0",
|
||||||
tag_version: bool = True,
|
tag_version: bool = True,
|
||||||
push_videos: bool = True,
|
push_videos: bool = True,
|
||||||
private: bool = False,
|
private: bool = False,
|
||||||
@@ -561,7 +563,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
|
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
|
||||||
card = create_lerobot_dataset_card(
|
card = create_lerobot_dataset_card(
|
||||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
tags=tags, dataset_info=self.meta.info, license=dataset_license, **card_kwargs
|
||||||
)
|
)
|
||||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
||||||
|
|
||||||
@@ -842,6 +844,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
|
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
|
||||||
None.
|
None.
|
||||||
"""
|
"""
|
||||||
|
episode_buffer = None
|
||||||
if not episode_data:
|
if not episode_data:
|
||||||
episode_buffer = self.episode_buffer
|
episode_buffer = self.episode_buffer
|
||||||
|
|
||||||
@@ -1086,8 +1089,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||||||
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
||||||
extra_keys = set(ds.features).difference(intersection_features)
|
extra_keys = set(ds.features).difference(intersection_features)
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
"keys %s of %s were disabled as they are not contained in all the other datasets.",
|
||||||
"other datasets."
|
extra_keys,
|
||||||
|
repo_id,
|
||||||
)
|
)
|
||||||
self.disabled_features.update(extra_keys)
|
self.disabled_features.update(extra_keys)
|
||||||
|
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ def rechunk_recompress_array(group, name, chunks=None, chunk_length=None, compre
|
|||||||
# rechunk recompress
|
# rechunk recompress
|
||||||
group.move(name, tmp_key)
|
group.move(name, tmp_key)
|
||||||
old_arr = group[tmp_key]
|
old_arr = group[tmp_key]
|
||||||
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy(
|
||||||
source=old_arr,
|
source=old_arr,
|
||||||
dest=group,
|
dest=group,
|
||||||
name=name,
|
name=name,
|
||||||
@@ -192,7 +192,7 @@ class ReplayBuffer:
|
|||||||
else:
|
else:
|
||||||
root = zarr.group(store=store)
|
root = zarr.group(store=store)
|
||||||
# copy without recompression
|
# copy without recompression
|
||||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
|
||||||
source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists
|
source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists
|
||||||
)
|
)
|
||||||
data_group = root.create_group("data", overwrite=True)
|
data_group = root.create_group("data", overwrite=True)
|
||||||
@@ -205,7 +205,7 @@ class ReplayBuffer:
|
|||||||
if cks == value.chunks and cpr == value.compressor:
|
if cks == value.chunks and cpr == value.compressor:
|
||||||
# copy without recompression
|
# copy without recompression
|
||||||
this_path = "/data/" + key
|
this_path = "/data/" + key
|
||||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
|
||||||
source=src_store,
|
source=src_store,
|
||||||
dest=store,
|
dest=store,
|
||||||
source_path=this_path,
|
source_path=this_path,
|
||||||
@@ -214,7 +214,7 @@ class ReplayBuffer:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# copy with recompression
|
# copy with recompression
|
||||||
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy(
|
||||||
source=value,
|
source=value,
|
||||||
dest=data_group,
|
dest=data_group,
|
||||||
name=key,
|
name=key,
|
||||||
@@ -275,7 +275,7 @@ class ReplayBuffer:
|
|||||||
compressors = {}
|
compressors = {}
|
||||||
if self.backend == "zarr":
|
if self.backend == "zarr":
|
||||||
# recompression free copy
|
# recompression free copy
|
||||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
|
||||||
source=self.root.store,
|
source=self.root.store,
|
||||||
dest=store,
|
dest=store,
|
||||||
source_path="/meta",
|
source_path="/meta",
|
||||||
@@ -297,7 +297,7 @@ class ReplayBuffer:
|
|||||||
if cks == value.chunks and cpr == value.compressor:
|
if cks == value.chunks and cpr == value.compressor:
|
||||||
# copy without recompression
|
# copy without recompression
|
||||||
this_path = "/data/" + key
|
this_path = "/data/" + key
|
||||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
|
||||||
source=self.root.store,
|
source=self.root.store,
|
||||||
dest=store,
|
dest=store,
|
||||||
source_path=this_path,
|
source_path=this_path,
|
||||||
|
|||||||
@@ -162,9 +162,9 @@ def download_raw(raw_dir: Path, repo_id: str):
|
|||||||
)
|
)
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
|
logging.info("Start downloading from huggingface.co/%s for %s", user_id, dataset_id)
|
||||||
snapshot_download(repo_id, repo_type="dataset", local_dir=raw_dir)
|
snapshot_download(repo_id, repo_type="dataset", local_dir=raw_dir)
|
||||||
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")
|
logging.info("Finish downloading from huggingface.co/%s for %s", user_id, dataset_id)
|
||||||
|
|
||||||
|
|
||||||
def download_all_raw_datasets(data_dir: Path | None = None):
|
def download_all_raw_datasets(data_dir: Path | None = None):
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ def check_format(raw_dir) -> bool:
|
|||||||
assert data[f"/observations/images/{camera}"].ndim == 2
|
assert data[f"/observations/images/{camera}"].ndim == 2
|
||||||
else:
|
else:
|
||||||
assert data[f"/observations/images/{camera}"].ndim == 4
|
assert data[f"/observations/images/{camera}"].ndim == 4
|
||||||
b, h, w, c = data[f"/observations/images/{camera}"].shape
|
_, h, w, c = data[f"/observations/images/{camera}"].shape
|
||||||
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
||||||
|
|
||||||
|
|
||||||
@@ -103,6 +103,7 @@ def load_from_raw(
|
|||||||
|
|
||||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||||
action = torch.from_numpy(ep["/action"][:])
|
action = torch.from_numpy(ep["/action"][:])
|
||||||
|
velocity = None
|
||||||
if "/observations/qvel" in ep:
|
if "/observations/qvel" in ep:
|
||||||
velocity = torch.from_numpy(ep["/observations/qvel"][:])
|
velocity = torch.from_numpy(ep["/observations/qvel"][:])
|
||||||
if "/observations/effort" in ep:
|
if "/observations/effort" in ep:
|
||||||
|
|||||||
@@ -96,6 +96,7 @@ def from_raw_to_lerobot_format(
|
|||||||
if fps is None:
|
if fps is None:
|
||||||
fps = 30
|
fps = 30
|
||||||
|
|
||||||
|
# TODO(Steven): Is this meant to call cam_png_format.load_from_raw?
|
||||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
|
|||||||
@@ -42,7 +42,9 @@ def check_format(raw_dir) -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
def load_from_raw(
|
||||||
|
raw_dir: Path, videos_dir: Path, fps: int, _video: bool, _episodes: list[int] | None = None
|
||||||
|
):
|
||||||
# Load data stream that will be used as reference for the timestamps synchronization
|
# Load data stream that will be used as reference for the timestamps synchronization
|
||||||
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
|
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
|
||||||
if len(reference_files) == 0:
|
if len(reference_files) == 0:
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers
|
|||||||
|
|
||||||
num_images = len(imgs_array)
|
num_images = len(imgs_array)
|
||||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
_ = [executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
||||||
|
|
||||||
|
|
||||||
def get_default_encoding() -> dict:
|
def get_default_encoding() -> dict:
|
||||||
@@ -92,24 +92,23 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc
|
|||||||
episode_data_index = {"from": [], "to": []}
|
episode_data_index = {"from": [], "to": []}
|
||||||
|
|
||||||
current_episode = None
|
current_episode = None
|
||||||
"""
|
# The episode_index is a list of integers, each representing the episode index of the corresponding example.
|
||||||
The episode_index is a list of integers, each representing the episode index of the corresponding example.
|
# For instance, the following is a valid episode_index:
|
||||||
For instance, the following is a valid episode_index:
|
# [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
|
||||||
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
|
#
|
||||||
|
# Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
|
||||||
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
|
# ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
|
||||||
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
|
# {
|
||||||
{
|
# "from": [0, 3, 7],
|
||||||
"from": [0, 3, 7],
|
# "to": [3, 7, 12]
|
||||||
"to": [3, 7, 12]
|
# }
|
||||||
}
|
|
||||||
"""
|
|
||||||
if len(hf_dataset) == 0:
|
if len(hf_dataset) == 0:
|
||||||
episode_data_index = {
|
episode_data_index = {
|
||||||
"from": torch.tensor([]),
|
"from": torch.tensor([]),
|
||||||
"to": torch.tensor([]),
|
"to": torch.tensor([]),
|
||||||
}
|
}
|
||||||
return episode_data_index
|
return episode_data_index
|
||||||
|
idx = None
|
||||||
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
||||||
if episode_idx != current_episode:
|
if episode_idx != current_episode:
|
||||||
# We encountered a new episode, so we append its starting location to the "from" list
|
# We encountered a new episode, so we append its starting location to the "from" list
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from torchvision.transforms.v2 import Transform
|
|||||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): Missing transform() implementation
|
||||||
class RandomSubsetApply(Transform):
|
class RandomSubsetApply(Transform):
|
||||||
"""Apply a random subset of N transformations from a list of transformations.
|
"""Apply a random subset of N transformations from a list of transformations.
|
||||||
|
|
||||||
@@ -218,6 +219,7 @@ def make_transform_from_config(cfg: ImageTransformConfig):
|
|||||||
raise ValueError(f"Transform '{cfg.type}' is not valid.")
|
raise ValueError(f"Transform '{cfg.type}' is not valid.")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): Missing transform() implementation
|
||||||
class ImageTransforms(Transform):
|
class ImageTransforms(Transform):
|
||||||
"""A class to compose image transforms based on configuration."""
|
"""A class to compose image transforms based on configuration."""
|
||||||
|
|
||||||
|
|||||||
@@ -135,21 +135,21 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
|||||||
|
|
||||||
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
||||||
# Embed image bytes into the table before saving to parquet
|
# Embed image bytes into the table before saving to parquet
|
||||||
format = dataset.format
|
ds_format = dataset.format
|
||||||
dataset = dataset.with_format("arrow")
|
dataset = dataset.with_format("arrow")
|
||||||
dataset = dataset.map(embed_table_storage, batched=False)
|
dataset = dataset.map(embed_table_storage, batched=False)
|
||||||
dataset = dataset.with_format(**format)
|
dataset = dataset.with_format(**ds_format)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def load_json(fpath: Path) -> Any:
|
def load_json(fpath: Path) -> Any:
|
||||||
with open(fpath) as f:
|
with open(fpath, encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
def write_json(data: dict, fpath: Path) -> None:
|
def write_json(data: dict, fpath: Path) -> None:
|
||||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||||
with open(fpath, "w") as f:
|
with open(fpath, "w", encoding="utf-8") as f:
|
||||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
@@ -300,7 +300,7 @@ def check_version_compatibility(
|
|||||||
if v_check.major < v_current.major and enforce_breaking_major:
|
if v_check.major < v_current.major and enforce_breaking_major:
|
||||||
raise BackwardCompatibilityError(repo_id, v_check)
|
raise BackwardCompatibilityError(repo_id, v_check)
|
||||||
elif v_check.minor < v_current.minor:
|
elif v_check.minor < v_current.minor:
|
||||||
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check))
|
logging.warning("%s", V21_MESSAGE.format(repo_id=repo_id, version=v_check))
|
||||||
|
|
||||||
|
|
||||||
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
|
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
|
||||||
@@ -348,7 +348,9 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
|
|||||||
if compatibles:
|
if compatibles:
|
||||||
return_version = max(compatibles)
|
return_version = max(compatibles)
|
||||||
if return_version < target_version:
|
if return_version < target_version:
|
||||||
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
|
logging.warning(
|
||||||
|
"Revision %s for %s not found, using version v%s", version, repo_id, return_version
|
||||||
|
)
|
||||||
return f"v{return_version}"
|
return f"v{return_version}"
|
||||||
|
|
||||||
lower_major = [v for v in hub_versions if v.major < target_version.major]
|
lower_major = [v for v in hub_versions if v.major < target_version.major]
|
||||||
@@ -403,7 +405,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
|||||||
for key, ft in features.items():
|
for key, ft in features.items():
|
||||||
shape = ft["shape"]
|
shape = ft["shape"]
|
||||||
if ft["dtype"] in ["image", "video"]:
|
if ft["dtype"] in ["image", "video"]:
|
||||||
type = FeatureType.VISUAL
|
feature_type = FeatureType.VISUAL
|
||||||
if len(shape) != 3:
|
if len(shape) != 3:
|
||||||
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
|
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
|
||||||
|
|
||||||
@@ -412,16 +414,16 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
|||||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||||
shape = (shape[2], shape[0], shape[1])
|
shape = (shape[2], shape[0], shape[1])
|
||||||
elif key == "observation.environment_state":
|
elif key == "observation.environment_state":
|
||||||
type = FeatureType.ENV
|
feature_type = FeatureType.ENV
|
||||||
elif key.startswith("observation"):
|
elif key.startswith("observation"):
|
||||||
type = FeatureType.STATE
|
feature_type = FeatureType.STATE
|
||||||
elif key == "action":
|
elif key == "action":
|
||||||
type = FeatureType.ACTION
|
feature_type = FeatureType.ACTION
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
policy_features[key] = PolicyFeature(
|
policy_features[key] = PolicyFeature(
|
||||||
type=type,
|
type=feature_type,
|
||||||
shape=shape,
|
shape=shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -871,11 +871,11 @@ def batch_convert():
|
|||||||
try:
|
try:
|
||||||
convert_dataset(repo_id, LOCAL_DIR, **kwargs)
|
convert_dataset(repo_id, LOCAL_DIR, **kwargs)
|
||||||
status = f"{repo_id}: success."
|
status = f"{repo_id}: success."
|
||||||
with open(logfile, "a") as file:
|
with open(logfile, "a", encoding="utf-8") as file:
|
||||||
file.write(status + "\n")
|
file.write(status + "\n")
|
||||||
except Exception:
|
except Exception:
|
||||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||||
with open(logfile, "a") as file:
|
with open(logfile, "a", encoding="utf-8") as file:
|
||||||
file.write(status + "\n")
|
file.write(status + "\n")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -190,11 +190,11 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
|
|||||||
|
|
||||||
json_path = v2_dir / STATS_PATH
|
json_path = v2_dir / STATS_PATH
|
||||||
json_path.parent.mkdir(exist_ok=True, parents=True)
|
json_path.parent.mkdir(exist_ok=True, parents=True)
|
||||||
with open(json_path, "w") as f:
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(serialized_stats, f, indent=4)
|
json.dump(serialized_stats, f, indent=4)
|
||||||
|
|
||||||
# Sanity check
|
# Sanity check
|
||||||
with open(json_path) as f:
|
with open(json_path, encoding="utf-8") as f:
|
||||||
stats_json = json.load(f)
|
stats_json = json.load(f)
|
||||||
|
|
||||||
stats_json = flatten_dict(stats_json)
|
stats_json = flatten_dict(stats_json)
|
||||||
@@ -213,7 +213,7 @@ def get_features_from_hf_dataset(
|
|||||||
dtype = ft.dtype
|
dtype = ft.dtype
|
||||||
shape = (1,)
|
shape = (1,)
|
||||||
names = None
|
names = None
|
||||||
if isinstance(ft, datasets.Sequence):
|
elif isinstance(ft, datasets.Sequence):
|
||||||
assert isinstance(ft.feature, datasets.Value)
|
assert isinstance(ft.feature, datasets.Value)
|
||||||
dtype = ft.feature.dtype
|
dtype = ft.feature.dtype
|
||||||
shape = (ft.length,)
|
shape = (ft.length,)
|
||||||
@@ -232,6 +232,8 @@ def get_features_from_hf_dataset(
|
|||||||
dtype = "video"
|
dtype = "video"
|
||||||
shape = None # Add shape later
|
shape = None # Add shape later
|
||||||
names = ["height", "width", "channels"]
|
names = ["height", "width", "channels"]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Feature type {ft._type} not supported.")
|
||||||
|
|
||||||
features[key] = {
|
features[key] = {
|
||||||
"dtype": dtype,
|
"dtype": dtype,
|
||||||
@@ -358,9 +360,9 @@ def move_videos(
|
|||||||
if len(video_dirs) == 1:
|
if len(video_dirs) == 1:
|
||||||
video_path = video_dirs[0] / video_file
|
video_path = video_dirs[0] / video_file
|
||||||
else:
|
else:
|
||||||
for dir in video_dirs:
|
for v_dir in video_dirs:
|
||||||
if (dir / video_file).is_file():
|
if (v_dir / video_file).is_file():
|
||||||
video_path = dir / video_file
|
video_path = v_dir / video_file
|
||||||
break
|
break
|
||||||
|
|
||||||
video_path.rename(work_dir / target_path)
|
video_path.rename(work_dir / target_path)
|
||||||
@@ -652,6 +654,7 @@ def main():
|
|||||||
if not args.local_dir:
|
if not args.local_dir:
|
||||||
args.local_dir = Path("/tmp/lerobot_dataset_v2")
|
args.local_dir = Path("/tmp/lerobot_dataset_v2")
|
||||||
|
|
||||||
|
robot_config = None
|
||||||
if args.robot is not None:
|
if args.robot is not None:
|
||||||
robot_config = make_robot_config(args.robot)
|
robot_config = make_robot_config(args.robot)
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ def fix_dataset(repo_id: str) -> str:
|
|||||||
return f"{repo_id}: skipped (no diff)"
|
return f"{repo_id}: skipped (no diff)"
|
||||||
|
|
||||||
if diff_meta_parquet:
|
if diff_meta_parquet:
|
||||||
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
|
logging.warning("In info.json not in parquet: %s", meta_features - parquet_features)
|
||||||
assert diff_meta_parquet == {"language_instruction"}
|
assert diff_meta_parquet == {"language_instruction"}
|
||||||
lerobot_metadata.features.pop("language_instruction")
|
lerobot_metadata.features.pop("language_instruction")
|
||||||
write_info(lerobot_metadata.info, lerobot_metadata.root)
|
write_info(lerobot_metadata.info, lerobot_metadata.root)
|
||||||
@@ -79,7 +79,7 @@ def batch_fix():
|
|||||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||||
|
|
||||||
logging.info(status)
|
logging.info(status)
|
||||||
with open(logfile, "a") as file:
|
with open(logfile, "a", encoding="utf-8") as file:
|
||||||
file.write(status + "\n")
|
file.write(status + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ def batch_convert():
|
|||||||
except Exception:
|
except Exception:
|
||||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||||
|
|
||||||
with open(logfile, "a") as file:
|
with open(logfile, "a", encoding="utf-8") as file:
|
||||||
file.write(status + "\n")
|
file.write(status + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -45,6 +45,9 @@ V21 = "v2.1"
|
|||||||
|
|
||||||
|
|
||||||
class SuppressWarnings:
|
class SuppressWarnings:
|
||||||
|
def __init__(self):
|
||||||
|
self.previous_level = None
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.previous_level = logging.getLogger().getEffectiveLevel()
|
self.previous_level = logging.getLogger().getEffectiveLevel()
|
||||||
logging.getLogger().setLevel(logging.ERROR)
|
logging.getLogger().setLevel(logging.ERROR)
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ def decode_video_frames_torchvision(
|
|||||||
for frame in reader:
|
for frame in reader:
|
||||||
current_ts = frame["pts"]
|
current_ts = frame["pts"]
|
||||||
if log_loaded_timestamps:
|
if log_loaded_timestamps:
|
||||||
logging.info(f"frame loaded at timestamp={current_ts:.4f}")
|
logging.info("frame loaded at timestamp=%.4f", current_ts)
|
||||||
loaded_frames.append(frame["data"])
|
loaded_frames.append(frame["data"])
|
||||||
loaded_ts.append(current_ts)
|
loaded_ts.append(current_ts)
|
||||||
if current_ts >= last_ts:
|
if current_ts >= last_ts:
|
||||||
@@ -118,7 +118,7 @@ def decode_video_frames_torchvision(
|
|||||||
closest_ts = loaded_ts[argmin_]
|
closest_ts = loaded_ts[argmin_]
|
||||||
|
|
||||||
if log_loaded_timestamps:
|
if log_loaded_timestamps:
|
||||||
logging.info(f"{closest_ts=}")
|
logging.info("closest_ts=%s", closest_ts)
|
||||||
|
|
||||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||||
closest_frames = closest_frames.type(torch.float32) / 255
|
closest_frames = closest_frames.type(torch.float32) / 255
|
||||||
@@ -227,7 +227,9 @@ def get_audio_info(video_path: Path | str) -> dict:
|
|||||||
"json",
|
"json",
|
||||||
str(video_path),
|
str(video_path),
|
||||||
]
|
]
|
||||||
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
result = subprocess.run(
|
||||||
|
ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True
|
||||||
|
)
|
||||||
if result.returncode != 0:
|
if result.returncode != 0:
|
||||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||||
|
|
||||||
@@ -263,7 +265,9 @@ def get_video_info(video_path: Path | str) -> dict:
|
|||||||
"json",
|
"json",
|
||||||
str(video_path),
|
str(video_path),
|
||||||
]
|
]
|
||||||
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
result = subprocess.run(
|
||||||
|
ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True
|
||||||
|
)
|
||||||
if result.returncode != 0:
|
if result.returncode != 0:
|
||||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user