Improve dataset v2 (#498)

This commit is contained in:
Remi
2024-11-19 12:31:47 +01:00
committed by GitHub
parent acae4b49d2
commit 1f13bda25b
9 changed files with 393 additions and 70 deletions

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import textwrap
import warnings
from itertools import accumulate
from pathlib import Path
@@ -139,6 +140,8 @@ def load_info(local_dir: Path) -> dict:
def load_stats(local_dir: Path) -> dict:
if not (local_dir / STATS_PATH).exists():
return None
stats = load_json(local_dir / STATS_PATH)
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
@@ -186,17 +189,37 @@ def _get_major_minor(version: str) -> tuple[int]:
return int(split[0]), int(split[1])
class BackwardCompatibilityError(Exception):
def __init__(self, repo_id, version):
message = textwrap.dedent(f"""
BackwardCompatibilityError: The dataset you requested ({repo_id}) is in {version} format.
We introduced a new format since v2.0 which is not backward compatible with v1.x.
Please, use our conversion script. Modify the following command with your own task description:
```
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
--repo-id {repo_id} \\
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
```
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.",
"Insert the peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.",
"Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.",
"Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ...
If you encounter a problem, contact LeRobot maintainers on Discord ('https://discord.com/invite/s3KuuzsPFb')
or open an issue on GitHub.
""")
super().__init__(message)
def check_version_compatibility(
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
) -> None:
current_major, _ = _get_major_minor(current_version)
major_to_check, _ = _get_major_minor(version_to_check)
if major_to_check < current_major and enforce_breaking_major:
raise ValueError(
f"""The dataset you requested ({repo_id}) is in {version_to_check} format. We introduced a new
format with v2.0 that is not backward compatible. Please use our conversion script
first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format."""
)
raise BackwardCompatibilityError(repo_id, version_to_check)
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
warnings.warn(
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
@@ -207,18 +230,16 @@ def check_version_compatibility(
)
def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str:
num_version = float(version.strip("v"))
if num_version < 2 and enforce_v2:
raise ValueError(
f"""The dataset you requested ({repo_id}) is in {version} format. We introduced a new
format with v2.0 that is not backward compatible. Please use our conversion script
first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format."""
)
def get_hub_safe_version(repo_id: str, version: str) -> str:
api = HfApi()
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
branches = [b.name for b in dataset_info.branches]
if version not in branches:
num_version = float(version.strip("v"))
hub_num_versions = [float(v.strip("v")) for v in branches if v.startswith("v")]
if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions):
raise BackwardCompatibilityError(repo_id, version)
warnings.warn(
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
codebase. The following versions are available: {branches}.
@@ -461,6 +482,7 @@ def create_lerobot_dataset_card(
}
]
card.data.task_categories = ["robotics"]
card.data.license = license
card.data.tags = ["LeRobot"]
if license:
card.data.license = license