Add review feedback

This commit is contained in:
AdilZouitine
2025-05-16 14:25:21 +02:00
parent fa72aed5b6
commit 1df2a7b2da
4 changed files with 26 additions and 9 deletions

View File

@@ -22,6 +22,18 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
def is_image_feature(key: str) -> bool:
"""Check if a feature key represents an image feature.
Args:
key: The feature key to check
Returns:
True if the key represents an image feature, False otherwise
"""
return key.startswith("observation.image")
@dataclass
class ConcurrencyConfig:
"""Configuration for the concurrency of the actor and learner.
@@ -203,7 +215,7 @@ class SACConfig(PreTrainedConfig):
return None
def validate_features(self) -> None:
has_image = any(key.startswith("observation.image") for key in self.input_features)
has_image = any(is_image_feature(key) for key in self.input_features)
has_state = "observation.state" in self.input_features
if not (has_state or has_image):
@@ -216,7 +228,7 @@ class SACConfig(PreTrainedConfig):
@property
def image_features(self) -> list[str]:
return [key for key in self.input_features if "image" in key]
return [key for key in self.input_features if is_image_feature(key)]
@property
def observation_delta_indices(self) -> list:

View File

@@ -29,7 +29,7 @@ from torch.distributions import MultivariateNormal, TanhTransform, Transform, Tr
from lerobot.common.policies.normalize import NormalizeBuffer, UnnormalizeBuffer
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.sac.configuration_sac import SACConfig, is_image_feature
from lerobot.common.policies.utils import get_device_from_parameters
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
@@ -264,6 +264,7 @@ class SACPolicy(
)
# subsample critics to prevent overfitting if use high UTD (update to date)
# TODO: Get indices before forward pass to avoid unnecessary computation
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[: self.config.num_subsample_critics]
@@ -468,6 +469,7 @@ class SACPolicy(
def _init_actor(self, continuous_action_dim):
"""Initialize policy actor network and default target entropy."""
# NOTE: The actor select only the continuous action part
self.actor = Policy(
encoder=self.encoder_actor,
network=MLP(input_dim=self.encoder_actor.output_dim, **asdict(self.config.actor_network_kwargs)),
@@ -500,7 +502,7 @@ class SACObservationEncoder(nn.Module):
self._compute_output_dim()
def _init_image_layers(self) -> None:
self.image_keys = [k for k in self.config.input_features if k.startswith("observation.image")]
self.image_keys = [k for k in self.config.input_features if is_image_feature(k)]
self.has_images = bool(self.image_keys)
if not self.has_images:
return
@@ -928,7 +930,7 @@ class Policy(nn.Module):
class DefaultImageEncoder(nn.Module):
def __init__(self, config: SACConfig):
super().__init__()
image_key = next(key for key in config.input_features.keys() if key.startswith("observation.image")) # noqa: SIM118
image_key = next(key for key in config.input_features if is_image_feature(key))
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
in_channels=config.input_features[image_key].shape[0],
@@ -959,7 +961,6 @@ class DefaultImageEncoder(nn.Module):
),
nn.ReLU(),
)
# Get first image key from input features
def forward(self, x):
x = self.image_enc_layers(x)

View File

@@ -271,9 +271,11 @@ class ReplayBuffer:
# Split the augmented images back to their sources
for i, key in enumerate(image_keys):
# State images are at even indices (0, 2, 4...)
# Calculate offsets for the current image key:
# For each key, we have 2*batch_size images (batch_size for states, batch_size for next_states)
# States start at index i*2*batch_size and take up batch_size slots
batch_state[key] = augmented_images[i * 2 * batch_size : (i * 2 + 1) * batch_size]
# Next state images are at odd indices (1, 3, 5...)
# Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots
batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
# Sample other tensors

View File

@@ -111,7 +111,9 @@ def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
return torch.load(buffer, weights_only=False) # nosec B614: Safe usage of torch.load
return torch.load(buffer, weights_only=False) # nosec B614: Using weights_only=False relies on pickle which has security implications.
# This is currently safe as we only deserialize trusted internal data.
# TODO: Verify if weights_only=True would work for our use case (safer default in torch 2.6+)
def python_object_to_bytes(python_object: Any) -> bytes: