Improve dataset v2 (#498)
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import textwrap
|
||||
import warnings
|
||||
from itertools import accumulate
|
||||
from pathlib import Path
|
||||
@@ -139,6 +140,8 @@ def load_info(local_dir: Path) -> dict:
|
||||
|
||||
|
||||
def load_stats(local_dir: Path) -> dict:
|
||||
if not (local_dir / STATS_PATH).exists():
|
||||
return None
|
||||
stats = load_json(local_dir / STATS_PATH)
|
||||
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
@@ -186,17 +189,37 @@ def _get_major_minor(version: str) -> tuple[int]:
|
||||
return int(split[0]), int(split[1])
|
||||
|
||||
|
||||
class BackwardCompatibilityError(Exception):
|
||||
def __init__(self, repo_id, version):
|
||||
message = textwrap.dedent(f"""
|
||||
BackwardCompatibilityError: The dataset you requested ({repo_id}) is in {version} format.
|
||||
|
||||
We introduced a new format since v2.0 which is not backward compatible with v1.x.
|
||||
Please, use our conversion script. Modify the following command with your own task description:
|
||||
```
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
|
||||
--repo-id {repo_id} \\
|
||||
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
|
||||
```
|
||||
|
||||
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.",
|
||||
"Insert the peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.",
|
||||
"Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.",
|
||||
"Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ...
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on Discord ('https://discord.com/invite/s3KuuzsPFb')
|
||||
or open an issue on GitHub.
|
||||
""")
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def check_version_compatibility(
|
||||
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
|
||||
) -> None:
|
||||
current_major, _ = _get_major_minor(current_version)
|
||||
major_to_check, _ = _get_major_minor(version_to_check)
|
||||
if major_to_check < current_major and enforce_breaking_major:
|
||||
raise ValueError(
|
||||
f"""The dataset you requested ({repo_id}) is in {version_to_check} format. We introduced a new
|
||||
format with v2.0 that is not backward compatible. Please use our conversion script
|
||||
first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format."""
|
||||
)
|
||||
raise BackwardCompatibilityError(repo_id, version_to_check)
|
||||
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
|
||||
warnings.warn(
|
||||
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
|
||||
@@ -207,18 +230,16 @@ def check_version_compatibility(
|
||||
)
|
||||
|
||||
|
||||
def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str:
|
||||
num_version = float(version.strip("v"))
|
||||
if num_version < 2 and enforce_v2:
|
||||
raise ValueError(
|
||||
f"""The dataset you requested ({repo_id}) is in {version} format. We introduced a new
|
||||
format with v2.0 that is not backward compatible. Please use our conversion script
|
||||
first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format."""
|
||||
)
|
||||
def get_hub_safe_version(repo_id: str, version: str) -> str:
|
||||
api = HfApi()
|
||||
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
branches = [b.name for b in dataset_info.branches]
|
||||
if version not in branches:
|
||||
num_version = float(version.strip("v"))
|
||||
hub_num_versions = [float(v.strip("v")) for v in branches if v.startswith("v")]
|
||||
if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions):
|
||||
raise BackwardCompatibilityError(repo_id, version)
|
||||
|
||||
warnings.warn(
|
||||
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
|
||||
codebase. The following versions are available: {branches}.
|
||||
@@ -461,6 +482,7 @@ def create_lerobot_dataset_card(
|
||||
}
|
||||
]
|
||||
card.data.task_categories = ["robotics"]
|
||||
card.data.license = license
|
||||
card.data.tags = ["LeRobot"]
|
||||
if license:
|
||||
card.data.license = license
|
||||
|
||||
Reference in New Issue
Block a user