forked from tangger/lerobot
Added gripper control mechanism to gym_manipulator
Moved HilSerl env config to configs/env/configs.py fixes in actor_server and modeling_sac and configuration_sac added the possibility of ignoring missing keys in env_cfg in get_features_from_env_config function
This commit is contained in:
committed by
AdilZouitine
parent
79e0f6e06c
commit
02b9ea9446
@@ -20,7 +20,7 @@ from typing import Any, Optional
|
||||
|
||||
from lerobot.common.optim.optimizers import MultiAdamConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode, PolicyFeature, FeatureType
|
||||
from lerobot.configs.types import NormalizationMode, PolicyFeature, FeatureType
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -29,7 +29,6 @@ class ConcurrencyConfig:
|
||||
learner: str = "threads"
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorLearnerConfig:
|
||||
learner_host: str = "127.0.0.1"
|
||||
@@ -110,6 +109,7 @@ class SACConfig(PreTrainedConfig):
|
||||
use_backup_entropy: Whether to use backup entropy for the SAC algorithm.
|
||||
grad_clip_norm: Gradient clipping norm for the SAC algorithm.
|
||||
"""
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
@@ -152,8 +152,8 @@ class SACConfig(PreTrainedConfig):
|
||||
camera_number: int = 1
|
||||
device: str = "cuda"
|
||||
storage_device: str = "cpu"
|
||||
# Set to "helper2424/resnet10" for hil serl
|
||||
vision_encoder_name: str | None = None
|
||||
# Set to "helper2424/resnet10" for hil serl
|
||||
vision_encoder_name: str | None = None
|
||||
freeze_vision_encoder: bool = True
|
||||
image_encoder_hidden_dim: int = 32
|
||||
shared_encoder: bool = True
|
||||
@@ -163,7 +163,7 @@ class SACConfig(PreTrainedConfig):
|
||||
online_env_seed: int = 10000
|
||||
online_buffer_capacity: int = 100000
|
||||
offline_buffer_capacity: int = 100000
|
||||
online_step_before_learning: int = 100
|
||||
online_step_before_learning: int = 100
|
||||
policy_update_freq: int = 1
|
||||
|
||||
# SAC algorithm parameters
|
||||
@@ -181,24 +181,14 @@ class SACConfig(PreTrainedConfig):
|
||||
target_entropy: float | None = None
|
||||
use_backup_entropy: bool = True
|
||||
grad_clip_norm: float = 40.0
|
||||
|
||||
|
||||
# Network configuration
|
||||
critic_network_kwargs: CriticNetworkConfig = field(
|
||||
default_factory=CriticNetworkConfig
|
||||
)
|
||||
actor_network_kwargs: ActorNetworkConfig = field(
|
||||
default_factory=ActorNetworkConfig
|
||||
)
|
||||
policy_kwargs: PolicyConfig = field(
|
||||
default_factory=PolicyConfig
|
||||
)
|
||||
|
||||
actor_learner_config: ActorLearnerConfig = field(
|
||||
default_factory=ActorLearnerConfig
|
||||
)
|
||||
concurrency: ConcurrencyConfig = field(
|
||||
default_factory=ConcurrencyConfig
|
||||
)
|
||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
||||
|
||||
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
@@ -218,18 +208,20 @@ class SACConfig(PreTrainedConfig):
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if "observation.image" not in self.input_features:
|
||||
raise ValueError("You must provide 'observation.image' in the input features")
|
||||
|
||||
if "observation.state" not in self.input_features:
|
||||
raise ValueError("You must provide 'observation.state' in the input features")
|
||||
|
||||
has_image = any(key.startswith("observation.image") for key in self.input_features)
|
||||
has_state = "observation.state" in self.input_features
|
||||
|
||||
if not (has_state or has_image):
|
||||
raise ValueError(
|
||||
"You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features"
|
||||
)
|
||||
|
||||
if "action" not in self.output_features:
|
||||
raise ValueError("You must provide 'action' in the output features")
|
||||
|
||||
@property
|
||||
def image_features(self) -> list[str]:
|
||||
return [key for key in self.input_features.keys() if 'image' in key]
|
||||
return [key for key in self.input_features.keys() if "image" in key]
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
@@ -243,9 +235,13 @@ class SACConfig(PreTrainedConfig):
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import draccus
|
||||
|
||||
config = SACConfig()
|
||||
draccus.set_config_type("json")
|
||||
draccus.dump(config=config, stream=open(file='run_config.json', mode='w'), )
|
||||
|
||||
draccus.dump(
|
||||
config=config,
|
||||
stream=open(file="run_config.json", mode="w"),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user