[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-24 13:16:38 +00:00
committed by Michel Aractingi
parent cdcf346061
commit 1c8daf11fd
95 changed files with 1592 additions and 491 deletions

View File

@@ -210,7 +210,10 @@ def tasks_factory():
def _create_tasks(total_tasks: int = 3) -> int:
tasks = {}
for task_index in range(total_tasks):
task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."}
task_dict = {
"task_index": task_index,
"task": f"Perform action {task_index}.",
}
tasks[task_index] = task_dict
return tasks
@@ -297,8 +300,12 @@ def hf_dataset_factory(
episode_index_col = np.array([], dtype=np.int64)
task_index = np.array([], dtype=np.int64)
for ep_dict in episodes.values():
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
timestamp_col = np.concatenate(
(timestamp_col, np.arange(ep_dict["length"]) / fps)
)
frame_index_col = np.concatenate(
(frame_index_col, np.arange(ep_dict["length"], dtype=int))
)
episode_index_col = np.concatenate(
(
episode_index_col,
@@ -385,7 +392,9 @@ def lerobot_dataset_metadata_factory(
episodes=episodes,
)
with (
patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.get_safe_version"
) as mock_get_safe_version_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
@@ -433,7 +442,9 @@ def lerobot_dataset_factory(
if not stats:
stats = stats_factory(features=info["features"])
if not episodes_stats:
episodes_stats = episodes_stats_factory(features=info["features"], total_episodes=total_episodes)
episodes_stats = episodes_stats_factory(
features=info["features"], total_episodes=total_episodes
)
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episode_dicts:
@@ -466,8 +477,12 @@ def lerobot_dataset_factory(
episodes=episode_dicts,
)
with (
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata"
) as mock_metadata_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.get_safe_version"
) as mock_get_safe_version_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,

View File

@@ -59,7 +59,9 @@ def stats_path(stats_factory):
@pytest.fixture(scope="session")
def episodes_stats_path(episodes_stats_factory):
def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path:
def _create_episodes_stats_jsonl_file(
dir: Path, episodes_stats: list[dict] | None = None
) -> Path:
if not episodes_stats:
episodes_stats = episodes_stats_factory()
fpath = dir / EPISODES_STATS_PATH

View File

@@ -99,7 +99,13 @@ def mock_snapshot_download_factory(
# List all possible files
all_files = []
meta_files = [INFO_PATH, STATS_PATH, EPISODES_STATS_PATH, TASKS_PATH, EPISODES_PATH]
meta_files = [
INFO_PATH,
STATS_PATH,
EPISODES_STATS_PATH,
TASKS_PATH,
EPISODES_PATH,
]
all_files.extend(meta_files)
data_files = []

View File

@@ -35,5 +35,7 @@ def optimizer(model_params):
@pytest.fixture
def scheduler(optimizer):
config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5)
config = VQBeTSchedulerConfig(
num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5
)
return config.build(optimizer, num_training_steps=100)