fix(lerobot/common/datasets): remove lint warnings/errors

This commit is contained in:
Steven Palma
2025-03-07 15:14:06 +01:00
parent 9b380eaf67
commit e59ef036e1
18 changed files with 98 additions and 79 deletions

View File

@@ -108,7 +108,7 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
def _assert_type_and_shape(stats_list: list[dict[str, dict]]): def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
for i in range(len(stats_list)): for i in enumerate(stats_list):
for fkey in stats_list[i]: for fkey in stats_list[i]:
for k, v in stats_list[i][fkey].items(): for k, v in stats_list[i][fkey].items():
if not isinstance(v, np.ndarray): if not isinstance(v, np.ndarray):

View File

@@ -13,8 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from pprint import pformat
import torch import torch
@@ -98,17 +96,17 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
) )
else: else:
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.") raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
dataset = MultiLeRobotDataset( # dataset = MultiLeRobotDataset(
cfg.dataset.repo_id, # cfg.dataset.repo_id,
# TODO(aliberts): add proper support for multi dataset # # TODO(aliberts): add proper support for multi dataset
# delta_timestamps=delta_timestamps, # # delta_timestamps=delta_timestamps,
image_transforms=image_transforms, # image_transforms=image_transforms,
video_backend=cfg.dataset.video_backend, # video_backend=cfg.dataset.video_backend,
) # )
logging.info( # logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: " # "Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
f"{pformat(dataset.repo_id_to_index, indent=2)}" # f"{pformat(dataset.repo_id_to_index, indent=2)}"
) # )
if cfg.dataset.use_imagenet_stats: if cfg.dataset.use_imagenet_stats:
for key in dataset.meta.camera_keys: for key in dataset.meta.camera_keys:

View File

@@ -81,21 +81,21 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
print(f"Error writing image {fpath}: {e}") print(f"Error writing image {fpath}: {e}")
def worker_thread_loop(queue: queue.Queue): def worker_thread_loop(task_queue: queue.Queue):
while True: while True:
item = queue.get() item = task_queue.get()
if item is None: if item is None:
queue.task_done() task_queue.task_done()
break break
image_array, fpath = item image_array, fpath = item
write_image(image_array, fpath) write_image(image_array, fpath)
queue.task_done() task_queue.task_done()
def worker_process(queue: queue.Queue, num_threads: int): def worker_process(task_queue: queue.Queue, num_threads: int):
threads = [] threads = []
for _ in range(num_threads): for _ in range(num_threads):
t = threading.Thread(target=worker_thread_loop, args=(queue,)) t = threading.Thread(target=worker_thread_loop, args=(task_queue,))
t.daemon = True t.daemon = True
t.start() t.start()
threads.append(t) threads.append(t)

View File

@@ -87,6 +87,7 @@ class LeRobotDatasetMetadata:
self.repo_id = repo_id self.repo_id = repo_id
self.revision = revision if revision else CODEBASE_VERSION self.revision = revision if revision else CODEBASE_VERSION
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
self.stats = None
try: try:
if force_cache_sync: if force_cache_sync:
@@ -102,10 +103,10 @@ class LeRobotDatasetMetadata:
def load_metadata(self): def load_metadata(self):
self.info = load_info(self.root) self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) check_version_compatibility(self.repo_id, self.version, CODEBASE_VERSION)
self.tasks, self.task_to_task_index = load_tasks(self.root) self.tasks, self.task_to_task_index = load_tasks(self.root)
self.episodes = load_episodes(self.root) self.episodes = load_episodes(self.root)
if self._version < packaging.version.parse("v2.1"): if self.version < packaging.version.parse("v2.1"):
self.stats = load_stats(self.root) self.stats = load_stats(self.root)
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes) self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
else: else:
@@ -127,7 +128,7 @@ class LeRobotDatasetMetadata:
) )
@property @property
def _version(self) -> packaging.version.Version: def version(self) -> packaging.version.Version:
"""Codebase version used to create this dataset.""" """Codebase version used to create this dataset."""
return packaging.version.parse(self.info["codebase_version"]) return packaging.version.parse(self.info["codebase_version"])
@@ -321,8 +322,9 @@ class LeRobotDatasetMetadata:
robot_type = robot.robot_type robot_type = robot.robot_type
if not all(cam.fps == fps for cam in robot.cameras.values()): if not all(cam.fps == fps for cam in robot.cameras.values()):
logging.warning( logging.warning(
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset." "Some cameras in your %s robot don't have an fps matching the fps of your dataset."
"In this case, frames from lower fps cameras will be repeated to fill in the blanks." "In this case, frames from lower fps cameras will be repeated to fill in the blanks.",
robot.robot_type,
) )
elif features is None: elif features is None:
raise ValueError( raise ValueError(
@@ -486,7 +488,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.meta = LeRobotDatasetMetadata( self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
) )
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"): if self.episodes is not None and self.meta.version >= packaging.version.parse("v2.1"):
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes] episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
self.stats = aggregate_stats(episodes_stats) self.stats = aggregate_stats(episodes_stats)
@@ -518,7 +520,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self, self,
branch: str | None = None, branch: str | None = None,
tags: list | None = None, tags: list | None = None,
license: str | None = "apache-2.0", dataset_license: str | None = "apache-2.0",
tag_version: bool = True, tag_version: bool = True,
push_videos: bool = True, push_videos: bool = True,
private: bool = False, private: bool = False,
@@ -561,7 +563,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch): if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
card = create_lerobot_dataset_card( card = create_lerobot_dataset_card(
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs tags=tags, dataset_info=self.meta.info, license=dataset_license, **card_kwargs
) )
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch) card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
@@ -842,6 +844,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
None. None.
""" """
episode_buffer = None
if not episode_data: if not episode_data:
episode_buffer = self.episode_buffer episode_buffer = self.episode_buffer
@@ -1086,8 +1089,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True): for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
extra_keys = set(ds.features).difference(intersection_features) extra_keys = set(ds.features).difference(intersection_features)
logging.warning( logging.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the " "keys %s of %s were disabled as they are not contained in all the other datasets.",
"other datasets." extra_keys,
repo_id,
) )
self.disabled_features.update(extra_keys) self.disabled_features.update(extra_keys)

View File

@@ -53,7 +53,7 @@ def rechunk_recompress_array(group, name, chunks=None, chunk_length=None, compre
# rechunk recompress # rechunk recompress
group.move(name, tmp_key) group.move(name, tmp_key)
old_arr = group[tmp_key] old_arr = group[tmp_key]
n_copied, n_skipped, n_bytes_copied = zarr.copy( _n_copied, _n_skipped, _n_bytes_copied = zarr.copy(
source=old_arr, source=old_arr,
dest=group, dest=group,
name=name, name=name,
@@ -192,7 +192,7 @@ class ReplayBuffer:
else: else:
root = zarr.group(store=store) root = zarr.group(store=store)
# copy without recompression # copy without recompression
n_copied, n_skipped, n_bytes_copied = zarr.copy_store( _n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists
) )
data_group = root.create_group("data", overwrite=True) data_group = root.create_group("data", overwrite=True)
@@ -205,7 +205,7 @@ class ReplayBuffer:
if cks == value.chunks and cpr == value.compressor: if cks == value.chunks and cpr == value.compressor:
# copy without recompression # copy without recompression
this_path = "/data/" + key this_path = "/data/" + key
n_copied, n_skipped, n_bytes_copied = zarr.copy_store( _n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
source=src_store, source=src_store,
dest=store, dest=store,
source_path=this_path, source_path=this_path,
@@ -214,7 +214,7 @@ class ReplayBuffer:
) )
else: else:
# copy with recompression # copy with recompression
n_copied, n_skipped, n_bytes_copied = zarr.copy( _n_copied, _n_skipped, _n_bytes_copied = zarr.copy(
source=value, source=value,
dest=data_group, dest=data_group,
name=key, name=key,
@@ -275,7 +275,7 @@ class ReplayBuffer:
compressors = {} compressors = {}
if self.backend == "zarr": if self.backend == "zarr":
# recompression free copy # recompression free copy
n_copied, n_skipped, n_bytes_copied = zarr.copy_store( _n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
source=self.root.store, source=self.root.store,
dest=store, dest=store,
source_path="/meta", source_path="/meta",
@@ -297,7 +297,7 @@ class ReplayBuffer:
if cks == value.chunks and cpr == value.compressor: if cks == value.chunks and cpr == value.compressor:
# copy without recompression # copy without recompression
this_path = "/data/" + key this_path = "/data/" + key
n_copied, n_skipped, n_bytes_copied = zarr.copy_store( _n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
source=self.root.store, source=self.root.store,
dest=store, dest=store,
source_path=this_path, source_path=this_path,

View File

@@ -162,9 +162,9 @@ def download_raw(raw_dir: Path, repo_id: str):
) )
raw_dir.mkdir(parents=True, exist_ok=True) raw_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}") logging.info("Start downloading from huggingface.co/%s for %s", user_id, dataset_id)
snapshot_download(repo_id, repo_type="dataset", local_dir=raw_dir) snapshot_download(repo_id, repo_type="dataset", local_dir=raw_dir)
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}") logging.info("Finish downloading from huggingface.co/%s for %s", user_id, dataset_id)
def download_all_raw_datasets(data_dir: Path | None = None): def download_all_raw_datasets(data_dir: Path | None = None):

View File

@@ -72,7 +72,7 @@ def check_format(raw_dir) -> bool:
assert data[f"/observations/images/{camera}"].ndim == 2 assert data[f"/observations/images/{camera}"].ndim == 2
else: else:
assert data[f"/observations/images/{camera}"].ndim == 4 assert data[f"/observations/images/{camera}"].ndim == 4
b, h, w, c = data[f"/observations/images/{camera}"].shape _, h, w, c = data[f"/observations/images/{camera}"].shape
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided." assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
@@ -103,6 +103,7 @@ def load_from_raw(
state = torch.from_numpy(ep["/observations/qpos"][:]) state = torch.from_numpy(ep["/observations/qpos"][:])
action = torch.from_numpy(ep["/action"][:]) action = torch.from_numpy(ep["/action"][:])
velocity = None
if "/observations/qvel" in ep: if "/observations/qvel" in ep:
velocity = torch.from_numpy(ep["/observations/qvel"][:]) velocity = torch.from_numpy(ep["/observations/qvel"][:])
if "/observations/effort" in ep: if "/observations/effort" in ep:

View File

@@ -96,6 +96,7 @@ def from_raw_to_lerobot_format(
if fps is None: if fps is None:
fps = 30 fps = 30
# TODO(Steven): Is this meant to call cam_png_format.load_from_raw?
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes) data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
hf_dataset = to_hf_dataset(data_dict, video) hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset) episode_data_index = calculate_episode_data_index(hf_dataset)

View File

@@ -42,7 +42,9 @@ def check_format(raw_dir) -> bool:
return True return True
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None): def load_from_raw(
raw_dir: Path, videos_dir: Path, fps: int, _video: bool, _episodes: list[int] | None = None
):
# Load data stream that will be used as reference for the timestamps synchronization # Load data stream that will be used as reference for the timestamps synchronization
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet")) reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
if len(reference_files) == 0: if len(reference_files) == 0:

View File

@@ -55,7 +55,7 @@ def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers
num_images = len(imgs_array) num_images = len(imgs_array)
with ThreadPoolExecutor(max_workers=max_workers) as executor: with ThreadPoolExecutor(max_workers=max_workers) as executor:
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)] _ = [executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
def get_default_encoding() -> dict: def get_default_encoding() -> dict:
@@ -92,24 +92,23 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc
episode_data_index = {"from": [], "to": []} episode_data_index = {"from": [], "to": []}
current_episode = None current_episode = None
""" # The episode_index is a list of integers, each representing the episode index of the corresponding example.
The episode_index is a list of integers, each representing the episode index of the corresponding example. # For instance, the following is a valid episode_index:
For instance, the following is a valid episode_index: # [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2] #
# Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and # ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this: # {
{ # "from": [0, 3, 7],
"from": [0, 3, 7], # "to": [3, 7, 12]
"to": [3, 7, 12] # }
}
"""
if len(hf_dataset) == 0: if len(hf_dataset) == 0:
episode_data_index = { episode_data_index = {
"from": torch.tensor([]), "from": torch.tensor([]),
"to": torch.tensor([]), "to": torch.tensor([]),
} }
return episode_data_index return episode_data_index
idx = None
for idx, episode_idx in enumerate(hf_dataset["episode_index"]): for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
if episode_idx != current_episode: if episode_idx != current_episode:
# We encountered a new episode, so we append its starting location to the "from" list # We encountered a new episode, so we append its starting location to the "from" list

View File

@@ -23,6 +23,7 @@ from torchvision.transforms.v2 import Transform
from torchvision.transforms.v2 import functional as F # noqa: N812 from torchvision.transforms.v2 import functional as F # noqa: N812
# TODO(Steven): Missing transform() implementation
class RandomSubsetApply(Transform): class RandomSubsetApply(Transform):
"""Apply a random subset of N transformations from a list of transformations. """Apply a random subset of N transformations from a list of transformations.
@@ -218,6 +219,7 @@ def make_transform_from_config(cfg: ImageTransformConfig):
raise ValueError(f"Transform '{cfg.type}' is not valid.") raise ValueError(f"Transform '{cfg.type}' is not valid.")
# TODO(Steven): Missing transform() implementation
class ImageTransforms(Transform): class ImageTransforms(Transform):
"""A class to compose image transforms based on configuration.""" """A class to compose image transforms based on configuration."""

View File

@@ -135,21 +135,21 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset: def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
# Embed image bytes into the table before saving to parquet # Embed image bytes into the table before saving to parquet
format = dataset.format ds_format = dataset.format
dataset = dataset.with_format("arrow") dataset = dataset.with_format("arrow")
dataset = dataset.map(embed_table_storage, batched=False) dataset = dataset.map(embed_table_storage, batched=False)
dataset = dataset.with_format(**format) dataset = dataset.with_format(**ds_format)
return dataset return dataset
def load_json(fpath: Path) -> Any: def load_json(fpath: Path) -> Any:
with open(fpath) as f: with open(fpath, encoding="utf-8") as f:
return json.load(f) return json.load(f)
def write_json(data: dict, fpath: Path) -> None: def write_json(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True) fpath.parent.mkdir(exist_ok=True, parents=True)
with open(fpath, "w") as f: with open(fpath, "w", encoding="utf-8") as f:
json.dump(data, f, indent=4, ensure_ascii=False) json.dump(data, f, indent=4, ensure_ascii=False)
@@ -300,7 +300,7 @@ def check_version_compatibility(
if v_check.major < v_current.major and enforce_breaking_major: if v_check.major < v_current.major and enforce_breaking_major:
raise BackwardCompatibilityError(repo_id, v_check) raise BackwardCompatibilityError(repo_id, v_check)
elif v_check.minor < v_current.minor: elif v_check.minor < v_current.minor:
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check)) logging.warning("%s", V21_MESSAGE.format(repo_id=repo_id, version=v_check))
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]: def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
@@ -348,7 +348,9 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
if compatibles: if compatibles:
return_version = max(compatibles) return_version = max(compatibles)
if return_version < target_version: if return_version < target_version:
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}") logging.warning(
"Revision %s for %s not found, using version v%s", version, repo_id, return_version
)
return f"v{return_version}" return f"v{return_version}"
lower_major = [v for v in hub_versions if v.major < target_version.major] lower_major = [v for v in hub_versions if v.major < target_version.major]
@@ -403,7 +405,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
for key, ft in features.items(): for key, ft in features.items():
shape = ft["shape"] shape = ft["shape"]
if ft["dtype"] in ["image", "video"]: if ft["dtype"] in ["image", "video"]:
type = FeatureType.VISUAL feature_type = FeatureType.VISUAL
if len(shape) != 3: if len(shape) != 3:
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})") raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
@@ -412,16 +414,16 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1]) shape = (shape[2], shape[0], shape[1])
elif key == "observation.environment_state": elif key == "observation.environment_state":
type = FeatureType.ENV feature_type = FeatureType.ENV
elif key.startswith("observation"): elif key.startswith("observation"):
type = FeatureType.STATE feature_type = FeatureType.STATE
elif key == "action": elif key == "action":
type = FeatureType.ACTION feature_type = FeatureType.ACTION
else: else:
continue continue
policy_features[key] = PolicyFeature( policy_features[key] = PolicyFeature(
type=type, type=feature_type,
shape=shape, shape=shape,
) )

View File

@@ -871,11 +871,11 @@ def batch_convert():
try: try:
convert_dataset(repo_id, LOCAL_DIR, **kwargs) convert_dataset(repo_id, LOCAL_DIR, **kwargs)
status = f"{repo_id}: success." status = f"{repo_id}: success."
with open(logfile, "a") as file: with open(logfile, "a", encoding="utf-8") as file:
file.write(status + "\n") file.write(status + "\n")
except Exception: except Exception:
status = f"{repo_id}: failed\n {traceback.format_exc()}" status = f"{repo_id}: failed\n {traceback.format_exc()}"
with open(logfile, "a") as file: with open(logfile, "a", encoding="utf-8") as file:
file.write(status + "\n") file.write(status + "\n")
continue continue

View File

@@ -190,11 +190,11 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
json_path = v2_dir / STATS_PATH json_path = v2_dir / STATS_PATH
json_path.parent.mkdir(exist_ok=True, parents=True) json_path.parent.mkdir(exist_ok=True, parents=True)
with open(json_path, "w") as f: with open(json_path, "w", encoding="utf-8") as f:
json.dump(serialized_stats, f, indent=4) json.dump(serialized_stats, f, indent=4)
# Sanity check # Sanity check
with open(json_path) as f: with open(json_path, encoding="utf-8") as f:
stats_json = json.load(f) stats_json = json.load(f)
stats_json = flatten_dict(stats_json) stats_json = flatten_dict(stats_json)
@@ -213,7 +213,7 @@ def get_features_from_hf_dataset(
dtype = ft.dtype dtype = ft.dtype
shape = (1,) shape = (1,)
names = None names = None
if isinstance(ft, datasets.Sequence): elif isinstance(ft, datasets.Sequence):
assert isinstance(ft.feature, datasets.Value) assert isinstance(ft.feature, datasets.Value)
dtype = ft.feature.dtype dtype = ft.feature.dtype
shape = (ft.length,) shape = (ft.length,)
@@ -232,6 +232,8 @@ def get_features_from_hf_dataset(
dtype = "video" dtype = "video"
shape = None # Add shape later shape = None # Add shape later
names = ["height", "width", "channels"] names = ["height", "width", "channels"]
else:
raise NotImplementedError(f"Feature type {ft._type} not supported.")
features[key] = { features[key] = {
"dtype": dtype, "dtype": dtype,
@@ -358,9 +360,9 @@ def move_videos(
if len(video_dirs) == 1: if len(video_dirs) == 1:
video_path = video_dirs[0] / video_file video_path = video_dirs[0] / video_file
else: else:
for dir in video_dirs: for v_dir in video_dirs:
if (dir / video_file).is_file(): if (v_dir / video_file).is_file():
video_path = dir / video_file video_path = v_dir / video_file
break break
video_path.rename(work_dir / target_path) video_path.rename(work_dir / target_path)
@@ -652,6 +654,7 @@ def main():
if not args.local_dir: if not args.local_dir:
args.local_dir = Path("/tmp/lerobot_dataset_v2") args.local_dir = Path("/tmp/lerobot_dataset_v2")
robot_config = None
if args.robot is not None: if args.robot is not None:
robot_config = make_robot_config(args.robot) robot_config = make_robot_config(args.robot)

View File

@@ -50,7 +50,7 @@ def fix_dataset(repo_id: str) -> str:
return f"{repo_id}: skipped (no diff)" return f"{repo_id}: skipped (no diff)"
if diff_meta_parquet: if diff_meta_parquet:
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}") logging.warning("In info.json not in parquet: %s", meta_features - parquet_features)
assert diff_meta_parquet == {"language_instruction"} assert diff_meta_parquet == {"language_instruction"}
lerobot_metadata.features.pop("language_instruction") lerobot_metadata.features.pop("language_instruction")
write_info(lerobot_metadata.info, lerobot_metadata.root) write_info(lerobot_metadata.info, lerobot_metadata.root)
@@ -79,7 +79,7 @@ def batch_fix():
status = f"{repo_id}: failed\n {traceback.format_exc()}" status = f"{repo_id}: failed\n {traceback.format_exc()}"
logging.info(status) logging.info(status)
with open(logfile, "a") as file: with open(logfile, "a", encoding="utf-8") as file:
file.write(status + "\n") file.write(status + "\n")

View File

@@ -46,7 +46,7 @@ def batch_convert():
except Exception: except Exception:
status = f"{repo_id}: failed\n {traceback.format_exc()}" status = f"{repo_id}: failed\n {traceback.format_exc()}"
with open(logfile, "a") as file: with open(logfile, "a", encoding="utf-8") as file:
file.write(status + "\n") file.write(status + "\n")

View File

@@ -45,6 +45,9 @@ V21 = "v2.1"
class SuppressWarnings: class SuppressWarnings:
def __init__(self):
self.previous_level = None
def __enter__(self): def __enter__(self):
self.previous_level = logging.getLogger().getEffectiveLevel() self.previous_level = logging.getLogger().getEffectiveLevel()
logging.getLogger().setLevel(logging.ERROR) logging.getLogger().setLevel(logging.ERROR)

View File

@@ -83,7 +83,7 @@ def decode_video_frames_torchvision(
for frame in reader: for frame in reader:
current_ts = frame["pts"] current_ts = frame["pts"]
if log_loaded_timestamps: if log_loaded_timestamps:
logging.info(f"frame loaded at timestamp={current_ts:.4f}") logging.info("frame loaded at timestamp=%.4f", current_ts)
loaded_frames.append(frame["data"]) loaded_frames.append(frame["data"])
loaded_ts.append(current_ts) loaded_ts.append(current_ts)
if current_ts >= last_ts: if current_ts >= last_ts:
@@ -118,7 +118,7 @@ def decode_video_frames_torchvision(
closest_ts = loaded_ts[argmin_] closest_ts = loaded_ts[argmin_]
if log_loaded_timestamps: if log_loaded_timestamps:
logging.info(f"{closest_ts=}") logging.info("closest_ts=%s", closest_ts)
# convert to the pytorch format which is float32 in [0,1] range (and channel first) # convert to the pytorch format which is float32 in [0,1] range (and channel first)
closest_frames = closest_frames.type(torch.float32) / 255 closest_frames = closest_frames.type(torch.float32) / 255
@@ -227,7 +227,9 @@ def get_audio_info(video_path: Path | str) -> dict:
"json", "json",
str(video_path), str(video_path),
] ]
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) result = subprocess.run(
ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True
)
if result.returncode != 0: if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}") raise RuntimeError(f"Error running ffprobe: {result.stderr}")
@@ -263,7 +265,9 @@ def get_video_info(video_path: Path | str) -> dict:
"json", "json",
str(video_path), str(video_path),
] ]
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) result = subprocess.run(
ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True
)
if result.returncode != 0: if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}") raise RuntimeError(f"Error running ffprobe: {result.stderr}")