ready for review

This commit is contained in:
Alexander Soare
2024-04-12 11:36:52 +01:00
parent 5666ec3ec7
commit 6d0a45a97d
7 changed files with 11 additions and 42 deletions

View File

@@ -10,7 +10,6 @@ logger = logging.getLogger(__name__)
class _SinusoidalPosEmb(nn.Module):
# TODO(now): consolidate?
def __init__(self, dim):
super().__init__()
self.dim = dim

View File

@@ -10,18 +10,6 @@ from lerobot.common.policies.utils import get_device_from_parameters, get_dtype_
class DiffusionUnetImagePolicy(nn.Module):
"""
TODO(now): Add DDIM scheduler.
Changes: TODO(now)
- Use single image encoder for now instead of generic obs_encoder. We may generalize again when/if
needed. Code for a general observation encoder can be found at:
https://github.com/huggingface/lerobot/blob/920e0d118b493e4cc3058a9b1b764f38ae145d8e/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py
- Uses the observation as global conditioning for the Unet by default.
- Does not do any inpainting (which would be applicable if the observation were not used to condition
the Unet).
"""
def __init__(
self,
cfg,
@@ -87,7 +75,7 @@ class DiffusionUnetImagePolicy(nn.Module):
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
global_cond=global_cond,
)
# Compute previous image: x_t -> x_t-1 # TODO(now): Is this right?
# Compute previous image: x_t -> x_t-1
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
return sample