[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-04 13:38:47 +00:00
parent d8a1758122
commit 584cad808e
108 changed files with 3894 additions and 1189 deletions

View File

@@ -51,7 +51,13 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
# save 2 frames at the middle of first episode
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
i = int(
(
dataset.episode_data_index["to"][0].item()
- dataset.episode_data_index["from"][0].item()
)
/ 2
)
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
@@ -87,4 +93,6 @@ if __name__ == "__main__":
"lerobot/nyu_franka_play_dataset",
"lerobot/cmu_stretch",
]:
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=dataset)
save_dataset_to_safetensors(
"tests/data/save_dataset_to_safetensors", repo_id=dataset
)

View File

@@ -67,7 +67,9 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
param_stats = {}
for key, param in policy.named_parameters():
param_stats[f"{key}_mean"] = param.mean()
param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0))
param_stats[f"{key}_std"] = (
param.std() if param.numel() > 1 else torch.tensor(float(0.0))
)
optimizer.zero_grad()
policy.reset()
@@ -85,11 +87,15 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
else:
actions_queue = cfg.policy.n_action_repeats
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
actions = {
str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)
}
return output_dict, grad_stats, param_stats, actions
def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_overrides, file_name_extra):
def save_policy_to_safetensors(
output_dir, env_name, policy_name, extra_overrides, file_name_extra
):
env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}{file_name_extra}"
if env_policy_dir.exists():
@@ -99,7 +105,9 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
shutil.rmtree(env_policy_dir)
env_policy_dir.mkdir(parents=True, exist_ok=True)
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
output_dict, grad_stats, param_stats, actions = get_policy_stats(
env_name, policy_name, extra_overrides
)
save_file(output_dict, env_policy_dir / "output_dict.safetensors")
save_file(grad_stats, env_policy_dir / "grad_stats.safetensors")
save_file(param_stats, env_policy_dir / "param_stats.safetensors")
@@ -129,5 +137,9 @@ if __name__ == "__main__":
raise RuntimeError("No policies were provided!")
for env, policy, extra_overrides, file_name_extra in env_policies:
save_policy_to_safetensors(
"tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra
"tests/data/save_policy_to_safetensors",
env,
policy,
extra_overrides,
file_name_extra,
)