Update constants
This commit is contained in:
@@ -35,7 +35,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
|
||||
from lerobot.common.constants import OBS_ENV_STATE, OBS_STATE
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
@@ -753,9 +753,9 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
)
|
||||
)
|
||||
if self.config.env_state_feature:
|
||||
feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV]))
|
||||
feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV_STATE]))
|
||||
if self.config.robot_state_feature:
|
||||
feat.append(self.state_enc_layers(obs_dict[OBS_ROBOT]))
|
||||
feat.append(self.state_enc_layers(obs_dict[OBS_STATE]))
|
||||
return torch.stack(feat, dim=0).mean(0)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user