fix kl loss issue

This commit is contained in:
PeterGriffinJin
2025-03-18 20:07:47 +00:00
parent e85506f143
commit 4b3c09451a

View File

@@ -268,7 +268,7 @@ class DataParallelPPOActor(BasePPOActor):
kl_penalty=self.config.kl_loss_type)
kl_loss = masked_mean(kld, response_mask)
policy_loss = policy_loss - kl_loss * self.config.kl_loss_coef
policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
metrics['actor/kl_loss'] = kl_loss.detach().item()
metrics['actor/kl_coef'] = self.config.kl_loss_coef