From cf15cba5fc932fcbb03e2ca490b1d3870da0f8db Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 3 Jun 2024 13:04:24 +0100 Subject: [PATCH] Remove redundant slicing operation in Diffusion Policy (#240) --- lerobot/common/policies/diffusion/modeling_diffusion.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 273f4f758..e0482143d 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -239,10 +239,8 @@ class DiffusionModel(nn.Module): global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) # run sampling - sample = self.conditional_sample(batch_size, global_cond=global_cond) + actions = self.conditional_sample(batch_size, global_cond=global_cond) - # `horizon` steps worth of actions (from the first observation). - actions = sample[..., : self.config.output_shapes["action"][0]] # Extract `n_action_steps` steps worth of actions (from the current observation). start = n_obs_steps - 1 end = start + self.config.n_action_steps