Several fixes to move the actor_server and learner_server code from the maniskill environment to the real robot environment.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi
2025-02-10 16:03:39 +01:00
parent af769abd8d
commit 9784d8a47f
10 changed files with 597 additions and 318 deletions

View File

@@ -39,6 +39,12 @@ class SACConfig:
"observation.environment_state": "min_max",
}
)
input_normalization_params: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: {
"observation.image": {"mean": [[0.485, 0.456, 0.406]], "std": [[0.229, 0.224, 0.225]]},
"observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]},
}
)
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
output_normalization_params: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: {

View File

@@ -51,18 +51,20 @@ class SACPolicy(
if config is None:
config = SACConfig()
self.config = config
if config.input_normalization_modes is not None:
input_normalization_params = _convert_normalization_params_to_tensor(
config.input_normalization_params
)
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
config.input_shapes, config.input_normalization_modes, input_normalization_params
)
else:
self.normalize_inputs = nn.Identity()
output_normalization_params = {}
for outer_key, inner_dict in config.output_normalization_params.items():
output_normalization_params[outer_key] = {}
for key, value in inner_dict.items():
output_normalization_params[outer_key][key] = torch.tensor(value)
output_normalization_params = _convert_normalization_params_to_tensor(
config.output_normalization_params
)
# HACK: This is hacky and should be removed
dataset_stats = dataset_stats or output_normalization_params
@@ -75,7 +77,7 @@ class SACPolicy(
# NOTE: For images the encoder should be shared between the actor and critic
if config.shared_encoder:
encoder_critic = SACObservationEncoder(config)
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
encoder_actor: SACObservationEncoder = encoder_critic
else:
encoder_critic = SACObservationEncoder(config)
@@ -92,6 +94,7 @@ class SACPolicy(
for _ in range(config.num_critics)
]
),
output_normalization=self.normalize_targets,
)
self.critic_target = CriticEnsemble(
@@ -105,6 +108,7 @@ class SACPolicy(
for _ in range(config.num_critics)
]
),
output_normalization=self.normalize_targets,
)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
@@ -122,7 +126,7 @@ class SACPolicy(
# TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
# it triggers "can't optimize a non-leaf Tensor"
self.log_alpha = torch.zeros(1, requires_grad=True, device=torch.device("cuda:0"))
self.log_alpha = torch.tensor([0.0], requires_grad=True, device=torch.device("mps"))
self.temperature = self.log_alpha.exp().item()
def reset(self):
@@ -313,12 +317,14 @@ class CriticEnsemble(nn.Module):
self,
encoder: Optional[nn.Module],
network_list: nn.ModuleList,
output_normalization: nn.Module,
init_final: Optional[float] = None,
):
super().__init__()
self.encoder = encoder
self.network_list = network_list
self.init_final = init_final
self.output_normalization = output_normalization
self.parameters_to_optimize = []
# Handle the case where a part of the encoder if frozen
@@ -358,6 +364,10 @@ class CriticEnsemble(nn.Module):
device = get_device_from_parameters(self)
# Move each tensor in observations to device
observations = {k: v.to(device) for k, v in observations.items()}
# NOTE: We normalize actions it helps for sample efficiency
actions: dict[str, torch.tensor] = {"action": actions}
# NOTE: Normalization layer took dict in input and outputs a dict that why
actions = self.output_normalization(actions)["action"]
actions = actions.to(device)
obs_enc = observations if self.encoder is None else self.encoder(observations)
@@ -474,17 +484,18 @@ class Policy(nn.Module):
class SACObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
def __init__(self, config: SACConfig):
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
"""
Creates encoders for pixel and/or state modalities.
"""
super().__init__()
self.config = config
self.input_normalization = input_normalizer
self.has_pretrained_vision_encoder = False
self.parameters_to_optimize = []
self.aggregation_size: int = 0
if "observation.image" in config.input_shapes:
if any("observation.image" in key for key in config.input_shapes):
self.camera_number = config.camera_number
if self.config.vision_encoder_name is not None:
@@ -534,8 +545,9 @@ class SACObservationEncoder(nn.Module):
over all features.
"""
feat = []
obs_dict = self.input_normalization(obs_dict)
# Concatenate all images along the channel dimension.
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
image_keys = [k for k in obs_dict if k.startswith("observation.image")]
for image_key in image_keys:
enc_feat = self.image_enc_layers(obs_dict[image_key])
@@ -681,6 +693,18 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
converted_params = {}
for outer_key, inner_dict in normalization_params.items():
converted_params[outer_key] = {}
for key, value in inner_dict.items():
converted_params[outer_key][key] = torch.tensor(value)
if "image" in outer_key:
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
return converted_params
if __name__ == "__main__":
# Test the SACObservationEncoder
import time

View File

@@ -18,6 +18,7 @@ import os
import os.path as osp
import platform
import subprocess
import time
from copy import copy
from datetime import datetime, timezone
from pathlib import Path
@@ -228,3 +229,28 @@ def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
except TypeError:
# If a TypeError is raised, the string is not a valid dtype
return False
class TimerManager:
def __init__(self, elapsed_time_list: list[float] | None = None, label="Elapsed time", log=True):
self.label = label
self.elapsed_time_list = elapsed_time_list
self.log = log
self.elapsed = 0.0
def __enter__(self):
self.start = time.perf_counter()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.elapsed: float = time.perf_counter() - self.start
if self.elapsed_time_list is not None:
self.elapsed_time_list.append(self.elapsed)
if self.log:
print(f"{self.label}: {self.elapsed:.6f} seconds")
@property
def elapsed_seconds(self):
return self.elapsed