|
|
|
|
@@ -65,16 +65,21 @@ class SACPolicy(
|
|
|
|
|
else:
|
|
|
|
|
self.normalize_inputs = nn.Identity()
|
|
|
|
|
|
|
|
|
|
output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
|
|
|
|
|
|
|
|
|
# HACK: This is hacky and should be removed
|
|
|
|
|
dataset_stats = dataset_stats or output_normalization_params
|
|
|
|
|
self.normalize_targets = Normalize(
|
|
|
|
|
config.output_features, config.normalization_mapping, dataset_stats
|
|
|
|
|
)
|
|
|
|
|
self.unnormalize_outputs = Unnormalize(
|
|
|
|
|
config.output_features, config.normalization_mapping, dataset_stats
|
|
|
|
|
)
|
|
|
|
|
if config.dataset_stats is not None:
|
|
|
|
|
output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
|
|
|
|
|
|
|
|
|
# HACK: This is hacky and should be removed
|
|
|
|
|
dataset_stats = dataset_stats or output_normalization_params
|
|
|
|
|
self.normalize_targets = Normalize(
|
|
|
|
|
config.output_features, config.normalization_mapping, dataset_stats
|
|
|
|
|
)
|
|
|
|
|
self.unnormalize_outputs = Unnormalize(
|
|
|
|
|
config.output_features, config.normalization_mapping, dataset_stats
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
self.normalize_targets = nn.Identity()
|
|
|
|
|
self.unnormalize_outputs = nn.Identity()
|
|
|
|
|
|
|
|
|
|
# NOTE: For images the encoder should be shared between the actor and critic
|
|
|
|
|
if config.shared_encoder:
|
|
|
|
|
@@ -192,7 +197,7 @@ class SACPolicy(
|
|
|
|
|
# We cached the encoder output to avoid recomputing it
|
|
|
|
|
observations_features = None
|
|
|
|
|
if self.shared_encoder:
|
|
|
|
|
observations_features = self.actor.encoder.get_image_features(batch)
|
|
|
|
|
observations_features = self.actor.encoder.get_image_features(batch, normalize=True)
|
|
|
|
|
|
|
|
|
|
actions, _, _ = self.actor(batch, observations_features)
|
|
|
|
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
|
|
|
|
@@ -568,8 +573,7 @@ class SACObservationEncoder(nn.Module):
|
|
|
|
|
feat = []
|
|
|
|
|
obs_dict = self.input_normalization(obs_dict)
|
|
|
|
|
if len(self.all_image_keys) > 0 and vision_encoder_cache is None:
|
|
|
|
|
vision_encoder_cache = self.get_image_features(obs_dict)
|
|
|
|
|
feat.append(vision_encoder_cache)
|
|
|
|
|
vision_encoder_cache = self.get_image_features(obs_dict, normalize=False)
|
|
|
|
|
|
|
|
|
|
if vision_encoder_cache is not None:
|
|
|
|
|
feat.append(vision_encoder_cache)
|
|
|
|
|
@@ -584,8 +588,10 @@ class SACObservationEncoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
return features
|
|
|
|
|
|
|
|
|
|
def get_image_features(self, batch: dict[str, Tensor]) -> torch.Tensor:
|
|
|
|
|
def get_image_features(self, batch: dict[str, Tensor], normalize: bool = True) -> torch.Tensor:
|
|
|
|
|
# [N*B, C, H, W]
|
|
|
|
|
if normalize:
|
|
|
|
|
batch = self.input_normalization(batch)
|
|
|
|
|
if len(self.all_image_keys) > 0:
|
|
|
|
|
# Batch all images along the batch dimension, then encode them.
|
|
|
|
|
images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0)
|
|
|
|
|
|