Use HWC for images

This commit is contained in:
Simon Alibert
2024-11-19 18:47:32 +01:00
parent 1f13bda25b
commit 6203641710
8 changed files with 40 additions and 36 deletions

View File

@@ -222,12 +222,12 @@ def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = N
dtype = "image"
image = dataset[0][key] # Assuming first row
channels = get_image_pixel_channels(image)
shape = (image.width, image.height, channels)
names = ["width", "height", "channel"]
shape = (image.height, image.width, channels)
names = ["height", "width", "channel"]
elif ft._type == "VideoFrame":
dtype = "video"
shape = None # Add shape later
names = ["width", "height", "channel"]
names = ["height", "width", "channel"]
features[key] = {
"dtype": dtype,
@@ -437,8 +437,9 @@ def convert_dataset(
tasks_col: Path | None = None,
robot_config: dict | None = None,
license: str | None = None,
citation: str | None = None,
url: str | None = None,
arxiv: str | None = None,
citation: str | None = None,
test_branch: str | None = None,
):
v1 = get_hub_safe_version(repo_id, V16)
@@ -518,8 +519,8 @@ def convert_dataset(
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
for key in video_keys:
features[key]["shape"] = (
videos_info[key].pop("video.width"),
videos_info[key].pop("video.height"),
videos_info[key].pop("video.width"),
videos_info[key].pop("video.channels"),
)
features[key]["video_info"] = videos_info[key]
@@ -566,7 +567,7 @@ def convert_dataset(
write_json(metadata_v2_0, v20_dir / INFO_PATH)
convert_stats_to_json(v1x_dir, v20_dir)
card = create_lerobot_dataset_card(
tags=repo_tags, info=metadata_v2_0, license=license, citation=citation, arxiv=arxiv
tags=repo_tags, info=metadata_v2_0, license=license, url=url, citation=citation, arxiv=arxiv
)
with contextlib.suppress(EntryNotFoundError):