From a60d27b132b91d79ec5d705af78af84dc4509ab5 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 9 Sep 2024 17:22:46 +0100 Subject: [PATCH] Raise ValueError if horizon is incompatible with downsampling (#422) --- .../common/policies/diffusion/configuration_diffusion.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index 1e1f9d28..bd3692ac 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -196,3 +196,12 @@ class DiffusionConfig: f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. " f"Got {self.noise_scheduler_type}." ) + + # Check that the horizon size and U-Net downsampling is compatible. + # U-Net downsamples by 2 with each stage. + downsampling_factor = 2 ** len(self.down_dims) + if self.horizon % downsampling_factor != 0: + raise ValueError( + "The horizon should be an integer multiple of the downsampling factor (which is determined " + f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}" + )