Cleanup, fix load_tasks
This commit is contained in:
@@ -80,6 +80,7 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
# TODO(aliberts): remove this part as we'll be using task_index
|
||||
elif isinstance(first_item, str):
|
||||
# TODO (michel-aractingi): add str2embedding via language tokenizer
|
||||
# For now we leave this part up to the user to choose how to address
|
||||
@@ -96,13 +97,13 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
|
||||
|
||||
@cache
|
||||
def get_hub_safe_version(repo_id: str, version: str) -> str:
|
||||
def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str:
|
||||
num_version = float(version.strip("v"))
|
||||
if num_version < 2:
|
||||
if num_version < 2 and enforce_v2:
|
||||
raise ValueError(
|
||||
f"""The dataset you requested ({repo_id}) is in {version} format. We introduced a new
|
||||
format with v2.0 that is not backward compatible. Please use our conversion script
|
||||
first (convert_dataset_16_to_20.py) to convert your dataset to this new format."""
|
||||
first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format."""
|
||||
)
|
||||
api = HfApi()
|
||||
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
@@ -192,7 +193,9 @@ def load_tasks(repo_id: str, version: str, local_dir: Path) -> dict:
|
||||
repo_id, filename="meta/tasks.json", local_dir=local_dir, repo_type="dataset", revision=version
|
||||
)
|
||||
with open(fpath) as f:
|
||||
return json.load(f)
|
||||
tasks = json.load(f)
|
||||
|
||||
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
|
||||
|
||||
def get_episode_data_index(episodes: list, episode_dicts: list[dict]) -> dict[str, torch.Tensor]:
|
||||
|
||||
Reference in New Issue
Block a user