diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 0d7bab95..48608537 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -111,12 +111,12 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): Schematically this looks like: ---------------------------------------------------------------------------------------------- (legend: o = n_obs_steps, h = horizon, a = n_action_steps) - |timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... |n-o+1+h| - |observation is used | YES | YES | YES | NO | NO | NO | NO | NO | NO | + |timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h | + |observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO | |action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES | |action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO | ---------------------------------------------------------------------------------------------- - Note that this means we require: `n_action_steps < horizon - n_obs_steps + 1`. Also, note that + Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that "horizon" may not the best name to describe what the variable actually means, because this period is actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. """