completed losses
This commit is contained in:
@@ -89,12 +89,18 @@ class SACPolicy(
|
||||
|
||||
Returns a dictionary with loss as a tensor, and other information as native floats.
|
||||
"""
|
||||
observation_batch =
|
||||
next_obaservation_batch =
|
||||
action_batch =
|
||||
reward_batch =
|
||||
dones_batch =
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
# batch shape is (b, 2, ...) where index 1 returns the current observation and
|
||||
# the next observation for caluculating the right td index.
|
||||
actions = batch["action"][:, 0]
|
||||
rewards = batch["next.reward"][:, 0]
|
||||
observations = {}
|
||||
next_observations = {}
|
||||
for k in batch:
|
||||
if k.startswith("observation."):
|
||||
observations[k] = batch[k][:, 0]
|
||||
next_observations[k] = batch[k][:, 1]
|
||||
|
||||
# perform image augmentation
|
||||
|
||||
# reward bias
|
||||
@@ -104,34 +110,51 @@ class SACPolicy(
|
||||
|
||||
# calculate critics loss
|
||||
# 1- compute actions from policy
|
||||
next_actions = ..
|
||||
action_preds, log_probs = self.actor_network(observations)
|
||||
# 2- compute q targets
|
||||
q_targets = self.target_qs(next_obaservation_batch, next_actions)
|
||||
q_targets = self.target_qs(next_observations, action_preds)
|
||||
|
||||
# critics subsample size
|
||||
min_q = q_targets.min(dim=0)
|
||||
|
||||
# backup entropy
|
||||
td_target = reward_batch + self.discount * min_q
|
||||
td_target = rewards + self.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
q_preds = self.critic_ensemble(observation_batch, action_batch)
|
||||
q_preds = self.critic_ensemble(observations, actions)
|
||||
|
||||
# 4- Calculate loss
|
||||
critics_loss = F.mse_loss(q_preds,
|
||||
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])) # dones masks
|
||||
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
q_preds,
|
||||
einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]),
|
||||
reduction="none",
|
||||
).sum(0) # sum over ensemble
|
||||
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# q_targets depends on the reward and the next observations.
|
||||
* ~batch["next.reward_is_pad"]
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
).sum(0).mean()
|
||||
|
||||
# calculate actors loss
|
||||
# 1- temperature
|
||||
temperature = self.temperature()
|
||||
|
||||
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
||||
actions, log_probs = self.actor_network(observation_batch)
|
||||
actions, log_probs = self.actor_network(observations) \
|
||||
|
||||
# 3- get q-value predictions
|
||||
with torch.no_grad():
|
||||
q_preds = self.critic_ensemble(observation_batch, actions, return_type="mean")
|
||||
actor_loss = -(q_preds - temperature * log_probs).mean()
|
||||
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
|
||||
actor_loss = (
|
||||
-(q_preds - temperature * log_probs).mean()
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
).mean()
|
||||
|
||||
|
||||
# calculate temperature loss
|
||||
# 1- calculate entropy
|
||||
@@ -141,8 +164,8 @@ class SACPolicy(
|
||||
loss = critics_loss + actor_loss + temperature_loss
|
||||
|
||||
return {
|
||||
"Q_value_loss": critics_loss.item(),
|
||||
"pi_loss": actor_loss.item(),
|
||||
"critics_loss": critics_loss.item(),
|
||||
"actor_loss": actor_loss.item(),
|
||||
"temperature_loss": temperature_loss.item(),
|
||||
"temperature": temperature.item(),
|
||||
"entropy": entropy.item(),
|
||||
@@ -153,4 +176,12 @@ class SACPolicy(
|
||||
def update(self):
|
||||
self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight)
|
||||
#for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
|
||||
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)
|
||||
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
||||
def __init__(self, config: SACConfig):
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
Reference in New Issue
Block a user