[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
cdcf346061
commit
1c8daf11fd
29
tests/fixtures/dataset_factories.py
vendored
29
tests/fixtures/dataset_factories.py
vendored
@@ -210,7 +210,10 @@ def tasks_factory():
|
||||
def _create_tasks(total_tasks: int = 3) -> int:
|
||||
tasks = {}
|
||||
for task_index in range(total_tasks):
|
||||
task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."}
|
||||
task_dict = {
|
||||
"task_index": task_index,
|
||||
"task": f"Perform action {task_index}.",
|
||||
}
|
||||
tasks[task_index] = task_dict
|
||||
return tasks
|
||||
|
||||
@@ -297,8 +300,12 @@ def hf_dataset_factory(
|
||||
episode_index_col = np.array([], dtype=np.int64)
|
||||
task_index = np.array([], dtype=np.int64)
|
||||
for ep_dict in episodes.values():
|
||||
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
|
||||
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
|
||||
timestamp_col = np.concatenate(
|
||||
(timestamp_col, np.arange(ep_dict["length"]) / fps)
|
||||
)
|
||||
frame_index_col = np.concatenate(
|
||||
(frame_index_col, np.arange(ep_dict["length"], dtype=int))
|
||||
)
|
||||
episode_index_col = np.concatenate(
|
||||
(
|
||||
episode_index_col,
|
||||
@@ -385,7 +392,9 @@ def lerobot_dataset_metadata_factory(
|
||||
episodes=episodes,
|
||||
)
|
||||
with (
|
||||
patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.get_safe_version"
|
||||
) as mock_get_safe_version_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||
) as mock_snapshot_download_patch,
|
||||
@@ -433,7 +442,9 @@ def lerobot_dataset_factory(
|
||||
if not stats:
|
||||
stats = stats_factory(features=info["features"])
|
||||
if not episodes_stats:
|
||||
episodes_stats = episodes_stats_factory(features=info["features"], total_episodes=total_episodes)
|
||||
episodes_stats = episodes_stats_factory(
|
||||
features=info["features"], total_episodes=total_episodes
|
||||
)
|
||||
if not tasks:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episode_dicts:
|
||||
@@ -466,8 +477,12 @@ def lerobot_dataset_factory(
|
||||
episodes=episode_dicts,
|
||||
)
|
||||
with (
|
||||
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
|
||||
patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata"
|
||||
) as mock_metadata_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.get_safe_version"
|
||||
) as mock_get_safe_version_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||
) as mock_snapshot_download_patch,
|
||||
|
||||
4
tests/fixtures/files.py
vendored
4
tests/fixtures/files.py
vendored
@@ -59,7 +59,9 @@ def stats_path(stats_factory):
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episodes_stats_path(episodes_stats_factory):
|
||||
def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path:
|
||||
def _create_episodes_stats_jsonl_file(
|
||||
dir: Path, episodes_stats: list[dict] | None = None
|
||||
) -> Path:
|
||||
if not episodes_stats:
|
||||
episodes_stats = episodes_stats_factory()
|
||||
fpath = dir / EPISODES_STATS_PATH
|
||||
|
||||
8
tests/fixtures/hub.py
vendored
8
tests/fixtures/hub.py
vendored
@@ -99,7 +99,13 @@ def mock_snapshot_download_factory(
|
||||
|
||||
# List all possible files
|
||||
all_files = []
|
||||
meta_files = [INFO_PATH, STATS_PATH, EPISODES_STATS_PATH, TASKS_PATH, EPISODES_PATH]
|
||||
meta_files = [
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
EPISODES_STATS_PATH,
|
||||
TASKS_PATH,
|
||||
EPISODES_PATH,
|
||||
]
|
||||
all_files.extend(meta_files)
|
||||
|
||||
data_files = []
|
||||
|
||||
4
tests/fixtures/optimizers.py
vendored
4
tests/fixtures/optimizers.py
vendored
@@ -35,5 +35,7 @@ def optimizer(model_params):
|
||||
|
||||
@pytest.fixture
|
||||
def scheduler(optimizer):
|
||||
config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5)
|
||||
config = VQBeTSchedulerConfig(
|
||||
num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5
|
||||
)
|
||||
return config.build(optimizer, num_training_steps=100)
|
||||
|
||||
Reference in New Issue
Block a user