diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 273f4f75..e0482143 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -239,10 +239,8 @@ class DiffusionModel(nn.Module): global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) # run sampling - sample = self.conditional_sample(batch_size, global_cond=global_cond) + actions = self.conditional_sample(batch_size, global_cond=global_cond) - # `horizon` steps worth of actions (from the first observation). - actions = sample[..., : self.config.output_shapes["action"][0]] # Extract `n_action_steps` steps worth of actions (from the current observation). start = n_obs_steps - 1 end = start + self.config.n_action_steps