ready for review
This commit is contained in:
@@ -10,7 +10,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _SinusoidalPosEmb(nn.Module):
|
||||
# TODO(now): consolidate?
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
@@ -10,18 +10,6 @@ from lerobot.common.policies.utils import get_device_from_parameters, get_dtype_
|
||||
|
||||
|
||||
class DiffusionUnetImagePolicy(nn.Module):
|
||||
"""
|
||||
TODO(now): Add DDIM scheduler.
|
||||
|
||||
Changes: TODO(now)
|
||||
- Use single image encoder for now instead of generic obs_encoder. We may generalize again when/if
|
||||
needed. Code for a general observation encoder can be found at:
|
||||
https://github.com/huggingface/lerobot/blob/920e0d118b493e4cc3058a9b1b764f38ae145d8e/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py
|
||||
- Uses the observation as global conditioning for the Unet by default.
|
||||
- Does not do any inpainting (which would be applicable if the observation were not used to condition
|
||||
the Unet).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
@@ -87,7 +75,7 @@ class DiffusionUnetImagePolicy(nn.Module):
|
||||
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
|
||||
global_cond=global_cond,
|
||||
)
|
||||
# Compute previous image: x_t -> x_t-1 # TODO(now): Is this right?
|
||||
# Compute previous image: x_t -> x_t-1
|
||||
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
|
||||
|
||||
return sample
|
||||
|
||||
Reference in New Issue
Block a user