Initial commit
This commit is contained in:
274
verl/trainer/ppo/core_algos.py
Normal file
274
verl/trainer/ppo/core_algos.py
Normal file
@@ -0,0 +1,274 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Core functions to implement PPO algorithms.
|
||||
The function implemented in this file should be used by trainer with different distributed strategies to
|
||||
implement PPO
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from collections import defaultdict
|
||||
|
||||
import verl.utils.torch_functional as verl_F
|
||||
|
||||
|
||||
class AdaptiveKLController:
|
||||
"""
|
||||
Adaptive KL controller described in the paper:
|
||||
https://arxiv.org/pdf/1909.08593.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, init_kl_coef, target_kl, horizon):
|
||||
self.value = init_kl_coef
|
||||
self.target = target_kl
|
||||
self.horizon = horizon
|
||||
|
||||
def update(self, current_kl, n_steps):
|
||||
target = self.target
|
||||
proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
|
||||
mult = 1 + proportional_error * n_steps / self.horizon
|
||||
self.value *= mult
|
||||
|
||||
|
||||
class FixedKLController:
|
||||
"""Fixed KL controller."""
|
||||
|
||||
def __init__(self, kl_coef):
|
||||
self.value = kl_coef
|
||||
|
||||
def update(self, current_kl, n_steps):
|
||||
pass
|
||||
|
||||
|
||||
def get_kl_controller(config): # seems never used?
|
||||
if config.critic.kl_ctrl.type == 'fixed':
|
||||
kl_ctrl = FixedKLController(kl_coef=config.critic.kl_ctrl.kl_coef)
|
||||
elif config.critic.kl_ctrl.type == 'adaptive':
|
||||
assert config.kl_ctrl.horizon > 0, f'horizon must be larger than 0. Got {config.critic.kl_ctrl.horizon}'
|
||||
kl_ctrl = AdaptiveKLController(init_kl_coef=config.critic.kl_ctrl.kl_coef,
|
||||
target_kl=config.critic.kl_ctrl.target_kl,
|
||||
horizon=config.critic.kl_ctrl.horizon)
|
||||
else:
|
||||
raise ValueError('Unknown kl_ctrl type')
|
||||
|
||||
return kl_ctrl
|
||||
|
||||
|
||||
def compute_gae_advantage_return(token_level_rewards: torch.Tensor, values: torch.Tensor, eos_mask: torch.Tensor,
|
||||
gamma: torch.Tensor, lam: torch.Tensor):
|
||||
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py
|
||||
|
||||
Args:
|
||||
token_level_rewards: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
values: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
eos_mask: `(torch.Tensor)`
|
||||
shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
|
||||
gamma: `(float)`
|
||||
discounted factor used in RL
|
||||
lam: `(float)`
|
||||
lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
|
||||
|
||||
Returns:
|
||||
advantages: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
Returns: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
|
||||
"""
|
||||
with torch.no_grad():
|
||||
lastgaelam = 0
|
||||
advantages_reversed = []
|
||||
gen_len = token_level_rewards.shape[-1]
|
||||
|
||||
for t in reversed(range(gen_len)):
|
||||
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
|
||||
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
|
||||
lastgaelam = delta + gamma * lam * lastgaelam
|
||||
advantages_reversed.append(lastgaelam)
|
||||
advantages = torch.stack(advantages_reversed[::-1], dim=1)
|
||||
|
||||
returns = advantages + values
|
||||
advantages = verl_F.masked_whiten(advantages, eos_mask)
|
||||
return advantages, returns
|
||||
|
||||
|
||||
# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
|
||||
def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor,
|
||||
eos_mask: torch.Tensor,
|
||||
index: torch.Tensor,
|
||||
epsilon: float = 1e-6):
|
||||
"""
|
||||
Compute advantage for GRPO, operating only on Outcome reward
|
||||
(with only one scalar reward for each response).
|
||||
Args:
|
||||
token_level_rewards: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
eos_mask: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
|
||||
Returns:
|
||||
advantages: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
Returns: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
"""
|
||||
response_length = token_level_rewards.shape[-1]
|
||||
non_zero_mask = (token_level_rewards != 0)
|
||||
scores = (token_level_rewards * non_zero_mask).sum(dim=-1)
|
||||
|
||||
id2score = defaultdict(list)
|
||||
id2mean = {}
|
||||
id2std = {}
|
||||
|
||||
with torch.no_grad():
|
||||
bsz = scores.shape[0]
|
||||
for i in range(bsz):
|
||||
id2score[index[i]].append(scores[i])
|
||||
for idx in id2score:
|
||||
if len(id2score[idx]) == 1:
|
||||
id2mean[idx] = torch.tensor(0.0)
|
||||
id2std[idx] = torch.tensor(1.0)
|
||||
elif len(id2score[idx]) > 1:
|
||||
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
|
||||
id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
|
||||
else:
|
||||
raise ValueError(f"no score in prompt index: {idx}")
|
||||
for i in range(bsz):
|
||||
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
|
||||
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
|
||||
|
||||
return scores, scores
|
||||
|
||||
|
||||
def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
|
||||
kl = old_log_prob - ref_log_prob
|
||||
return token_level_scores - kl * kl_ratio
|
||||
|
||||
|
||||
def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange):
|
||||
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122
|
||||
|
||||
Args:
|
||||
old_log_prob: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
log_prob: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
advantages: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
eos_mask: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
cliprange: (float)
|
||||
The clip range used in PPO. See https://arxiv.org/abs/1707.06347
|
||||
|
||||
Returns:
|
||||
pg_loss: `a scalar torch.Tensor`
|
||||
policy gradient loss computed via PPO
|
||||
pg_clipfrac: (float)
|
||||
a float number indicating the fraction of policy gradient loss being clipped
|
||||
|
||||
"""
|
||||
negative_approx_kl = log_prob - old_log_prob
|
||||
ratio = torch.exp(negative_approx_kl)
|
||||
ppo_kl = verl_F.masked_mean(-negative_approx_kl, eos_mask)
|
||||
|
||||
pg_losses = -advantages * ratio
|
||||
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)
|
||||
|
||||
pg_loss = verl_F.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask)
|
||||
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask)
|
||||
return pg_loss, pg_clipfrac, ppo_kl
|
||||
|
||||
|
||||
def compute_entropy_loss(logits, eos_mask):
|
||||
"""Compute Categorical entropy loss
|
||||
|
||||
Args:
|
||||
logits: `(torch.Tensor)`
|
||||
shape: (bs, response_length, vocab_size)
|
||||
eos_mask: `(torch.Tensor)`
|
||||
shape: (bs, response_length)
|
||||
|
||||
Returns:
|
||||
entropy: a scalar torch.Tensor
|
||||
|
||||
"""
|
||||
# compute entropy
|
||||
entropy = verl_F.entropy_from_logits(logits) # (bs, response_len)
|
||||
entropy_loss = verl_F.masked_mean(entropy, mask=eos_mask)
|
||||
return entropy_loss
|
||||
|
||||
|
||||
def compute_value_loss(vpreds, returns, values, eos_mask, cliprange_value):
|
||||
"""Compute the value loss. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151
|
||||
|
||||
Args:
|
||||
vpreds (`torch.FloatTensor`):
|
||||
Predicted values of the value head, shape (`batch_size`, `response_length`)
|
||||
values (`torch.FloatTensor`):
|
||||
Old values of value head, shape (`batch_size`, `response_length`)
|
||||
returns: (`torch.FloatTensor`):
|
||||
Ground truth returns, shape (`batch_size`, `response_length`)
|
||||
|
||||
Returns:
|
||||
vf_loss: a scalar (`torch.FloatTensor`):
|
||||
value function loss
|
||||
vf_clipfrac: a float
|
||||
The ratio of vf being clipped
|
||||
|
||||
"""
|
||||
vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
|
||||
vf_losses1 = (vpreds - returns)**2
|
||||
vf_losses2 = (vpredclipped - returns)**2
|
||||
vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), eos_mask)
|
||||
vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), eos_mask)
|
||||
return vf_loss, vf_clipfrac
|
||||
|
||||
|
||||
def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
|
||||
"""Compute KL divergence given logprob and ref_logprob.
|
||||
Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104
|
||||
|
||||
Args:
|
||||
logprob:
|
||||
ref_logprob:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if kl_penalty == "kl":
|
||||
return logprob - ref_logprob
|
||||
|
||||
if kl_penalty == "abs":
|
||||
return (logprob - ref_logprob).abs()
|
||||
|
||||
if kl_penalty == "mse":
|
||||
return 0.5 * (logprob - ref_logprob).square()
|
||||
|
||||
# J. Schulman. Approximating kl divergence, 2020.
|
||||
# # URL http://joschu.net/blog/kl-approx.html.
|
||||
if kl_penalty == 'low_var_kl':
|
||||
kl = ref_logprob - logprob
|
||||
ratio = torch.exp(kl)
|
||||
kld = (ratio - kl - 1).contiguous()
|
||||
return torch.clamp(kld, min=-10, max=10)
|
||||
|
||||
if kl_penalty == "full":
|
||||
# so, here logprob and ref_logprob should contain the logits for every token in vocabulary
|
||||
raise NotImplementedError
|
||||
|
||||
raise NotImplementedError
|
||||
Reference in New Issue
Block a user