From 1df2a7b2dac14d0d5d2d7479057cc2aa87e1aab0 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Fri, 16 May 2025 14:25:21 +0200 Subject: [PATCH] Add review feedback --- lerobot/common/policies/sac/configuration_sac.py | 16 ++++++++++++++-- lerobot/common/policies/sac/modeling_sac.py | 9 +++++---- lerobot/scripts/server/buffer.py | 6 ++++-- lerobot/scripts/server/network_utils.py | 4 +++- 4 files changed, 26 insertions(+), 9 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 1f2e9bb8..bf6c923e 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -22,6 +22,18 @@ from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import NormalizationMode +def is_image_feature(key: str) -> bool: + """Check if a feature key represents an image feature. + + Args: + key: The feature key to check + + Returns: + True if the key represents an image feature, False otherwise + """ + return key.startswith("observation.image") + + @dataclass class ConcurrencyConfig: """Configuration for the concurrency of the actor and learner. @@ -203,7 +215,7 @@ class SACConfig(PreTrainedConfig): return None def validate_features(self) -> None: - has_image = any(key.startswith("observation.image") for key in self.input_features) + has_image = any(is_image_feature(key) for key in self.input_features) has_state = "observation.state" in self.input_features if not (has_state or has_image): @@ -216,7 +228,7 @@ class SACConfig(PreTrainedConfig): @property def image_features(self) -> list[str]: - return [key for key in self.input_features if "image" in key] + return [key for key in self.input_features if is_image_feature(key)] @property def observation_delta_indices(self) -> list: diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 257f37cb..a15974a9 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -29,7 +29,7 @@ from torch.distributions import MultivariateNormal, TanhTransform, Transform, Tr from lerobot.common.policies.normalize import NormalizeBuffer, UnnormalizeBuffer from lerobot.common.policies.pretrained import PreTrainedPolicy -from lerobot.common.policies.sac.configuration_sac import SACConfig +from lerobot.common.policies.sac.configuration_sac import SACConfig, is_image_feature from lerobot.common.policies.utils import get_device_from_parameters DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension @@ -264,6 +264,7 @@ class SACPolicy( ) # subsample critics to prevent overfitting if use high UTD (update to date) + # TODO: Get indices before forward pass to avoid unnecessary computation if self.config.num_subsample_critics is not None: indices = torch.randperm(self.config.num_critics) indices = indices[: self.config.num_subsample_critics] @@ -468,6 +469,7 @@ class SACPolicy( def _init_actor(self, continuous_action_dim): """Initialize policy actor network and default target entropy.""" + # NOTE: The actor select only the continuous action part self.actor = Policy( encoder=self.encoder_actor, network=MLP(input_dim=self.encoder_actor.output_dim, **asdict(self.config.actor_network_kwargs)), @@ -500,7 +502,7 @@ class SACObservationEncoder(nn.Module): self._compute_output_dim() def _init_image_layers(self) -> None: - self.image_keys = [k for k in self.config.input_features if k.startswith("observation.image")] + self.image_keys = [k for k in self.config.input_features if is_image_feature(k)] self.has_images = bool(self.image_keys) if not self.has_images: return @@ -928,7 +930,7 @@ class Policy(nn.Module): class DefaultImageEncoder(nn.Module): def __init__(self, config: SACConfig): super().__init__() - image_key = next(key for key in config.input_features.keys() if key.startswith("observation.image")) # noqa: SIM118 + image_key = next(key for key in config.input_features if is_image_feature(key)) self.image_enc_layers = nn.Sequential( nn.Conv2d( in_channels=config.input_features[image_key].shape[0], @@ -959,7 +961,6 @@ class DefaultImageEncoder(nn.Module): ), nn.ReLU(), ) - # Get first image key from input features def forward(self, x): x = self.image_enc_layers(x) diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 8da5cecf..45d0d089 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -271,9 +271,11 @@ class ReplayBuffer: # Split the augmented images back to their sources for i, key in enumerate(image_keys): - # State images are at even indices (0, 2, 4...) + # Calculate offsets for the current image key: + # For each key, we have 2*batch_size images (batch_size for states, batch_size for next_states) + # States start at index i*2*batch_size and take up batch_size slots batch_state[key] = augmented_images[i * 2 * batch_size : (i * 2 + 1) * batch_size] - # Next state images are at odd indices (1, 3, 5...) + # Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size] # Sample other tensors diff --git a/lerobot/scripts/server/network_utils.py b/lerobot/scripts/server/network_utils.py index c62f6cbd..1b1d8044 100644 --- a/lerobot/scripts/server/network_utils.py +++ b/lerobot/scripts/server/network_utils.py @@ -111,7 +111,9 @@ def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes: def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]: buffer = io.BytesIO(buffer) buffer.seek(0) - return torch.load(buffer, weights_only=False) # nosec B614: Safe usage of torch.load + return torch.load(buffer, weights_only=False) # nosec B614: Using weights_only=False relies on pickle which has security implications. + # This is currently safe as we only deserialize trusted internal data. + # TODO: Verify if weights_only=True would work for our use case (safer default in torch 2.6+) def python_object_to_bytes(python_object: Any) -> bytes: