forked from tangger/lerobot
Compare commits
3 Commits
recovered-
...
user/miche
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
06fc9b89e1 | ||
|
|
3034272229 | ||
|
|
bbce0eaeaf |
@@ -160,6 +160,31 @@ class ACTPolicy(
|
|||||||
|
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
|
||||||
|
def make_optimizer_and_scheduler(self, cfg):
|
||||||
|
"""Create the optimizer and learning rate scheduler for ACT"""
|
||||||
|
optimizer_params_dicts = [
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p
|
||||||
|
for n, p in self.named_parameters()
|
||||||
|
if not n.startswith("model.backbone") and p.requires_grad
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p
|
||||||
|
for n, p in self.named_parameters()
|
||||||
|
if n.startswith("model.backbone") and p.requires_grad
|
||||||
|
],
|
||||||
|
"lr": cfg.training.lr_backbone,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
|
||||||
|
)
|
||||||
|
lr_scheduler = None
|
||||||
|
return optimizer, lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
class ACTTemporalEnsembler:
|
class ACTTemporalEnsembler:
|
||||||
def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None:
|
def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None:
|
||||||
|
|||||||
@@ -156,6 +156,25 @@ class DiffusionPolicy(
|
|||||||
loss = self.diffusion.compute_loss(batch)
|
loss = self.diffusion.compute_loss(batch)
|
||||||
return {"loss": loss}
|
return {"loss": loss}
|
||||||
|
|
||||||
|
def make_optimizer_and_scheduler(self, cfg):
|
||||||
|
"""Create the optimizer and learning rate scheduler for Diffusion policy"""
|
||||||
|
optimizer = torch.optim.Adam(
|
||||||
|
self.diffusion.parameters(),
|
||||||
|
cfg.training.lr,
|
||||||
|
cfg.training.adam_betas,
|
||||||
|
cfg.training.adam_eps,
|
||||||
|
cfg.training.adam_weight_decay,
|
||||||
|
)
|
||||||
|
from diffusers.optimization import get_scheduler
|
||||||
|
|
||||||
|
lr_scheduler = get_scheduler(
|
||||||
|
cfg.training.lr_scheduler,
|
||||||
|
optimizer=optimizer,
|
||||||
|
num_warmup_steps=cfg.training.lr_warmup_steps,
|
||||||
|
num_training_steps=cfg.training.offline_steps,
|
||||||
|
)
|
||||||
|
return optimizer, lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
|
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -534,6 +534,12 @@ class TDMPCPolicy(
|
|||||||
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
|
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
|
||||||
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
|
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
|
||||||
|
|
||||||
|
def make_optimizer_and_scheduler(self, cfg):
|
||||||
|
"""Create the optimizer and learning rate scheduler for TD-MPC"""
|
||||||
|
optimizer = torch.optim.Adam(self.parameters(), cfg.training.lr)
|
||||||
|
lr_scheduler = None
|
||||||
|
return optimizer, lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
class TDMPCTOLD(nn.Module):
|
class TDMPCTOLD(nn.Module):
|
||||||
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
|
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
|
||||||
|
|||||||
@@ -152,6 +152,12 @@ class VQBeTPolicy(
|
|||||||
|
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
|
||||||
|
def make_optimizer_and_scheduler(self, cfg):
|
||||||
|
"""Create the optimizer and learning rate scheduler for VQ-BeT"""
|
||||||
|
optimizer = VQBeTOptimizer(self, cfg)
|
||||||
|
scheduler = VQBeTScheduler(optimizer, cfg)
|
||||||
|
return optimizer, scheduler
|
||||||
|
|
||||||
|
|
||||||
class SpatialSoftmax(nn.Module):
|
class SpatialSoftmax(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -51,59 +51,6 @@ from lerobot.common.utils.utils import (
|
|||||||
from lerobot.scripts.eval import eval_policy
|
from lerobot.scripts.eval import eval_policy
|
||||||
|
|
||||||
|
|
||||||
def make_optimizer_and_scheduler(cfg, policy):
|
|
||||||
if cfg.policy.name == "act":
|
|
||||||
optimizer_params_dicts = [
|
|
||||||
{
|
|
||||||
"params": [
|
|
||||||
p
|
|
||||||
for n, p in policy.named_parameters()
|
|
||||||
if not n.startswith("model.backbone") and p.requires_grad
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": [
|
|
||||||
p
|
|
||||||
for n, p in policy.named_parameters()
|
|
||||||
if n.startswith("model.backbone") and p.requires_grad
|
|
||||||
],
|
|
||||||
"lr": cfg.training.lr_backbone,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
optimizer = torch.optim.AdamW(
|
|
||||||
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
|
|
||||||
)
|
|
||||||
lr_scheduler = None
|
|
||||||
elif cfg.policy.name == "diffusion":
|
|
||||||
optimizer = torch.optim.Adam(
|
|
||||||
policy.diffusion.parameters(),
|
|
||||||
cfg.training.lr,
|
|
||||||
cfg.training.adam_betas,
|
|
||||||
cfg.training.adam_eps,
|
|
||||||
cfg.training.adam_weight_decay,
|
|
||||||
)
|
|
||||||
from diffusers.optimization import get_scheduler
|
|
||||||
|
|
||||||
lr_scheduler = get_scheduler(
|
|
||||||
cfg.training.lr_scheduler,
|
|
||||||
optimizer=optimizer,
|
|
||||||
num_warmup_steps=cfg.training.lr_warmup_steps,
|
|
||||||
num_training_steps=cfg.training.offline_steps,
|
|
||||||
)
|
|
||||||
elif policy.name == "tdmpc":
|
|
||||||
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
|
|
||||||
lr_scheduler = None
|
|
||||||
elif cfg.policy.name == "vqbet":
|
|
||||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
|
|
||||||
|
|
||||||
optimizer = VQBeTOptimizer(policy, cfg)
|
|
||||||
lr_scheduler = VQBeTScheduler(optimizer, cfg)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
return optimizer, lr_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
def update_policy(
|
def update_policy(
|
||||||
policy,
|
policy,
|
||||||
batch,
|
batch,
|
||||||
@@ -334,7 +281,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||||||
assert isinstance(policy, nn.Module)
|
assert isinstance(policy, nn.Module)
|
||||||
# Create optimizer and scheduler
|
# Create optimizer and scheduler
|
||||||
# Temporary hack to move optimizer out of policy
|
# Temporary hack to move optimizer out of policy
|
||||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
optimizer, lr_scheduler = policy.make_optimizer_and_scheduler(cfg)
|
||||||
grad_scaler = GradScaler(enabled=cfg.use_amp)
|
grad_scaler = GradScaler(enabled=cfg.use_amp)
|
||||||
|
|
||||||
step = 0 # number of policy updates (forward + backward + optim)
|
step = 0 # number of policy updates (forward + backward + optim)
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ from safetensors.torch import save_file
|
|||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.utils.utils import init_hydra_config, set_global_seed
|
from lerobot.common.utils.utils import init_hydra_config, set_global_seed
|
||||||
from lerobot.scripts.train import make_optimizer_and_scheduler
|
|
||||||
from tests.utils import DEFAULT_CONFIG_PATH
|
from tests.utils import DEFAULT_CONFIG_PATH
|
||||||
|
|
||||||
|
|
||||||
@@ -40,7 +39,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
|||||||
dataset = make_dataset(cfg)
|
dataset = make_dataset(cfg)
|
||||||
policy = make_policy(cfg, dataset_stats=dataset.stats)
|
policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||||
policy.train()
|
policy.train()
|
||||||
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
optimizer, _ = policy.make_optimizer_and_scheduler(cfg)
|
||||||
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ from lerobot.common.policies.factory import (
|
|||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.policy_protocol import Policy
|
from lerobot.common.policies.policy_protocol import Policy
|
||||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||||
from lerobot.scripts.train import make_optimizer_and_scheduler
|
|
||||||
from tests.scripts.save_policy_to_safetensors import get_policy_stats
|
from tests.scripts.save_policy_to_safetensors import get_policy_stats
|
||||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
|
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||||
|
|
||||||
@@ -214,7 +213,7 @@ def test_act_backbone_lr():
|
|||||||
|
|
||||||
dataset = make_dataset(cfg)
|
dataset = make_dataset(cfg)
|
||||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||||
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
optimizer, _ = policy.make_optimizer_and_scheduler(cfg)
|
||||||
assert len(optimizer.param_groups) == 2
|
assert len(optimizer.param_groups) == 2
|
||||||
assert optimizer.param_groups[0]["lr"] == cfg.training.lr
|
assert optimizer.param_groups[0]["lr"] == cfg.training.lr
|
||||||
assert optimizer.param_groups[1]["lr"] == cfg.training.lr_backbone
|
assert optimizer.param_groups[1]["lr"] == cfg.training.lr_backbone
|
||||||
|
|||||||
Reference in New Issue
Block a user