fix(lerobot/common/datasets): remove lint warnings/errors

This commit is contained in:
Steven Palma
2025-03-07 15:14:06 +01:00
parent 9b380eaf67
commit e59ef036e1
18 changed files with 98 additions and 79 deletions

View File

@@ -108,7 +108,7 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
for i in range(len(stats_list)):
for i in enumerate(stats_list):
for fkey in stats_list[i]:
for k, v in stats_list[i][fkey].items():
if not isinstance(v, np.ndarray):

View File

@@ -13,8 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pprint import pformat
import torch
@@ -98,17 +96,17 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
)
else:
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
dataset = MultiLeRobotDataset(
cfg.dataset.repo_id,
# TODO(aliberts): add proper support for multi dataset
# delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
video_backend=cfg.dataset.video_backend,
)
logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
f"{pformat(dataset.repo_id_to_index, indent=2)}"
)
# dataset = MultiLeRobotDataset(
# cfg.dataset.repo_id,
# # TODO(aliberts): add proper support for multi dataset
# # delta_timestamps=delta_timestamps,
# image_transforms=image_transforms,
# video_backend=cfg.dataset.video_backend,
# )
# logging.info(
# "Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
# f"{pformat(dataset.repo_id_to_index, indent=2)}"
# )
if cfg.dataset.use_imagenet_stats:
for key in dataset.meta.camera_keys:

View File

@@ -81,21 +81,21 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
print(f"Error writing image {fpath}: {e}")
def worker_thread_loop(queue: queue.Queue):
def worker_thread_loop(task_queue: queue.Queue):
while True:
item = queue.get()
item = task_queue.get()
if item is None:
queue.task_done()
task_queue.task_done()
break
image_array, fpath = item
write_image(image_array, fpath)
queue.task_done()
task_queue.task_done()
def worker_process(queue: queue.Queue, num_threads: int):
def worker_process(task_queue: queue.Queue, num_threads: int):
threads = []
for _ in range(num_threads):
t = threading.Thread(target=worker_thread_loop, args=(queue,))
t = threading.Thread(target=worker_thread_loop, args=(task_queue,))
t.daemon = True
t.start()
threads.append(t)

View File

@@ -87,6 +87,7 @@ class LeRobotDatasetMetadata:
self.repo_id = repo_id
self.revision = revision if revision else CODEBASE_VERSION
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
self.stats = None
try:
if force_cache_sync:
@@ -102,10 +103,10 @@ class LeRobotDatasetMetadata:
def load_metadata(self):
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
check_version_compatibility(self.repo_id, self.version, CODEBASE_VERSION)
self.tasks, self.task_to_task_index = load_tasks(self.root)
self.episodes = load_episodes(self.root)
if self._version < packaging.version.parse("v2.1"):
if self.version < packaging.version.parse("v2.1"):
self.stats = load_stats(self.root)
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
else:
@@ -127,7 +128,7 @@ class LeRobotDatasetMetadata:
)
@property
def _version(self) -> packaging.version.Version:
def version(self) -> packaging.version.Version:
"""Codebase version used to create this dataset."""
return packaging.version.parse(self.info["codebase_version"])
@@ -321,8 +322,9 @@ class LeRobotDatasetMetadata:
robot_type = robot.robot_type
if not all(cam.fps == fps for cam in robot.cameras.values()):
logging.warning(
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
"In this case, frames from lower fps cameras will be repeated to fill in the blanks."
"Some cameras in your %s robot don't have an fps matching the fps of your dataset."
"In this case, frames from lower fps cameras will be repeated to fill in the blanks.",
robot.robot_type,
)
elif features is None:
raise ValueError(
@@ -486,7 +488,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
)
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
if self.episodes is not None and self.meta.version >= packaging.version.parse("v2.1"):
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
self.stats = aggregate_stats(episodes_stats)
@@ -518,7 +520,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self,
branch: str | None = None,
tags: list | None = None,
license: str | None = "apache-2.0",
dataset_license: str | None = "apache-2.0",
tag_version: bool = True,
push_videos: bool = True,
private: bool = False,
@@ -561,7 +563,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
card = create_lerobot_dataset_card(
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
tags=tags, dataset_info=self.meta.info, license=dataset_license, **card_kwargs
)
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
@@ -842,6 +844,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
None.
"""
episode_buffer = None
if not episode_data:
episode_buffer = self.episode_buffer
@@ -1086,8 +1089,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
extra_keys = set(ds.features).difference(intersection_features)
logging.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets."
"keys %s of %s were disabled as they are not contained in all the other datasets.",
extra_keys,
repo_id,
)
self.disabled_features.update(extra_keys)

View File

@@ -53,7 +53,7 @@ def rechunk_recompress_array(group, name, chunks=None, chunk_length=None, compre
# rechunk recompress
group.move(name, tmp_key)
old_arr = group[tmp_key]
n_copied, n_skipped, n_bytes_copied = zarr.copy(
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy(
source=old_arr,
dest=group,
name=name,
@@ -192,7 +192,7 @@ class ReplayBuffer:
else:
root = zarr.group(store=store)
# copy without recompression
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists
)
data_group = root.create_group("data", overwrite=True)
@@ -205,7 +205,7 @@ class ReplayBuffer:
if cks == value.chunks and cpr == value.compressor:
# copy without recompression
this_path = "/data/" + key
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
source=src_store,
dest=store,
source_path=this_path,
@@ -214,7 +214,7 @@ class ReplayBuffer:
)
else:
# copy with recompression
n_copied, n_skipped, n_bytes_copied = zarr.copy(
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy(
source=value,
dest=data_group,
name=key,
@@ -275,7 +275,7 @@ class ReplayBuffer:
compressors = {}
if self.backend == "zarr":
# recompression free copy
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
source=self.root.store,
dest=store,
source_path="/meta",
@@ -297,7 +297,7 @@ class ReplayBuffer:
if cks == value.chunks and cpr == value.compressor:
# copy without recompression
this_path = "/data/" + key
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
source=self.root.store,
dest=store,
source_path=this_path,

View File

@@ -162,9 +162,9 @@ def download_raw(raw_dir: Path, repo_id: str):
)
raw_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
logging.info("Start downloading from huggingface.co/%s for %s", user_id, dataset_id)
snapshot_download(repo_id, repo_type="dataset", local_dir=raw_dir)
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")
logging.info("Finish downloading from huggingface.co/%s for %s", user_id, dataset_id)
def download_all_raw_datasets(data_dir: Path | None = None):

View File

@@ -72,7 +72,7 @@ def check_format(raw_dir) -> bool:
assert data[f"/observations/images/{camera}"].ndim == 2
else:
assert data[f"/observations/images/{camera}"].ndim == 4
b, h, w, c = data[f"/observations/images/{camera}"].shape
_, h, w, c = data[f"/observations/images/{camera}"].shape
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
@@ -103,6 +103,7 @@ def load_from_raw(
state = torch.from_numpy(ep["/observations/qpos"][:])
action = torch.from_numpy(ep["/action"][:])
velocity = None
if "/observations/qvel" in ep:
velocity = torch.from_numpy(ep["/observations/qvel"][:])
if "/observations/effort" in ep:

View File

@@ -96,6 +96,7 @@ def from_raw_to_lerobot_format(
if fps is None:
fps = 30
# TODO(Steven): Is this meant to call cam_png_format.load_from_raw?
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)

View File

@@ -42,7 +42,9 @@ def check_format(raw_dir) -> bool:
return True
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
def load_from_raw(
raw_dir: Path, videos_dir: Path, fps: int, _video: bool, _episodes: list[int] | None = None
):
# Load data stream that will be used as reference for the timestamps synchronization
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
if len(reference_files) == 0:

View File

@@ -55,7 +55,7 @@ def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers
num_images = len(imgs_array)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
_ = [executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
def get_default_encoding() -> dict:
@@ -92,24 +92,23 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc
episode_data_index = {"from": [], "to": []}
current_episode = None
"""
The episode_index is a list of integers, each representing the episode index of the corresponding example.
For instance, the following is a valid episode_index:
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
{
"from": [0, 3, 7],
"to": [3, 7, 12]
}
"""
# The episode_index is a list of integers, each representing the episode index of the corresponding example.
# For instance, the following is a valid episode_index:
# [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
#
# Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
# ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
# {
# "from": [0, 3, 7],
# "to": [3, 7, 12]
# }
if len(hf_dataset) == 0:
episode_data_index = {
"from": torch.tensor([]),
"to": torch.tensor([]),
}
return episode_data_index
idx = None
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
if episode_idx != current_episode:
# We encountered a new episode, so we append its starting location to the "from" list

View File

@@ -23,6 +23,7 @@ from torchvision.transforms.v2 import Transform
from torchvision.transforms.v2 import functional as F # noqa: N812
# TODO(Steven): Missing transform() implementation
class RandomSubsetApply(Transform):
"""Apply a random subset of N transformations from a list of transformations.
@@ -218,6 +219,7 @@ def make_transform_from_config(cfg: ImageTransformConfig):
raise ValueError(f"Transform '{cfg.type}' is not valid.")
# TODO(Steven): Missing transform() implementation
class ImageTransforms(Transform):
"""A class to compose image transforms based on configuration."""

View File

@@ -135,21 +135,21 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
# Embed image bytes into the table before saving to parquet
format = dataset.format
ds_format = dataset.format
dataset = dataset.with_format("arrow")
dataset = dataset.map(embed_table_storage, batched=False)
dataset = dataset.with_format(**format)
dataset = dataset.with_format(**ds_format)
return dataset
def load_json(fpath: Path) -> Any:
with open(fpath) as f:
with open(fpath, encoding="utf-8") as f:
return json.load(f)
def write_json(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with open(fpath, "w") as f:
with open(fpath, "w", encoding="utf-8") as f:
json.dump(data, f, indent=4, ensure_ascii=False)
@@ -300,7 +300,7 @@ def check_version_compatibility(
if v_check.major < v_current.major and enforce_breaking_major:
raise BackwardCompatibilityError(repo_id, v_check)
elif v_check.minor < v_current.minor:
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check))
logging.warning("%s", V21_MESSAGE.format(repo_id=repo_id, version=v_check))
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
@@ -348,7 +348,9 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
if compatibles:
return_version = max(compatibles)
if return_version < target_version:
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
logging.warning(
"Revision %s for %s not found, using version v%s", version, repo_id, return_version
)
return f"v{return_version}"
lower_major = [v for v in hub_versions if v.major < target_version.major]
@@ -403,7 +405,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
for key, ft in features.items():
shape = ft["shape"]
if ft["dtype"] in ["image", "video"]:
type = FeatureType.VISUAL
feature_type = FeatureType.VISUAL
if len(shape) != 3:
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
@@ -412,16 +414,16 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1])
elif key == "observation.environment_state":
type = FeatureType.ENV
feature_type = FeatureType.ENV
elif key.startswith("observation"):
type = FeatureType.STATE
feature_type = FeatureType.STATE
elif key == "action":
type = FeatureType.ACTION
feature_type = FeatureType.ACTION
else:
continue
policy_features[key] = PolicyFeature(
type=type,
type=feature_type,
shape=shape,
)

View File

@@ -871,11 +871,11 @@ def batch_convert():
try:
convert_dataset(repo_id, LOCAL_DIR, **kwargs)
status = f"{repo_id}: success."
with open(logfile, "a") as file:
with open(logfile, "a", encoding="utf-8") as file:
file.write(status + "\n")
except Exception:
status = f"{repo_id}: failed\n {traceback.format_exc()}"
with open(logfile, "a") as file:
with open(logfile, "a", encoding="utf-8") as file:
file.write(status + "\n")
continue

View File

@@ -190,11 +190,11 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
json_path = v2_dir / STATS_PATH
json_path.parent.mkdir(exist_ok=True, parents=True)
with open(json_path, "w") as f:
with open(json_path, "w", encoding="utf-8") as f:
json.dump(serialized_stats, f, indent=4)
# Sanity check
with open(json_path) as f:
with open(json_path, encoding="utf-8") as f:
stats_json = json.load(f)
stats_json = flatten_dict(stats_json)
@@ -213,7 +213,7 @@ def get_features_from_hf_dataset(
dtype = ft.dtype
shape = (1,)
names = None
if isinstance(ft, datasets.Sequence):
elif isinstance(ft, datasets.Sequence):
assert isinstance(ft.feature, datasets.Value)
dtype = ft.feature.dtype
shape = (ft.length,)
@@ -232,6 +232,8 @@ def get_features_from_hf_dataset(
dtype = "video"
shape = None # Add shape later
names = ["height", "width", "channels"]
else:
raise NotImplementedError(f"Feature type {ft._type} not supported.")
features[key] = {
"dtype": dtype,
@@ -358,9 +360,9 @@ def move_videos(
if len(video_dirs) == 1:
video_path = video_dirs[0] / video_file
else:
for dir in video_dirs:
if (dir / video_file).is_file():
video_path = dir / video_file
for v_dir in video_dirs:
if (v_dir / video_file).is_file():
video_path = v_dir / video_file
break
video_path.rename(work_dir / target_path)
@@ -652,6 +654,7 @@ def main():
if not args.local_dir:
args.local_dir = Path("/tmp/lerobot_dataset_v2")
robot_config = None
if args.robot is not None:
robot_config = make_robot_config(args.robot)

View File

@@ -50,7 +50,7 @@ def fix_dataset(repo_id: str) -> str:
return f"{repo_id}: skipped (no diff)"
if diff_meta_parquet:
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
logging.warning("In info.json not in parquet: %s", meta_features - parquet_features)
assert diff_meta_parquet == {"language_instruction"}
lerobot_metadata.features.pop("language_instruction")
write_info(lerobot_metadata.info, lerobot_metadata.root)
@@ -79,7 +79,7 @@ def batch_fix():
status = f"{repo_id}: failed\n {traceback.format_exc()}"
logging.info(status)
with open(logfile, "a") as file:
with open(logfile, "a", encoding="utf-8") as file:
file.write(status + "\n")

View File

@@ -46,7 +46,7 @@ def batch_convert():
except Exception:
status = f"{repo_id}: failed\n {traceback.format_exc()}"
with open(logfile, "a") as file:
with open(logfile, "a", encoding="utf-8") as file:
file.write(status + "\n")

View File

@@ -45,6 +45,9 @@ V21 = "v2.1"
class SuppressWarnings:
def __init__(self):
self.previous_level = None
def __enter__(self):
self.previous_level = logging.getLogger().getEffectiveLevel()
logging.getLogger().setLevel(logging.ERROR)

View File

@@ -83,7 +83,7 @@ def decode_video_frames_torchvision(
for frame in reader:
current_ts = frame["pts"]
if log_loaded_timestamps:
logging.info(f"frame loaded at timestamp={current_ts:.4f}")
logging.info("frame loaded at timestamp=%.4f", current_ts)
loaded_frames.append(frame["data"])
loaded_ts.append(current_ts)
if current_ts >= last_ts:
@@ -118,7 +118,7 @@ def decode_video_frames_torchvision(
closest_ts = loaded_ts[argmin_]
if log_loaded_timestamps:
logging.info(f"{closest_ts=}")
logging.info("closest_ts=%s", closest_ts)
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
closest_frames = closest_frames.type(torch.float32) / 255
@@ -227,7 +227,9 @@ def get_audio_info(video_path: Path | str) -> dict:
"json",
str(video_path),
]
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
result = subprocess.run(
ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True
)
if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
@@ -263,7 +265,9 @@ def get_video_info(video_path: Path | str) -> dict:
"json",
str(video_path),
]
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
result = subprocess.run(
ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True
)
if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}")