forked from tangger/lerobot
Remove EMA model from Diffusion Policy (#134)
This commit is contained in:
@@ -121,7 +121,7 @@ def rollout(
|
||||
max_steps = env.call("_max_episode_steps")[0]
|
||||
progbar = trange(
|
||||
max_steps,
|
||||
desc=f"Running rollout with {max_steps} steps (maximum) per rollout",
|
||||
desc=f"Running rollout with at most {max_steps} steps",
|
||||
disable=not enable_progbar,
|
||||
leave=False,
|
||||
)
|
||||
|
||||
@@ -89,9 +89,6 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step()
|
||||
|
||||
if hasattr(policy, "ema") and policy.ema is not None:
|
||||
policy.ema.step(policy.diffusion)
|
||||
|
||||
if isinstance(policy, PolicyWithUpdate):
|
||||
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
|
||||
policy.update()
|
||||
|
||||
Reference in New Issue
Block a user