- Added additional logging information in wandb around the timings of the policy loop and optimization loop.
- Optimized critic design that improves the performance of the learner loop by a factor of 2 - Cleaned the code and fixed style issues - Completed the config with actor_learner_config field that contains host-ip and port elemnts that are necessary for the actor-learner servers. Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
committed by
AdilZouitine
parent
a0a81c0c12
commit
18207d995e
@@ -45,6 +45,14 @@ class SACConfig:
|
||||
"action": {"min": [-1, -1], "max": [1, 1]},
|
||||
}
|
||||
)
|
||||
# TODO: Move it outside of the config
|
||||
actor_learner_config: dict[str, str | int] = field(
|
||||
default_factory=lambda: {
|
||||
"actor_ip": "127.0.0.1",
|
||||
"port": 50051,
|
||||
"learner_ip": "127.0.0.1",
|
||||
}
|
||||
)
|
||||
camera_number: int = 1
|
||||
# Add type annotations for these fields:
|
||||
image_encoder_hidden_dim: int = 32
|
||||
|
||||
@@ -17,8 +17,7 @@
|
||||
|
||||
# TODO: (1) better device management
|
||||
|
||||
from collections import deque
|
||||
from typing import Callable, Optional, Sequence, Tuple, Union
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
@@ -74,43 +73,42 @@ class SACPolicy(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
# NOTE: For images the encoder should be shared between the actor and critic
|
||||
if config.shared_encoder:
|
||||
encoder_critic = SACObservationEncoder(config)
|
||||
encoder_actor: SACObservationEncoder = encoder_critic
|
||||
else:
|
||||
encoder_critic = SACObservationEncoder(config)
|
||||
encoder_actor = SACObservationEncoder(config)
|
||||
# Define networks
|
||||
critic_nets = []
|
||||
for _ in range(config.num_critics):
|
||||
critic_net = Critic(
|
||||
encoder=encoder_critic,
|
||||
network=MLP(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
critic_nets.append(critic_net)
|
||||
|
||||
target_critic_nets = []
|
||||
for _ in range(config.num_critics):
|
||||
target_critic_net = Critic(
|
||||
encoder=encoder_critic,
|
||||
network=MLP(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
target_critic_nets.append(target_critic_net)
|
||||
self.critic_ensemble = CriticEnsemble(
|
||||
encoder=encoder_critic,
|
||||
network_list=nn.ModuleList(
|
||||
[
|
||||
MLP(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.critic_ensemble = create_critic_ensemble(
|
||||
critics=critic_nets, num_critics=config.num_critics, device=device
|
||||
)
|
||||
self.critic_target = create_critic_ensemble(
|
||||
critics=target_critic_nets, num_critics=config.num_critics, device=device
|
||||
self.critic_target = CriticEnsemble(
|
||||
encoder=encoder_critic,
|
||||
network_list=nn.ModuleList(
|
||||
[
|
||||
MLP(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
self.actor = Policy(
|
||||
@@ -123,7 +121,8 @@ class SACPolicy(
|
||||
)
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
|
||||
# TODO: Handle the case where the temparameter is a fixed
|
||||
|
||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
||||
self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
@@ -152,18 +151,19 @@ class SACPolicy(
|
||||
Tensor of Q-values from all critics
|
||||
"""
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = torch.stack([critic(observations, actions) for critic in critics])
|
||||
q_values = critics(observations, actions)
|
||||
return q_values
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ...
|
||||
def update_target_networks(self):
|
||||
"""Update target networks with exponential moving average"""
|
||||
for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False):
|
||||
for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
for target_param, param in zip(
|
||||
self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=False
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
def compute_loss_critic(self, observations, actions, rewards, next_observations, done) -> Tensor:
|
||||
temperature = self.log_alpha.exp().item()
|
||||
@@ -264,34 +264,83 @@ class MLP(nn.Module):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
class CriticEnsemble(nn.Module):
|
||||
"""
|
||||
┌──────────────────┬─────────────────────────────────────────────────────────┐
|
||||
│ Critic Ensemble │ │
|
||||
├──────────────────┘ │
|
||||
│ │
|
||||
│ ┌────┐ ┌────┐ ┌────┐ │
|
||||
│ │ Q1 │ │ Q2 │ │ Qn │ │
|
||||
│ └────┘ └────┘ └────┘ │
|
||||
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
|
||||
│ │ │ │ │ │ │ │
|
||||
│ │ MLP 1 │ │ MLP 2 │ │ MLP │ │
|
||||
│ │ │ │ │ ... │ num_critics │ │
|
||||
│ │ │ │ │ │ │ │
|
||||
│ └──────────────┘ └──────────────┘ └──────────────┘ │
|
||||
│ ▲ ▲ ▲ │
|
||||
│ └───────────────────┴───────┬────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ │ │
|
||||
│ ┌───────────────────┐ │
|
||||
│ │ Embedding │ │
|
||||
│ │ │ │
|
||||
│ └───────────────────┘ │
|
||||
│ ▲ │
|
||||
│ │ │
|
||||
│ ┌─────────────┴────────────┐ │
|
||||
│ │ │ │
|
||||
│ │ SACObservationEncoder │ │
|
||||
│ │ │ │
|
||||
│ └──────────────────────────┘ │
|
||||
│ ▲ │
|
||||
│ │ │
|
||||
│ │ │
|
||||
│ │ │
|
||||
└───────────────────────────┬────────────────────┬───────────────────────────┘
|
||||
│ Observation │
|
||||
└────────────────────┘
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
network: nn.Module,
|
||||
network_list: nn.Module,
|
||||
init_final: Optional[float] = None,
|
||||
device: str = "cpu",
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
self.encoder = encoder
|
||||
self.network = network
|
||||
self.network_list = network_list
|
||||
self.init_final = init_final
|
||||
|
||||
# for network in network_list:
|
||||
# network.to(self.device)
|
||||
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
for layer in reversed(network_list[0].net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
out_features = layer.out_features
|
||||
break
|
||||
|
||||
# Output layer
|
||||
self.output_layers = []
|
||||
if init_final is not None:
|
||||
self.output_layer = nn.Linear(out_features, 1)
|
||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||
for _ in network_list:
|
||||
output_layer = nn.Linear(out_features, 1, device=device)
|
||||
nn.init.uniform_(output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(output_layer.bias, -init_final, init_final)
|
||||
self.output_layers.append(output_layer)
|
||||
else:
|
||||
self.output_layer = nn.Linear(out_features, 1)
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
self.output_layers = []
|
||||
for _ in network_list:
|
||||
output_layer = nn.Linear(out_features, 1, device=device)
|
||||
orthogonal_init()(output_layer.weight)
|
||||
self.output_layers.append(output_layer)
|
||||
|
||||
self.output_layers = nn.ModuleList(self.output_layers)
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
@@ -307,9 +356,12 @@ class Critic(nn.Module):
|
||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
x = self.network(inputs)
|
||||
value = self.output_layer(x)
|
||||
return value.squeeze(-1)
|
||||
list_q_values = []
|
||||
for network, output_layer in zip(self.network_list, self.output_layers, strict=False):
|
||||
x = network(inputs)
|
||||
value = output_layer(x)
|
||||
list_q_values.append(value.squeeze(-1))
|
||||
return torch.stack(list_q_values)
|
||||
|
||||
|
||||
class Policy(nn.Module):
|
||||
@@ -416,9 +468,7 @@ class Policy(nn.Module):
|
||||
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations.
|
||||
TODO(ke-wang): The original work allows for (1) stacking multiple history frames and (2) using pretrained resnet encoders.
|
||||
"""
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
||||
def __init__(self, config: SACConfig):
|
||||
"""
|
||||
@@ -513,8 +563,7 @@ class SACObservationEncoder(nn.Module):
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
# TODO(ke-wang): currently average over all features, concatenate all features maybe a better way
|
||||
# return torch.stack(feat, dim=0).mean(0)
|
||||
|
||||
features = torch.cat(tensors=feat, dim=-1)
|
||||
features = self.aggregation_layer(features)
|
||||
|
||||
@@ -530,12 +579,8 @@ def orthogonal_init():
|
||||
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
||||
|
||||
|
||||
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cpu") -> nn.ModuleList:
|
||||
"""Creates an ensemble of critic networks"""
|
||||
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
|
||||
return nn.ModuleList(critics).to(device)
|
||||
|
||||
|
||||
# TODO (azouitine): I think in our case this function is not usefull we should remove it
|
||||
# after some investigation
|
||||
# borrowed from tdmpc
|
||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||
|
||||
Reference in New Issue
Block a user