fix kl loss issue
This commit is contained in:
@@ -268,7 +268,7 @@ class DataParallelPPOActor(BasePPOActor):
|
|||||||
kl_penalty=self.config.kl_loss_type)
|
kl_penalty=self.config.kl_loss_type)
|
||||||
kl_loss = masked_mean(kld, response_mask)
|
kl_loss = masked_mean(kld, response_mask)
|
||||||
|
|
||||||
policy_loss = policy_loss - kl_loss * self.config.kl_loss_coef
|
policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
|
||||||
metrics['actor/kl_loss'] = kl_loss.detach().item()
|
metrics['actor/kl_loss'] = kl_loss.detach().item()
|
||||||
metrics['actor/kl_coef'] = self.config.kl_loss_coef
|
metrics['actor/kl_coef'] = self.config.kl_loss_coef
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user