format
This commit is contained in:
@@ -16,10 +16,10 @@ from lerobot.common.envs.factory import make_env
|
|||||||
from lerobot.common.logger import Logger, log_output_dir
|
from lerobot.common.logger import Logger, log_output_dir
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.utils.utils import (
|
from lerobot.common.utils.utils import (
|
||||||
format_big_number,
|
format_big_number,
|
||||||
get_safe_torch_device,
|
get_safe_torch_device,
|
||||||
init_logging,
|
init_logging,
|
||||||
set_global_seed,
|
set_global_seed,
|
||||||
)
|
)
|
||||||
from lerobot.scripts.eval import eval_policy
|
from lerobot.scripts.eval import eval_policy
|
||||||
|
|
||||||
@@ -40,7 +40,7 @@ def update_policy(policy, batch, optimizer, lr_scheduler=None):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# Diffusion
|
# Diffusion
|
||||||
model = policy.diffusion if hasattr(policy, "diffusion") else policy # TODO: hacky, remove this line
|
model = policy.diffusion if hasattr(policy, "diffusion") else policy # TODO: hacky, remove this line
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
policy.cfg.grad_clip_norm,
|
policy.cfg.grad_clip_norm,
|
||||||
@@ -54,7 +54,7 @@ def update_policy(policy, batch, optimizer, lr_scheduler=None):
|
|||||||
|
|
||||||
if hasattr(policy, "ema") and policy.ema is not None:
|
if hasattr(policy, "ema") and policy.ema is not None:
|
||||||
policy.ema.step(model)
|
policy.ema.step(model)
|
||||||
|
|
||||||
info = {
|
info = {
|
||||||
"loss": loss.item(),
|
"loss": loss.item(),
|
||||||
"grad_norm": float(grad_norm),
|
"grad_norm": float(grad_norm),
|
||||||
@@ -280,21 +280,13 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||||||
# Temporary hack to move optimizer out of policy
|
# Temporary hack to move optimizer out of policy
|
||||||
if isinstance(policy, ActPolicy):
|
if isinstance(policy, ActPolicy):
|
||||||
optimizer_params_dicts = [
|
optimizer_params_dicts = [
|
||||||
|
{"params": [p for n, p in policy.named_parameters() if not n.startswith("backbone") and p.requires_grad]},
|
||||||
{
|
{
|
||||||
"params": [
|
"params": [p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad],
|
||||||
p for n, p in policy.named_parameters() if not n.startswith("backbone") and p.requires_grad
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": [
|
|
||||||
p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad
|
|
||||||
],
|
|
||||||
"lr": policy.cfg.lr_backbone,
|
"lr": policy.cfg.lr_backbone,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
optimizer = torch.optim.AdamW(
|
optimizer = torch.optim.AdamW(optimizer_params_dicts, lr=policy.cfg.lr, weight_decay=policy.cfg.weight_decay)
|
||||||
optimizer_params_dicts, lr=policy.cfg.lr, weight_decay=policy.cfg.weight_decay
|
|
||||||
)
|
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
elif isinstance(policy, DiffusionPolicy):
|
elif isinstance(policy, DiffusionPolicy):
|
||||||
optimizer = torch.optim.Adam(
|
optimizer = torch.optim.Adam(
|
||||||
@@ -313,8 +305,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||||||
last_epoch=global_step - 1,
|
last_epoch=global_step - 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||||
|
|
||||||
@@ -397,9 +387,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||||||
# create dataloader for online training
|
# create dataloader for online training
|
||||||
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
||||||
weights = [1.0] * len(concat_dataset)
|
weights = [1.0] * len(concat_dataset)
|
||||||
sampler = torch.utils.data.WeightedRandomSampler(
|
sampler = torch.utils.data.WeightedRandomSampler(weights, num_samples=len(concat_dataset), replacement=True)
|
||||||
weights, num_samples=len(concat_dataset), replacement=True
|
|
||||||
)
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
concat_dataset,
|
concat_dataset,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
|
|||||||
Reference in New Issue
Block a user