Add review feedback
This commit is contained in:
@@ -22,6 +22,18 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
def is_image_feature(key: str) -> bool:
|
||||
"""Check if a feature key represents an image feature.
|
||||
|
||||
Args:
|
||||
key: The feature key to check
|
||||
|
||||
Returns:
|
||||
True if the key represents an image feature, False otherwise
|
||||
"""
|
||||
return key.startswith("observation.image")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConcurrencyConfig:
|
||||
"""Configuration for the concurrency of the actor and learner.
|
||||
@@ -203,7 +215,7 @@ class SACConfig(PreTrainedConfig):
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
has_image = any(key.startswith("observation.image") for key in self.input_features)
|
||||
has_image = any(is_image_feature(key) for key in self.input_features)
|
||||
has_state = "observation.state" in self.input_features
|
||||
|
||||
if not (has_state or has_image):
|
||||
@@ -216,7 +228,7 @@ class SACConfig(PreTrainedConfig):
|
||||
|
||||
@property
|
||||
def image_features(self) -> list[str]:
|
||||
return [key for key in self.input_features if "image" in key]
|
||||
return [key for key in self.input_features if is_image_feature(key)]
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
|
||||
@@ -29,7 +29,7 @@ from torch.distributions import MultivariateNormal, TanhTransform, Transform, Tr
|
||||
|
||||
from lerobot.common.policies.normalize import NormalizeBuffer, UnnormalizeBuffer
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig, is_image_feature
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
|
||||
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
|
||||
@@ -264,6 +264,7 @@ class SACPolicy(
|
||||
)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
# TODO: Get indices before forward pass to avoid unnecessary computation
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
@@ -468,6 +469,7 @@ class SACPolicy(
|
||||
|
||||
def _init_actor(self, continuous_action_dim):
|
||||
"""Initialize policy actor network and default target entropy."""
|
||||
# NOTE: The actor select only the continuous action part
|
||||
self.actor = Policy(
|
||||
encoder=self.encoder_actor,
|
||||
network=MLP(input_dim=self.encoder_actor.output_dim, **asdict(self.config.actor_network_kwargs)),
|
||||
@@ -500,7 +502,7 @@ class SACObservationEncoder(nn.Module):
|
||||
self._compute_output_dim()
|
||||
|
||||
def _init_image_layers(self) -> None:
|
||||
self.image_keys = [k for k in self.config.input_features if k.startswith("observation.image")]
|
||||
self.image_keys = [k for k in self.config.input_features if is_image_feature(k)]
|
||||
self.has_images = bool(self.image_keys)
|
||||
if not self.has_images:
|
||||
return
|
||||
@@ -928,7 +930,7 @@ class Policy(nn.Module):
|
||||
class DefaultImageEncoder(nn.Module):
|
||||
def __init__(self, config: SACConfig):
|
||||
super().__init__()
|
||||
image_key = next(key for key in config.input_features.keys() if key.startswith("observation.image")) # noqa: SIM118
|
||||
image_key = next(key for key in config.input_features if is_image_feature(key))
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=config.input_features[image_key].shape[0],
|
||||
@@ -959,7 +961,6 @@ class DefaultImageEncoder(nn.Module):
|
||||
),
|
||||
nn.ReLU(),
|
||||
)
|
||||
# Get first image key from input features
|
||||
|
||||
def forward(self, x):
|
||||
x = self.image_enc_layers(x)
|
||||
|
||||
@@ -271,9 +271,11 @@ class ReplayBuffer:
|
||||
|
||||
# Split the augmented images back to their sources
|
||||
for i, key in enumerate(image_keys):
|
||||
# State images are at even indices (0, 2, 4...)
|
||||
# Calculate offsets for the current image key:
|
||||
# For each key, we have 2*batch_size images (batch_size for states, batch_size for next_states)
|
||||
# States start at index i*2*batch_size and take up batch_size slots
|
||||
batch_state[key] = augmented_images[i * 2 * batch_size : (i * 2 + 1) * batch_size]
|
||||
# Next state images are at odd indices (1, 3, 5...)
|
||||
# Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots
|
||||
batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
|
||||
|
||||
# Sample other tensors
|
||||
|
||||
@@ -111,7 +111,9 @@ def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
|
||||
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
return torch.load(buffer, weights_only=False) # nosec B614: Safe usage of torch.load
|
||||
return torch.load(buffer, weights_only=False) # nosec B614: Using weights_only=False relies on pickle which has security implications.
|
||||
# This is currently safe as we only deserialize trusted internal data.
|
||||
# TODO: Verify if weights_only=True would work for our use case (safer default in torch 2.6+)
|
||||
|
||||
|
||||
def python_object_to_bytes(python_object: Any) -> bytes:
|
||||
|
||||
Reference in New Issue
Block a user