Update constants

This commit is contained in:
Simon Alibert
2025-03-04 11:07:15 +01:00
parent a13e49073c
commit 2b24feb604
4 changed files with 17 additions and 17 deletions

View File

@@ -33,7 +33,7 @@ from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from torch import Tensor, nn
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
from lerobot.common.constants import OBS_ENV_STATE, OBS_STATE
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
@@ -238,8 +238,8 @@ class DiffusionModel(nn.Module):
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
"""Encode image features and concatenate them all together along with the state vector."""
batch_size, n_obs_steps = batch[OBS_ROBOT].shape[:2]
global_cond_feats = [batch[OBS_ROBOT]]
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
global_cond_feats = [batch[OBS_STATE]]
# Extract image features.
if self.config.image_features:
if self.config.use_separate_rgb_encoder_per_camera:
@@ -269,7 +269,7 @@ class DiffusionModel(nn.Module):
global_cond_feats.append(img_features)
if self.config.env_state_feature:
global_cond_feats.append(batch[OBS_ENV])
global_cond_feats.append(batch[OBS_ENV_STATE])
# Concatenate features then flatten to (B, global_cond_dim).
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)