forked from tangger/lerobot
In tests: Add use_videos=False by default, Create mp4 file if True, then fix test_datasets and test_aggregate (all passing)
This commit is contained in:
@@ -17,6 +17,7 @@ from lerobot.common.datasets.utils import (
|
||||
concat_video_files,
|
||||
get_parquet_file_size_in_mb,
|
||||
get_video_size_in_mb,
|
||||
safe_write_dataframe_to_parquet,
|
||||
update_chunk_file_indices,
|
||||
write_info,
|
||||
write_stats,
|
||||
@@ -97,6 +98,7 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
||||
|
||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
||||
image_keys = [key for key in features if features[key]["dtype"] == "image"]
|
||||
|
||||
# Create resulting dataset folder
|
||||
aggr_meta = LeRobotDatasetMetadata.create(
|
||||
@@ -259,7 +261,7 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
||||
# Update the existing parquet file with new rows
|
||||
aggr_df = pd.read_parquet(aggr_path)
|
||||
df = pd.concat([aggr_df, df], ignore_index=True)
|
||||
df.to_parquet(aggr_path)
|
||||
safe_write_dataframe_to_parquet(df, aggr_path, image_keys)
|
||||
|
||||
num_episodes += meta.total_episodes
|
||||
num_frames += meta.total_frames
|
||||
|
||||
@@ -63,6 +63,7 @@ from lerobot.common.datasets.utils import (
|
||||
load_nested_dataset,
|
||||
load_stats,
|
||||
load_tasks,
|
||||
safe_write_dataframe_to_parquet,
|
||||
update_chunk_file_indices,
|
||||
validate_episode_buffer,
|
||||
validate_frame,
|
||||
@@ -1008,10 +1009,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Write the resulting dataframe from RAM to disk
|
||||
path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if len(self.meta.image_keys) > 0:
|
||||
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
|
||||
else:
|
||||
df.to_parquet(path)
|
||||
safe_write_dataframe_to_parquet(df, path, self.meta.image_keys)
|
||||
|
||||
# Update the Hugging Face dataset by reloading it.
|
||||
# This process should be fast because only the latest Parquet file has been modified.
|
||||
|
||||
@@ -890,3 +890,11 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
|
||||
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
||||
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
||||
)
|
||||
|
||||
|
||||
def safe_write_dataframe_to_parquet(df: pandas.DataFrame, path: Path, image_keys: list[str]):
|
||||
if len(image_keys) > 0:
|
||||
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
||||
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
|
||||
else:
|
||||
df.to_parquet(path)
|
||||
|
||||
Reference in New Issue
Block a user