Use HWC for images
This commit is contained in:
@@ -13,7 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
@@ -138,6 +137,11 @@ class LeRobotDatasetMetadata:
|
||||
"""Formattable string for the video files."""
|
||||
return self.info["video_path"]
|
||||
|
||||
@property
|
||||
def robot_type(self) -> str | None:
|
||||
"""Robot type used in recording this dataset."""
|
||||
return self.info["robot_type"]
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
"""Frames per second used during data collection."""
|
||||
@@ -258,10 +262,14 @@ class LeRobotDatasetMetadata:
|
||||
write_json(self.info, self.root / INFO_PATH)
|
||||
|
||||
def __repr__(self):
|
||||
feature_keys = list(self.features)
|
||||
return (
|
||||
f"{self.__class__.__name__}\n"
|
||||
f"Repository ID: '{self.repo_id}',\n"
|
||||
f"{json.dumps(self.meta.info, indent=4)}\n"
|
||||
f"{self.__class__.__name__}({{\n"
|
||||
f" Repository ID: '{self.repo_id}',\n"
|
||||
f" Total episodes: '{self.total_episodes}',\n"
|
||||
f" Total frames: '{self.total_frames}',\n"
|
||||
f" Features: '{feature_keys}',\n"
|
||||
"})',\n"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -657,13 +665,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
feature_keys = list(self.features)
|
||||
return (
|
||||
f"{self.__class__.__name__}\n"
|
||||
f" Repository ID: '{self.repo_id}',\n"
|
||||
f" Selected episodes: {self.episodes},\n"
|
||||
f" Number of selected episodes: {self.num_episodes},\n"
|
||||
f" Number of selected samples: {self.num_frames},\n"
|
||||
f"\n{json.dumps(self.meta.info, indent=4)}\n"
|
||||
f"{self.__class__.__name__}({{\n"
|
||||
f" Repository ID: '{self.repo_id}',\n"
|
||||
f" Number of selected episodes: '{self.num_episodes}',\n"
|
||||
f" Number of selected samples: '{self.num_frames}',\n"
|
||||
f" Features: '{feature_keys}',\n"
|
||||
"})',\n"
|
||||
)
|
||||
|
||||
def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||
|
||||
@@ -468,6 +468,7 @@ def create_lerobot_dataset_card(
|
||||
text: str | None = None,
|
||||
info: dict | None = None,
|
||||
license: str | None = None,
|
||||
url: str | None = None,
|
||||
citation: str | None = None,
|
||||
arxiv: str | None = None,
|
||||
) -> DatasetCard:
|
||||
@@ -488,6 +489,8 @@ def create_lerobot_dataset_card(
|
||||
card.data.license = license
|
||||
if tags:
|
||||
card.data.tags += tags
|
||||
if url:
|
||||
card.text += f"## Homepage:\n{url}\n"
|
||||
if text:
|
||||
card.text += f"{text}\n"
|
||||
if info:
|
||||
|
||||
@@ -222,12 +222,12 @@ def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = N
|
||||
dtype = "image"
|
||||
image = dataset[0][key] # Assuming first row
|
||||
channels = get_image_pixel_channels(image)
|
||||
shape = (image.width, image.height, channels)
|
||||
names = ["width", "height", "channel"]
|
||||
shape = (image.height, image.width, channels)
|
||||
names = ["height", "width", "channel"]
|
||||
elif ft._type == "VideoFrame":
|
||||
dtype = "video"
|
||||
shape = None # Add shape later
|
||||
names = ["width", "height", "channel"]
|
||||
names = ["height", "width", "channel"]
|
||||
|
||||
features[key] = {
|
||||
"dtype": dtype,
|
||||
@@ -437,8 +437,9 @@ def convert_dataset(
|
||||
tasks_col: Path | None = None,
|
||||
robot_config: dict | None = None,
|
||||
license: str | None = None,
|
||||
citation: str | None = None,
|
||||
url: str | None = None,
|
||||
arxiv: str | None = None,
|
||||
citation: str | None = None,
|
||||
test_branch: str | None = None,
|
||||
):
|
||||
v1 = get_hub_safe_version(repo_id, V16)
|
||||
@@ -518,8 +519,8 @@ def convert_dataset(
|
||||
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
|
||||
for key in video_keys:
|
||||
features[key]["shape"] = (
|
||||
videos_info[key].pop("video.width"),
|
||||
videos_info[key].pop("video.height"),
|
||||
videos_info[key].pop("video.width"),
|
||||
videos_info[key].pop("video.channels"),
|
||||
)
|
||||
features[key]["video_info"] = videos_info[key]
|
||||
@@ -566,7 +567,7 @@ def convert_dataset(
|
||||
write_json(metadata_v2_0, v20_dir / INFO_PATH)
|
||||
convert_stats_to_json(v1x_dir, v20_dir)
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=repo_tags, info=metadata_v2_0, license=license, citation=citation, arxiv=arxiv
|
||||
tags=repo_tags, info=metadata_v2_0, license=license, url=url, citation=citation, arxiv=arxiv
|
||||
)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError):
|
||||
|
||||
@@ -279,8 +279,8 @@ def get_video_info(video_path: Path | str) -> dict:
|
||||
|
||||
video_info = {
|
||||
"video.fps": fps,
|
||||
"video.width": video_stream_info["width"],
|
||||
"video.height": video_stream_info["height"],
|
||||
"video.width": video_stream_info["width"],
|
||||
"video.channels": pixel_channels,
|
||||
"video.codec": video_stream_info["codec_name"],
|
||||
"video.pix_fmt": video_stream_info["pix_fmt"],
|
||||
|
||||
@@ -235,8 +235,8 @@ class ManipulatorRobot:
|
||||
for cam_key, cam in self.cameras.items():
|
||||
key = f"observation.images.{cam_key}"
|
||||
cam_ft[key] = {
|
||||
"shape": (cam.width, cam.height, cam.channels),
|
||||
"names": ["width", "height", "channels"],
|
||||
"shape": (cam.height, cam.width, cam.channels),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
}
|
||||
return cam_ft
|
||||
|
||||
Reference in New Issue
Block a user