[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
@@ -51,7 +51,13 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
|
||||
save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
|
||||
|
||||
# save 2 frames at the middle of first episode
|
||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
||||
i = int(
|
||||
(
|
||||
dataset.episode_data_index["to"][0].item()
|
||||
- dataset.episode_data_index["from"][0].item()
|
||||
)
|
||||
/ 2
|
||||
)
|
||||
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
|
||||
save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
|
||||
|
||||
@@ -87,4 +93,6 @@ if __name__ == "__main__":
|
||||
"lerobot/nyu_franka_play_dataset",
|
||||
"lerobot/cmu_stretch",
|
||||
]:
|
||||
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=dataset)
|
||||
save_dataset_to_safetensors(
|
||||
"tests/data/save_dataset_to_safetensors", repo_id=dataset
|
||||
)
|
||||
|
||||
@@ -67,7 +67,9 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
||||
param_stats = {}
|
||||
for key, param in policy.named_parameters():
|
||||
param_stats[f"{key}_mean"] = param.mean()
|
||||
param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0))
|
||||
param_stats[f"{key}_std"] = (
|
||||
param.std() if param.numel() > 1 else torch.tensor(float(0.0))
|
||||
)
|
||||
|
||||
optimizer.zero_grad()
|
||||
policy.reset()
|
||||
@@ -85,11 +87,15 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
||||
else:
|
||||
actions_queue = cfg.policy.n_action_repeats
|
||||
|
||||
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
|
||||
actions = {
|
||||
str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)
|
||||
}
|
||||
return output_dict, grad_stats, param_stats, actions
|
||||
|
||||
|
||||
def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_overrides, file_name_extra):
|
||||
def save_policy_to_safetensors(
|
||||
output_dir, env_name, policy_name, extra_overrides, file_name_extra
|
||||
):
|
||||
env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}{file_name_extra}"
|
||||
|
||||
if env_policy_dir.exists():
|
||||
@@ -99,7 +105,9 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
|
||||
shutil.rmtree(env_policy_dir)
|
||||
|
||||
env_policy_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(
|
||||
env_name, policy_name, extra_overrides
|
||||
)
|
||||
save_file(output_dict, env_policy_dir / "output_dict.safetensors")
|
||||
save_file(grad_stats, env_policy_dir / "grad_stats.safetensors")
|
||||
save_file(param_stats, env_policy_dir / "param_stats.safetensors")
|
||||
@@ -129,5 +137,9 @@ if __name__ == "__main__":
|
||||
raise RuntimeError("No policies were provided!")
|
||||
for env, policy, extra_overrides, file_name_extra in env_policies:
|
||||
save_policy_to_safetensors(
|
||||
"tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra
|
||||
"tests/data/save_policy_to_safetensors",
|
||||
env,
|
||||
policy,
|
||||
extra_overrides,
|
||||
file_name_extra,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user