[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
2abbd60a0d
commit
0ea27704f6
@@ -52,13 +52,7 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
|
||||
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
|
||||
|
||||
# save 2 frames at the middle of first episode
|
||||
i = int(
|
||||
(
|
||||
dataset.episode_data_index["to"][0].item()
|
||||
- dataset.episode_data_index["from"][0].item()
|
||||
)
|
||||
/ 2
|
||||
)
|
||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
||||
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
|
||||
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
|
||||
|
||||
|
||||
@@ -51,9 +51,7 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
batch = next(iter(dataloader))
|
||||
loss, output_dict = policy.forward(batch)
|
||||
if output_dict is not None:
|
||||
output_dict = {
|
||||
k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)
|
||||
}
|
||||
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
|
||||
output_dict["loss"] = loss
|
||||
else:
|
||||
output_dict = {"loss": loss}
|
||||
@@ -71,9 +69,7 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
param_stats = {}
|
||||
for key, param in policy.named_parameters():
|
||||
param_stats[f"{key}_mean"] = param.mean()
|
||||
param_stats[f"{key}_std"] = (
|
||||
param.std() if param.numel() > 1 else torch.tensor(float(0.0))
|
||||
)
|
||||
param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0))
|
||||
|
||||
optimizer.zero_grad()
|
||||
policy.reset()
|
||||
@@ -100,15 +96,11 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
else:
|
||||
actions_queue = train_cfg.policy.n_action_repeats
|
||||
|
||||
actions = {
|
||||
str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)
|
||||
}
|
||||
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
|
||||
return output_dict, grad_stats, param_stats, actions
|
||||
|
||||
|
||||
def save_policy_to_safetensors(
|
||||
output_dir: Path, ds_repo_id: str, policy_name: str, policy_kwargs: dict
|
||||
):
|
||||
def save_policy_to_safetensors(output_dir: Path, ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
if output_dir.exists():
|
||||
print(f"Overwrite existing safetensors in '{output_dir}':")
|
||||
print(f" - Validate with: `git add {output_dir}`")
|
||||
@@ -116,9 +108,7 @@ def save_policy_to_safetensors(
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(
|
||||
ds_repo_id, policy_name, policy_kwargs
|
||||
)
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs)
|
||||
save_file(output_dict, output_dir / "output_dict.safetensors")
|
||||
save_file(grad_stats, output_dir / "grad_stats.safetensors")
|
||||
save_file(param_stats, output_dir / "param_stats.safetensors")
|
||||
@@ -151,7 +141,5 @@ if __name__ == "__main__":
|
||||
raise RuntimeError("No policies were provided!")
|
||||
for ds_repo_id, policy, policy_kwargs, file_name_extra in artifacts_cfg:
|
||||
ds_name = ds_repo_id.split("/")[-1]
|
||||
output_dir = (
|
||||
Path("tests/artifacts/policies") / f"{ds_name}_{policy}_{file_name_extra}"
|
||||
)
|
||||
output_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy}_{file_name_extra}"
|
||||
save_policy_to_safetensors(output_dir, ds_repo_id, policy, policy_kwargs)
|
||||
|
||||
Reference in New Issue
Block a user