fix kl loss issue
This commit is contained in:
@@ -268,7 +268,7 @@ class DataParallelPPOActor(BasePPOActor):
|
||||
kl_penalty=self.config.kl_loss_type)
|
||||
kl_loss = masked_mean(kld, response_mask)
|
||||
|
||||
policy_loss = policy_loss - kl_loss * self.config.kl_loss_coef
|
||||
policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
|
||||
metrics['actor/kl_loss'] = kl_loss.detach().item()
|
||||
metrics['actor/kl_coef'] = self.config.kl_loss_coef
|
||||
|
||||
|
||||
Reference in New Issue
Block a user