forked from tangger/lerobot
format
This commit is contained in:
@@ -280,21 +280,13 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
# Temporary hack to move optimizer out of policy
|
||||
if isinstance(policy, ActPolicy):
|
||||
optimizer_params_dicts = [
|
||||
{"params": [p for n, p in policy.named_parameters() if not n.startswith("backbone") and p.requires_grad]},
|
||||
{
|
||||
"params": [
|
||||
p for n, p in policy.named_parameters() if not n.startswith("backbone") and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad
|
||||
],
|
||||
"params": [p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad],
|
||||
"lr": policy.cfg.lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(
|
||||
optimizer_params_dicts, lr=policy.cfg.lr, weight_decay=policy.cfg.weight_decay
|
||||
)
|
||||
optimizer = torch.optim.AdamW(optimizer_params_dicts, lr=policy.cfg.lr, weight_decay=policy.cfg.weight_decay)
|
||||
lr_scheduler = None
|
||||
elif isinstance(policy, DiffusionPolicy):
|
||||
optimizer = torch.optim.Adam(
|
||||
@@ -313,8 +305,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
last_epoch=global_step - 1,
|
||||
)
|
||||
|
||||
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
@@ -397,9 +387,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
# create dataloader for online training
|
||||
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
||||
weights = [1.0] * len(concat_dataset)
|
||||
sampler = torch.utils.data.WeightedRandomSampler(
|
||||
weights, num_samples=len(concat_dataset), replacement=True
|
||||
)
|
||||
sampler = torch.utils.data.WeightedRandomSampler(weights, num_samples=len(concat_dataset), replacement=True)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
concat_dataset,
|
||||
num_workers=4,
|
||||
|
||||
Reference in New Issue
Block a user