Stable version of rlpd + drq
This commit is contained in:
committed by
Michel Aractingi
parent
1fb03d4cf2
commit
d75b44f89f
@@ -42,37 +42,31 @@ class SACPolicy(
|
||||
name = "sac"
|
||||
|
||||
def __init__(
|
||||
self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
|
||||
self,
|
||||
config: SACConfig | None = None,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
device: str = "cpu",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = SACConfig()
|
||||
self.config = config
|
||||
|
||||
if config.input_normalization_modes is not None:
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
# HACK: we need to pass the dataset_stats to the normalization functions
|
||||
|
||||
# NOTE: This is for biwalker environment
|
||||
dataset_stats = dataset_stats or {
|
||||
"action": {
|
||||
"min": torch.tensor([-1.0, -1.0, -1.0, -1.0]),
|
||||
"max": torch.tensor([1.0, 1.0, 1.0, 1.0]),
|
||||
}
|
||||
}
|
||||
output_normalization_params = {}
|
||||
for outer_key, inner_dict in config.output_normalization_params.items():
|
||||
output_normalization_params[outer_key] = {}
|
||||
for key, value in inner_dict.items():
|
||||
output_normalization_params[outer_key][key] = torch.tensor(value)
|
||||
|
||||
# NOTE: This is for pusht environment
|
||||
# dataset_stats = dataset_stats or {
|
||||
# "action": {
|
||||
# "min": torch.tensor([0, 0]),
|
||||
# "max": torch.tensor([512, 512]),
|
||||
# }
|
||||
# }
|
||||
# HACK: This is hacky and should be removed
|
||||
dataset_stats = dataset_stats or output_normalization_params
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
@@ -82,7 +76,7 @@ class SACPolicy(
|
||||
|
||||
if config.shared_encoder:
|
||||
encoder_critic = SACObservationEncoder(config)
|
||||
encoder_actor = encoder_critic
|
||||
encoder_actor: SACObservationEncoder = encoder_critic
|
||||
else:
|
||||
encoder_critic = SACObservationEncoder(config)
|
||||
encoder_actor = SACObservationEncoder(config)
|
||||
@@ -95,6 +89,7 @@ class SACPolicy(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
critic_nets.append(critic_net)
|
||||
|
||||
@@ -106,40 +101,35 @@ class SACPolicy(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
target_critic_nets.append(target_critic_net)
|
||||
|
||||
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
|
||||
self.critic_target = create_critic_ensemble(target_critic_nets, config.num_critics)
|
||||
self.critic_ensemble = create_critic_ensemble(
|
||||
critics=critic_nets, num_critics=config.num_critics, device=device
|
||||
)
|
||||
self.critic_target = create_critic_ensemble(
|
||||
critics=target_critic_nets, num_critics=config.num_critics, device=device
|
||||
)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
|
||||
action_dim=config.output_shapes["action"][0],
|
||||
device=device,
|
||||
encoder_is_shared=config.shared_encoder,
|
||||
**config.policy_kwargs,
|
||||
)
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
|
||||
# TODO: fix later device
|
||||
# TODO: Handle the case where the temparameter is a fixed
|
||||
self.log_alpha = torch.zeros(1, requires_grad=True, device="cpu")
|
||||
self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Clear observation and action queues. Should be called on `env.reset()`
|
||||
queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
"""
|
||||
|
||||
self._queues = {
|
||||
"observation.state": deque(maxlen=1),
|
||||
"action": deque(maxlen=1),
|
||||
}
|
||||
if "observation.image" in self.config.input_shapes:
|
||||
self._queues["observation.image"] = deque(maxlen=1)
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=1)
|
||||
"""Reset the policy"""
|
||||
pass
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
@@ -334,6 +324,7 @@ class Policy(nn.Module):
|
||||
init_final: Optional[float] = None,
|
||||
use_tanh_squash: bool = False,
|
||||
device: str = "cpu",
|
||||
encoder_is_shared: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
@@ -344,7 +335,12 @@ class Policy(nn.Module):
|
||||
self.log_std_max = log_std_max
|
||||
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
|
||||
self.use_tanh_squash = use_tanh_squash
|
||||
self.parameters_to_optimize = []
|
||||
|
||||
self.parameters_to_optimize += list(self.network.parameters())
|
||||
|
||||
if self.encoder is not None and not encoder_is_shared:
|
||||
self.parameters_to_optimize += list(self.encoder.parameters())
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
@@ -358,6 +354,7 @@ class Policy(nn.Module):
|
||||
else:
|
||||
orthogonal_init()(self.mean_layer.weight)
|
||||
|
||||
self.parameters_to_optimize += list(self.mean_layer.parameters())
|
||||
# Standard deviation layer or parameter
|
||||
if fixed_std is None:
|
||||
self.std_layer = nn.Linear(out_features, action_dim)
|
||||
@@ -366,6 +363,7 @@ class Policy(nn.Module):
|
||||
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.std_layer.weight)
|
||||
self.parameters_to_optimize += list(self.std_layer.parameters())
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
@@ -428,44 +426,78 @@ class SACObservationEncoder(nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
if "observation.image" in config.input_shapes:
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2
|
||||
in_channels=config.input_shapes["observation.image"][0],
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=5,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.camera_number = config.camera_number
|
||||
self.aggregation_size: int = 0
|
||||
|
||||
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
||||
with torch.inference_mode():
|
||||
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
self.image_enc_layers.extend(
|
||||
nn.Sequential(
|
||||
sequential=nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(np.prod(out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Linear(
|
||||
in_features=np.prod(out_shape) * self.camera_number, out_features=config.latent_dim
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
)
|
||||
|
||||
self.aggregation_size += config.latent_dim * self.camera_number
|
||||
if "observation.state" in config.input_shapes:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
nn.Linear(config.input_shapes["observation.state"][0], config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Linear(
|
||||
in_features=config.input_shapes["observation.state"][0], out_features=config.latent_dim
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(config.input_shapes["observation.environment_state"][0], config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Linear(
|
||||
in_features=config.input_shapes["observation.environment_state"][0],
|
||||
out_features=config.latent_dim,
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
self.aggregation_size += config.latent_dim
|
||||
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
||||
|
||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode the image and/or state vector.
|
||||
|
||||
@@ -482,7 +514,11 @@ class SACObservationEncoder(nn.Module):
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
# TODO(ke-wang): currently average over all features, concatenate all features maybe a better way
|
||||
return torch.stack(feat, dim=0).mean(0)
|
||||
# return torch.stack(feat, dim=0).mean(0)
|
||||
features = torch.cat(tensors=feat, dim=-1)
|
||||
features = self.aggregation_layer(features)
|
||||
|
||||
return features
|
||||
|
||||
@property
|
||||
def output_dim(self) -> int:
|
||||
|
||||
Reference in New Issue
Block a user