[PORT HIL-SERL] Add unit tests for SAC modeling (#999)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Eugene Mironov
2025-05-05 14:27:42 +07:00
committed by GitHub
parent fb7c288c94
commit 6fa7df35df
3 changed files with 267 additions and 5 deletions

View File

@@ -70,9 +70,10 @@ class SACConfig(PreTrainedConfig):
hyperparameters.
Args:
actor_network: Configuration for the actor network architecture.
critic_network: Configuration for the critic network architecture.
policy: Configuration for the policy parameters.
actor_network_kwargs: Configuration for the actor network architecture.
critic_network_kwargs: Configuration for the critic network architecture.
discrete_critic_network_kwargs: Configuration for the discrete critic network.
policy_kwargs: Configuration for the policy parameters.
n_obs_steps: Number of observation steps to consider.
normalization_mapping: Mapping of feature types to normalization modes.
dataset_stats: Statistics for normalizing different types of inputs.
@@ -88,7 +89,7 @@ class SACConfig(PreTrainedConfig):
num_discrete_actions: Number of discrete actions, eg for gripper actions.
image_embedding_pooling_dim: Dimension of the image embedding pooling.
concurrency: Configuration for concurrency settings.
actor_learner: Configuration for actor-learner architecture.
actor_learner_config: Configuration for actor-learner architecture.
online_steps: Number of steps for online training.
online_env_seed: Seed for the online environment.
online_buffer_capacity: Capacity of the online replay buffer.
@@ -140,7 +141,7 @@ class SACConfig(PreTrainedConfig):
# Architecture specifics
camera_number: int = 1
device: str = "cuda"
device: str = "cpu"
storage_device: str = "cpu"
# Set to "helper2424/resnet10" for hil serl
vision_encoder_name: str | None = None