completed losses

This commit is contained in:
Michel Aractingi
2024-12-12 11:45:30 +01:00
parent 458c427e0c
commit 972bac98b4

View File

@@ -89,12 +89,18 @@ class SACPolicy(
Returns a dictionary with loss as a tensor, and other information as native floats.
"""
observation_batch =
next_obaservation_batch =
action_batch =
reward_batch =
dones_batch =
batch = self.normalize_inputs(batch)
# batch shape is (b, 2, ...) where index 1 returns the current observation and
# the next observation for caluculating the right td index.
actions = batch["action"][:, 0]
rewards = batch["next.reward"][:, 0]
observations = {}
next_observations = {}
for k in batch:
if k.startswith("observation."):
observations[k] = batch[k][:, 0]
next_observations[k] = batch[k][:, 1]
# perform image augmentation
# reward bias
@@ -104,34 +110,51 @@ class SACPolicy(
# calculate critics loss
# 1- compute actions from policy
next_actions = ..
action_preds, log_probs = self.actor_network(observations)
# 2- compute q targets
q_targets = self.target_qs(next_obaservation_batch, next_actions)
q_targets = self.target_qs(next_observations, action_preds)
# critics subsample size
min_q = q_targets.min(dim=0)
# backup entropy
td_target = reward_batch + self.discount * min_q
td_target = rewards + self.discount * min_q
# 3- compute predicted qs
q_preds = self.critic_ensemble(observation_batch, action_batch)
q_preds = self.critic_ensemble(observations, actions)
# 4- Calculate loss
critics_loss = F.mse_loss(q_preds,
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])) # dones masks
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
critics_loss = (
F.mse_loss(
q_preds,
einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]),
reduction="none",
).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
# q_targets depends on the reward and the next observations.
* ~batch["next.reward_is_pad"]
* ~batch["observation.state_is_pad"][1:]
).sum(0).mean()
# calculate actors loss
# 1- temperature
temperature = self.temperature()
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
actions, log_probs = self.actor_network(observation_batch)
actions, log_probs = self.actor_network(observations) \
# 3- get q-value predictions
with torch.no_grad():
q_preds = self.critic_ensemble(observation_batch, actions, return_type="mean")
actor_loss = -(q_preds - temperature * log_probs).mean()
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
actor_loss = (
-(q_preds - temperature * log_probs).mean()
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
).mean()
# calculate temperature loss
# 1- calculate entropy
@@ -141,8 +164,8 @@ class SACPolicy(
loss = critics_loss + actor_loss + temperature_loss
return {
"Q_value_loss": critics_loss.item(),
"pi_loss": actor_loss.item(),
"critics_loss": critics_loss.item(),
"actor_loss": actor_loss.item(),
"temperature_loss": temperature_loss.item(),
"temperature": temperature.item(),
"entropy": entropy.item(),
@@ -153,4 +176,12 @@ class SACPolicy(
def update(self):
self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight)
#for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)
class SACObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
def __init__(self, config: SACConfig):
super().__init__()
self.config = config