forked from tangger/lerobot
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
2abbd60a0d
commit
0ea27704f6
@@ -34,13 +34,7 @@ def make_optimizer_and_scheduler(
|
||||
Returns:
|
||||
tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`.
|
||||
"""
|
||||
params = (
|
||||
policy.get_optim_params()
|
||||
if cfg.use_policy_training_preset
|
||||
else policy.parameters()
|
||||
)
|
||||
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
|
||||
optimizer = cfg.optimizer.build(params)
|
||||
lr_scheduler = (
|
||||
cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
|
||||
)
|
||||
lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
|
||||
return optimizer, lr_scheduler
|
||||
|
||||
@@ -102,9 +102,7 @@ def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> No
|
||||
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
|
||||
|
||||
|
||||
def load_optimizer_state(
|
||||
optimizer: torch.optim.Optimizer, save_dir: Path
|
||||
) -> torch.optim.Optimizer:
|
||||
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
||||
current_state_dict = optimizer.state_dict()
|
||||
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
||||
state = unflatten_dict(flat_state)
|
||||
|
||||
@@ -36,9 +36,7 @@ class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@abc.abstractmethod
|
||||
def build(
|
||||
self, optimizer: Optimizer, num_training_steps: int
|
||||
) -> LRScheduler | None:
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler | None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -79,11 +77,7 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
|
||||
)
|
||||
return max(
|
||||
0.0,
|
||||
0.5
|
||||
* (
|
||||
1.0
|
||||
+ math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)
|
||||
),
|
||||
0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)),
|
||||
)
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
@@ -111,9 +105,7 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
|
||||
def cosine_decay_schedule(current_step):
|
||||
step = min(current_step, self.num_decay_steps)
|
||||
cosine_decay = 0.5 * (
|
||||
1 + math.cos(math.pi * step / self.num_decay_steps)
|
||||
)
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
|
||||
alpha = self.decay_lr / self.peak_lr
|
||||
decayed = (1 - alpha) * cosine_decay + alpha
|
||||
return decayed
|
||||
@@ -132,8 +124,6 @@ def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
|
||||
|
||||
|
||||
def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler:
|
||||
state_dict = deserialize_json_into_object(
|
||||
save_dir / SCHEDULER_STATE, scheduler.state_dict()
|
||||
)
|
||||
state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict())
|
||||
scheduler.load_state_dict(state_dict)
|
||||
return scheduler
|
||||
|
||||
Reference in New Issue
Block a user