completed losses
This commit is contained in:
@@ -89,11 +89,17 @@ class SACPolicy(
|
|||||||
|
|
||||||
Returns a dictionary with loss as a tensor, and other information as native floats.
|
Returns a dictionary with loss as a tensor, and other information as native floats.
|
||||||
"""
|
"""
|
||||||
observation_batch =
|
batch = self.normalize_inputs(batch)
|
||||||
next_obaservation_batch =
|
# batch shape is (b, 2, ...) where index 1 returns the current observation and
|
||||||
action_batch =
|
# the next observation for caluculating the right td index.
|
||||||
reward_batch =
|
actions = batch["action"][:, 0]
|
||||||
dones_batch =
|
rewards = batch["next.reward"][:, 0]
|
||||||
|
observations = {}
|
||||||
|
next_observations = {}
|
||||||
|
for k in batch:
|
||||||
|
if k.startswith("observation."):
|
||||||
|
observations[k] = batch[k][:, 0]
|
||||||
|
next_observations[k] = batch[k][:, 1]
|
||||||
|
|
||||||
# perform image augmentation
|
# perform image augmentation
|
||||||
|
|
||||||
@@ -104,34 +110,51 @@ class SACPolicy(
|
|||||||
|
|
||||||
# calculate critics loss
|
# calculate critics loss
|
||||||
# 1- compute actions from policy
|
# 1- compute actions from policy
|
||||||
next_actions = ..
|
action_preds, log_probs = self.actor_network(observations)
|
||||||
# 2- compute q targets
|
# 2- compute q targets
|
||||||
q_targets = self.target_qs(next_obaservation_batch, next_actions)
|
q_targets = self.target_qs(next_observations, action_preds)
|
||||||
|
|
||||||
# critics subsample size
|
# critics subsample size
|
||||||
min_q = q_targets.min(dim=0)
|
min_q = q_targets.min(dim=0)
|
||||||
|
|
||||||
# backup entropy
|
# backup entropy
|
||||||
td_target = reward_batch + self.discount * min_q
|
td_target = rewards + self.discount * min_q
|
||||||
|
|
||||||
# 3- compute predicted qs
|
# 3- compute predicted qs
|
||||||
q_preds = self.critic_ensemble(observation_batch, action_batch)
|
q_preds = self.critic_ensemble(observations, actions)
|
||||||
|
|
||||||
# 4- Calculate loss
|
# 4- Calculate loss
|
||||||
critics_loss = F.mse_loss(q_preds,
|
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||||
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])) # dones masks
|
critics_loss = (
|
||||||
|
F.mse_loss(
|
||||||
|
q_preds,
|
||||||
|
einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]),
|
||||||
|
reduction="none",
|
||||||
|
).sum(0) # sum over ensemble
|
||||||
|
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||||
|
* ~batch["observation.state_is_pad"][0]
|
||||||
|
* ~batch["action_is_pad"]
|
||||||
|
# q_targets depends on the reward and the next observations.
|
||||||
|
* ~batch["next.reward_is_pad"]
|
||||||
|
* ~batch["observation.state_is_pad"][1:]
|
||||||
|
).sum(0).mean()
|
||||||
|
|
||||||
# calculate actors loss
|
# calculate actors loss
|
||||||
# 1- temperature
|
# 1- temperature
|
||||||
temperature = self.temperature()
|
temperature = self.temperature()
|
||||||
|
|
||||||
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
||||||
actions, log_probs = self.actor_network(observation_batch)
|
actions, log_probs = self.actor_network(observations) \
|
||||||
|
|
||||||
# 3- get q-value predictions
|
# 3- get q-value predictions
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
q_preds = self.critic_ensemble(observation_batch, actions, return_type="mean")
|
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
|
||||||
actor_loss = -(q_preds - temperature * log_probs).mean()
|
actor_loss = (
|
||||||
|
-(q_preds - temperature * log_probs).mean()
|
||||||
|
* ~batch["observation.state_is_pad"][0]
|
||||||
|
* ~batch["action_is_pad"]
|
||||||
|
).mean()
|
||||||
|
|
||||||
|
|
||||||
# calculate temperature loss
|
# calculate temperature loss
|
||||||
# 1- calculate entropy
|
# 1- calculate entropy
|
||||||
@@ -141,8 +164,8 @@ class SACPolicy(
|
|||||||
loss = critics_loss + actor_loss + temperature_loss
|
loss = critics_loss + actor_loss + temperature_loss
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"Q_value_loss": critics_loss.item(),
|
"critics_loss": critics_loss.item(),
|
||||||
"pi_loss": actor_loss.item(),
|
"actor_loss": actor_loss.item(),
|
||||||
"temperature_loss": temperature_loss.item(),
|
"temperature_loss": temperature_loss.item(),
|
||||||
"temperature": temperature.item(),
|
"temperature": temperature.item(),
|
||||||
"entropy": entropy.item(),
|
"entropy": entropy.item(),
|
||||||
@@ -154,3 +177,11 @@ class SACPolicy(
|
|||||||
self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight)
|
self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight)
|
||||||
#for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
|
#for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
|
||||||
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)
|
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)
|
||||||
|
|
||||||
|
class SACObservationEncoder(nn.Module):
|
||||||
|
"""Encode image and/or state vector observations."""
|
||||||
|
|
||||||
|
def __init__(self, config: SACConfig):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|||||||
Reference in New Issue
Block a user