forked from tangger/lerobot
Fix nightly (#775)
This commit is contained in:
@@ -27,16 +27,13 @@ from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs):
|
||||
# TODO(rcadene, aliberts): env_name?
|
||||
def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
set_seed(1337)
|
||||
|
||||
train_cfg = TrainPipelineConfig(
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
|
||||
policy=make_policy_config(policy_name, **policy_kwargs),
|
||||
device="cpu",
|
||||
**train_kwargs,
|
||||
)
|
||||
train_cfg.validate() # Needed for auto-setting some parameters
|
||||
|
||||
@@ -54,8 +51,11 @@ def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwa
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
loss, output_dict = policy.forward(batch)
|
||||
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
|
||||
output_dict["loss"] = loss
|
||||
if output_dict is not None:
|
||||
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
|
||||
output_dict["loss"] = loss
|
||||
else:
|
||||
output_dict = {"loss": loss}
|
||||
|
||||
loss.backward()
|
||||
grad_stats = {}
|
||||
@@ -101,30 +101,27 @@ def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwa
|
||||
return output_dict, grad_stats, param_stats, actions
|
||||
|
||||
|
||||
def save_policy_to_safetensors(output_dir, env_name, policy_name, policy_kwargs, file_name_extra):
|
||||
env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}{file_name_extra}"
|
||||
def save_policy_to_safetensors(output_dir: Path, ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
if output_dir.exists():
|
||||
print(f"Overwrite existing safetensors in '{output_dir}':")
|
||||
print(f" - Validate with: `git add {output_dir}`")
|
||||
print(f" - Revert with: `git checkout -- {output_dir}`")
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
if env_policy_dir.exists():
|
||||
print(f"Overwrite existing safetensors in '{env_policy_dir}':")
|
||||
print(f" - Validate with: `git add {env_policy_dir}`")
|
||||
print(f" - Revert with: `git checkout -- {env_policy_dir}`")
|
||||
shutil.rmtree(env_policy_dir)
|
||||
|
||||
env_policy_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, policy_kwargs)
|
||||
save_file(output_dict, env_policy_dir / "output_dict.safetensors")
|
||||
save_file(grad_stats, env_policy_dir / "grad_stats.safetensors")
|
||||
save_file(param_stats, env_policy_dir / "param_stats.safetensors")
|
||||
save_file(actions, env_policy_dir / "actions.safetensors")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs)
|
||||
save_file(output_dict, output_dir / "output_dict.safetensors")
|
||||
save_file(grad_stats, output_dir / "grad_stats.safetensors")
|
||||
save_file(param_stats, output_dir / "param_stats.safetensors")
|
||||
save_file(actions, output_dir / "actions.safetensors")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
env_policies = [
|
||||
("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": False}, "use_policy"),
|
||||
("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": True}, "use_mpc"),
|
||||
artifacts_cfg = [
|
||||
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"),
|
||||
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
|
||||
(
|
||||
"lerobot/pusht",
|
||||
"pusht",
|
||||
"diffusion",
|
||||
{
|
||||
"n_action_steps": 8,
|
||||
@@ -133,18 +130,17 @@ if __name__ == "__main__":
|
||||
},
|
||||
"",
|
||||
),
|
||||
("lerobot/aloha_sim_insertion_human", "aloha", "act", {"n_action_steps": 10}, ""),
|
||||
("lerobot/aloha_sim_insertion_human", "act", {"n_action_steps": 10}, ""),
|
||||
(
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
"aloha",
|
||||
"act",
|
||||
{"n_action_steps": 1000, "chunk_size": 1000},
|
||||
"_1000_steps",
|
||||
"1000_steps",
|
||||
),
|
||||
]
|
||||
if len(env_policies) == 0:
|
||||
if len(artifacts_cfg) == 0:
|
||||
raise RuntimeError("No policies were provided!")
|
||||
for ds_repo_id, env, policy, policy_kwargs, file_name_extra in env_policies:
|
||||
save_policy_to_safetensors(
|
||||
"tests/data/save_policy_to_safetensors", ds_repo_id, env, policy, policy_kwargs, file_name_extra
|
||||
)
|
||||
for ds_repo_id, policy, policy_kwargs, file_name_extra in artifacts_cfg:
|
||||
ds_name = ds_repo_id.split("/")[-1]
|
||||
output_dir = Path("tests/data/save_policy_to_safetensors") / f"{ds_name}_{policy}_{file_name_extra}"
|
||||
save_policy_to_safetensors(output_dir, ds_repo_id, policy, policy_kwargs)
|
||||
|
||||
Reference in New Issue
Block a user