Cleanup, fix load_tasks

This commit is contained in:
Simon Alibert
2024-10-15 11:05:16 +02:00
parent f96773de10
commit 835ab5a81b
2 changed files with 20 additions and 17 deletions

View File

@@ -80,6 +80,7 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
# TODO(aliberts): remove this part as we'll be using task_index
elif isinstance(first_item, str):
# TODO (michel-aractingi): add str2embedding via language tokenizer
# For now we leave this part up to the user to choose how to address
@@ -96,13 +97,13 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
@cache
def get_hub_safe_version(repo_id: str, version: str) -> str:
def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str:
num_version = float(version.strip("v"))
if num_version < 2:
if num_version < 2 and enforce_v2:
raise ValueError(
f"""The dataset you requested ({repo_id}) is in {version} format. We introduced a new
format with v2.0 that is not backward compatible. Please use our conversion script
first (convert_dataset_16_to_20.py) to convert your dataset to this new format."""
first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format."""
)
api = HfApi()
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
@@ -192,7 +193,9 @@ def load_tasks(repo_id: str, version: str, local_dir: Path) -> dict:
repo_id, filename="meta/tasks.json", local_dir=local_dir, repo_type="dataset", revision=version
)
with open(fpath) as f:
return json.load(f)
tasks = json.load(f)
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
def get_episode_data_index(episodes: list, episode_dicts: list[dict]) -> dict[str, torch.Tensor]: