Remove EMA model from Diffusion Policy (#134)

This commit is contained in:
Alexander Soare
2024-05-05 11:26:12 +01:00
committed by GitHub
parent d747195c57
commit f3bba0270d
11 changed files with 21 additions and 117 deletions

View File

@@ -121,7 +121,7 @@ def rollout(
max_steps = env.call("_max_episode_steps")[0]
progbar = trange(
max_steps,
desc=f"Running rollout with {max_steps} steps (maximum) per rollout",
desc=f"Running rollout with at most {max_steps} steps",
disable=not enable_progbar,
leave=False,
)

View File

@@ -89,9 +89,6 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
if lr_scheduler is not None:
lr_scheduler.step()
if hasattr(policy, "ema") and policy.ema is not None:
policy.ema.step(policy.diffusion)
if isinstance(policy, PolicyWithUpdate):
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
policy.update()