Remove local_files_only and use codebase_version instead of branches (#734)
This commit is contained in:
14
tests/fixtures/dataset_factories.py
vendored
14
tests/fixtures/dataset_factories.py
vendored
@@ -309,7 +309,6 @@ def lerobot_dataset_metadata_factory(
|
||||
episodes_stats: list[dict] | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episodes: list[dict] | None = None,
|
||||
local_files_only: bool = False,
|
||||
) -> LeRobotDatasetMetadata:
|
||||
if not info:
|
||||
info = info_factory()
|
||||
@@ -335,16 +334,16 @@ def lerobot_dataset_metadata_factory(
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.get_hub_safe_version"
|
||||
) as mock_get_hub_safe_version_patch,
|
||||
"lerobot.common.datasets.lerobot_dataset.get_safe_revision"
|
||||
) as mock_get_safe_revision_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||
) as mock_snapshot_download_patch,
|
||||
):
|
||||
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version: version
|
||||
mock_get_safe_revision_patch.side_effect = lambda repo_id, version: version
|
||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
||||
|
||||
return LeRobotDatasetMetadata(repo_id=repo_id, root=root, local_files_only=local_files_only)
|
||||
return LeRobotDatasetMetadata(repo_id=repo_id, root=root)
|
||||
|
||||
return _create_lerobot_dataset_metadata
|
||||
|
||||
@@ -411,15 +410,18 @@ def lerobot_dataset_factory(
|
||||
episodes_stats=episodes_stats,
|
||||
tasks=tasks,
|
||||
episodes=episode_dicts,
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
)
|
||||
with (
|
||||
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.get_safe_revision"
|
||||
) as mock_get_safe_revision_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||
) as mock_snapshot_download_patch,
|
||||
):
|
||||
mock_metadata_patch.return_value = mock_metadata
|
||||
mock_get_safe_revision_patch.side_effect = lambda repo_id, version: version
|
||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
||||
|
||||
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)
|
||||
|
||||
@@ -167,9 +167,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
|
||||
assert dataset.meta.total_episodes == 2
|
||||
assert len(dataset) == 2
|
||||
|
||||
replay_cfg = ReplayControlConfig(
|
||||
episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False, local_files_only=True
|
||||
)
|
||||
replay_cfg = ReplayControlConfig(episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False)
|
||||
replay(robot, replay_cfg)
|
||||
|
||||
policy_cfg = ACTConfig()
|
||||
@@ -266,7 +264,6 @@ def test_resume_record(tmp_path, request, robot_type, mock):
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
local_files_only=True,
|
||||
num_episodes=1,
|
||||
)
|
||||
|
||||
|
||||
@@ -77,10 +77,6 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||
root_init = tmp_path / "init"
|
||||
dataset_init = lerobot_dataset_factory(root=root_init)
|
||||
|
||||
# Access the '_hub_version' cached_property in both instances to force its creation
|
||||
_ = dataset_init.meta._hub_version
|
||||
_ = dataset_create.meta._hub_version
|
||||
|
||||
init_attr = set(vars(dataset_init).keys())
|
||||
create_attr = set(vars(dataset_create).keys())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user