Add add_frame, empty dataset creation
This commit is contained in:
@@ -13,7 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
@@ -26,15 +25,17 @@ from datasets import load_dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.common.datasets.image_writer import ImageWriter
|
||||
from lerobot.common.datasets.utils import (
|
||||
check_delta_timestamps,
|
||||
check_timestamps_sync,
|
||||
create_dataset_info,
|
||||
create_empty_dataset_info,
|
||||
get_delta_indices,
|
||||
get_episode_data_index,
|
||||
get_hub_safe_version,
|
||||
hf_transform_to_torch,
|
||||
load_metadata,
|
||||
write_json,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
@@ -55,6 +56,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
tolerance_s: float = 1e-4,
|
||||
download_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
image_writer: ImageWriter | None = None,
|
||||
):
|
||||
"""LeRobotDataset encapsulates 3 main things:
|
||||
- metadata:
|
||||
@@ -156,6 +158,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.tolerance_s = tolerance_s
|
||||
self.download_videos = download_videos
|
||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
self.image_writer = image_writer
|
||||
self.episode_buffer = {}
|
||||
self.delta_indices = None
|
||||
|
||||
# Load metadata
|
||||
@@ -296,9 +300,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
"""Number of samples/frames."""
|
||||
"""Number of samples/frames in selected episodes."""
|
||||
return len(self.hf_dataset)
|
||||
|
||||
@property
|
||||
def total_frames(self) -> int:
|
||||
"""Total number of frames saved in this dataset."""
|
||||
return self.info["total_frames"]
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
"""Number of episodes selected."""
|
||||
@@ -423,10 +432,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
return item
|
||||
|
||||
def write_info(self) -> None:
|
||||
with open(self.root / "meta/info.json", "w") as f:
|
||||
json.dump(self.info, f, indent=4, ensure_ascii=False)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}(\n"
|
||||
@@ -442,6 +447,49 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
f")"
|
||||
)
|
||||
|
||||
def _create_episode_buffer(self) -> dict:
|
||||
# TODO(aliberts): Handle resume
|
||||
return {
|
||||
"chunk": self.total_chunks,
|
||||
"episode_index": self.total_episodes,
|
||||
"size": 0,
|
||||
"frame_index": [],
|
||||
"timestamp": [],
|
||||
"next.done": [],
|
||||
**{key: [] for key in self.keys},
|
||||
}
|
||||
|
||||
def add_frame(self, frame: dict) -> None:
|
||||
frame_index = self.episode_buffer["size"]
|
||||
self.episode_buffer["frame_index"].append(frame_index)
|
||||
self.episode_buffer["timestamp"].append(frame_index / self.fps)
|
||||
self.episode_buffer["next.done"].append(False)
|
||||
|
||||
# Save all observed modalities except images
|
||||
for key in self.keys:
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
if self.image_writer is None:
|
||||
return
|
||||
|
||||
# Save images
|
||||
for cam_key in self.camera_keys:
|
||||
img_path = self.image_writer.get_image_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"],
|
||||
image_key=cam_key,
|
||||
frame_index=frame_index,
|
||||
return_str=False,
|
||||
)
|
||||
if frame_index == 0:
|
||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.image_writer.async_save_image(
|
||||
image=frame[cam_key],
|
||||
file_path=img_path,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
@@ -450,24 +498,29 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
robot: Robot,
|
||||
root: Path | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
image_writer: ImageWriter | None = None,
|
||||
use_videos: bool = True,
|
||||
) -> "LeRobotDataset":
|
||||
"""Create a LeRobot Dataset from scratch in order to record data."""
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||
obj._version = CODEBASE_VERSION
|
||||
obj.tolerance_s = tolerance_s
|
||||
obj.image_writer = image_writer
|
||||
|
||||
obj.root.mkdir(exist_ok=True, parents=True)
|
||||
obj.info = create_dataset_info(obj._version, fps, robot)
|
||||
obj.write_info()
|
||||
obj.fps = fps
|
||||
|
||||
if not all(cam.fps == fps for cam in robot.cameras):
|
||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||
logging.warn(
|
||||
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
|
||||
)
|
||||
|
||||
obj.info = create_empty_dataset_info(obj._version, fps, robot, use_videos)
|
||||
write_json(obj.info, obj.root / "meta/info.json")
|
||||
|
||||
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
||||
obj.episode_buffer = obj._create_episode_buffer()
|
||||
|
||||
# obj.episodes = None
|
||||
# obj.image_transforms = None
|
||||
# obj.delta_timestamps = None
|
||||
|
||||
Reference in New Issue
Block a user