Update constants

This commit is contained in:
Simon Alibert
2025-03-04 11:07:15 +01:00
parent a13e49073c
commit 2b24feb604
4 changed files with 17 additions and 17 deletions

View File

@@ -35,7 +35,7 @@ import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
from lerobot.common.constants import OBS_ENV_STATE, OBS_STATE
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
@@ -753,9 +753,9 @@ class TDMPCObservationEncoder(nn.Module):
)
)
if self.config.env_state_feature:
feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV]))
feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV_STATE]))
if self.config.robot_state_feature:
feat.append(self.state_enc_layers(obs_dict[OBS_ROBOT]))
feat.append(self.state_enc_layers(obs_dict[OBS_STATE]))
return torch.stack(feat, dim=0).mean(0)