forked from tangger/lerobot
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
cdcf346061
commit
1c8daf11fd
@@ -34,7 +34,13 @@ def make_optimizer_and_scheduler(
|
||||
Returns:
|
||||
tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`.
|
||||
"""
|
||||
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
|
||||
params = (
|
||||
policy.get_optim_params()
|
||||
if cfg.use_policy_training_preset
|
||||
else policy.parameters()
|
||||
)
|
||||
optimizer = cfg.optimizer.build(params)
|
||||
lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
|
||||
lr_scheduler = (
|
||||
cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
|
||||
)
|
||||
return optimizer, lr_scheduler
|
||||
|
||||
@@ -102,7 +102,9 @@ def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> No
|
||||
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
|
||||
|
||||
|
||||
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
||||
def load_optimizer_state(
|
||||
optimizer: torch.optim.Optimizer, save_dir: Path
|
||||
) -> torch.optim.Optimizer:
|
||||
current_state_dict = optimizer.state_dict()
|
||||
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
||||
state = unflatten_dict(flat_state)
|
||||
|
||||
@@ -36,7 +36,9 @@ class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@abc.abstractmethod
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler | None:
|
||||
def build(
|
||||
self, optimizer: Optimizer, num_training_steps: int
|
||||
) -> LRScheduler | None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -49,7 +51,11 @@ class DiffuserSchedulerConfig(LRSchedulerConfig):
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer}
|
||||
kwargs = {
|
||||
**asdict(self),
|
||||
"num_training_steps": num_training_steps,
|
||||
"optimizer": optimizer,
|
||||
}
|
||||
return get_scheduler(**kwargs)
|
||||
|
||||
|
||||
@@ -71,7 +77,14 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
|
||||
progress = float(adjusted_step - self.num_warmup_steps) / float(
|
||||
max(1, num_training_steps - self.num_warmup_steps)
|
||||
)
|
||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))
|
||||
return max(
|
||||
0.0,
|
||||
0.5
|
||||
* (
|
||||
1.0
|
||||
+ math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)
|
||||
),
|
||||
)
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
@@ -98,7 +111,9 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
|
||||
def cosine_decay_schedule(current_step):
|
||||
step = min(current_step, self.num_decay_steps)
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
|
||||
cosine_decay = 0.5 * (
|
||||
1 + math.cos(math.pi * step / self.num_decay_steps)
|
||||
)
|
||||
alpha = self.decay_lr / self.peak_lr
|
||||
decayed = (1 - alpha) * cosine_decay + alpha
|
||||
return decayed
|
||||
@@ -117,6 +132,8 @@ def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
|
||||
|
||||
|
||||
def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler:
|
||||
state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict())
|
||||
state_dict = deserialize_json_into_object(
|
||||
save_dir / SCHEDULER_STATE, scheduler.state_dict()
|
||||
)
|
||||
scheduler.load_state_dict(state_dict)
|
||||
return scheduler
|
||||
|
||||
Reference in New Issue
Block a user