forked from tangger/lerobot
grad_clip_norm as arg of update policy
This commit is contained in:
@@ -24,7 +24,7 @@ from lerobot.common.utils.utils import (
|
|||||||
from lerobot.scripts.eval import eval_policy
|
from lerobot.scripts.eval import eval_policy
|
||||||
|
|
||||||
|
|
||||||
def update_policy(cfg, policy, batch, optimizer, lr_scheduler=None):
|
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
model = policy.diffusion if hasattr(policy, "diffusion") else policy # TODO: hacky, remove this line
|
model = policy.diffusion if hasattr(policy, "diffusion") else policy # TODO: hacky, remove this line
|
||||||
@@ -41,7 +41,7 @@ def update_policy(cfg, policy, batch, optimizer, lr_scheduler=None):
|
|||||||
model = policy.diffusion if hasattr(policy, "diffusion") else policy # TODO: hacky, remove this line
|
model = policy.diffusion if hasattr(policy, "diffusion") else policy # TODO: hacky, remove this line
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
cfg.grad_clip_norm,
|
grad_clip_norm,
|
||||||
error_if_nonfinite=False,
|
error_if_nonfinite=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -275,8 +275,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||||||
logging.info("make_policy")
|
logging.info("make_policy")
|
||||||
policy = make_policy(cfg, dataset_stats=offline_dataset.stats)
|
policy = make_policy(cfg, dataset_stats=offline_dataset.stats)
|
||||||
|
|
||||||
|
# Create optimizer and scheduler
|
||||||
# Temporary hack to move optimizer out of policy
|
# Temporary hack to move optimizer out of policy
|
||||||
if isinstance(policy, ActPolicy):
|
if cfg.policy.name == "act":
|
||||||
optimizer_params_dicts = [
|
optimizer_params_dicts = [
|
||||||
{"params": [p for n, p in policy.named_parameters() if not n.startswith("backbone") and p.requires_grad]},
|
{"params": [p for n, p in policy.named_parameters() if not n.startswith("backbone") and p.requires_grad]},
|
||||||
{
|
{
|
||||||
@@ -286,7 +287,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||||||
]
|
]
|
||||||
optimizer = torch.optim.AdamW(optimizer_params_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay)
|
optimizer = torch.optim.AdamW(optimizer_params_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay)
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
elif isinstance(policy, DiffusionPolicy):
|
elif cfg.policy.name == "diffusion":
|
||||||
optimizer = torch.optim.Adam(
|
optimizer = torch.optim.Adam(
|
||||||
policy.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay
|
policy.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay
|
||||||
)
|
)
|
||||||
@@ -362,7 +363,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||||||
for key in batch:
|
for key in batch:
|
||||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||||
|
|
||||||
train_info = update_policy(cfg, policy, batch, optimizer, lr_scheduler)
|
train_info = update_policy(policy, batch, optimizer, cfg.grad_clip_norm, lr_scheduler)
|
||||||
|
|
||||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||||
if step % cfg.log_freq == 0:
|
if step % cfg.log_freq == 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user