forked from tangger/lerobot
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
AdilZouitine
parent
2945bbb221
commit
7c05755823
@@ -223,18 +223,12 @@ def train(cfg: TrainPipelineConfig):
|
||||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
|
||||
if cfg.resume:
|
||||
step, optimizer, lr_scheduler = load_training_state(
|
||||
cfg.checkpoint_path, optimizer, lr_scheduler
|
||||
)
|
||||
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
||||
|
||||
num_learnable_params = sum(
|
||||
p.numel() for p in policy.parameters() if p.requires_grad
|
||||
)
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
logging.info(
|
||||
colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}"
|
||||
)
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||
if cfg.env is not None:
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
||||
@@ -335,9 +329,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type)
|
||||
if cfg.policy.use_amp
|
||||
else nullcontext(),
|
||||
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
||||
):
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
|
||||
Reference in New Issue
Block a user