Initial commit
This commit is contained in:
19
verl/workers/rollout/__init__.py
Normal file
19
verl/workers/rollout/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .base import BaseRollout
|
||||
from .naive import NaiveRollout
|
||||
from .hf_rollout import HFRollout
|
||||
|
||||
__all__ = ["BaseRollout", "NaiveRollout", "HFRollout"]
|
||||
37
verl/workers/rollout/base.py
Normal file
37
verl/workers/rollout/base.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Iterable, Union
|
||||
|
||||
from verl import DataProto
|
||||
|
||||
__all__ = ['BaseRollout']
|
||||
|
||||
|
||||
class BaseRollout(ABC):
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
|
||||
Args:
|
||||
dataloader: an Iterable of TensorDict that consistently generates prompts. Note that the dataloader
|
||||
should handle when the training stops.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
def generate_sequences(self, prompts: DataProto) -> DataProto:
|
||||
"""Generate sequences"""
|
||||
pass
|
||||
140
verl/workers/rollout/hf_rollout.py
Normal file
140
verl/workers/rollout/hf_rollout.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Rollout with huggingface models.
|
||||
TODO: refactor this class. Currently, it will hang when using FSDP HybridShard. We should actually create a single GPU model.
|
||||
Then, get full state_dict and bind the state_dict to the single GPU model. Then, use the single GPU model to perform generation.
|
||||
"""
|
||||
import contextlib
|
||||
import torch
|
||||
import torch.distributed
|
||||
from tensordict import TensorDict
|
||||
from torch import nn
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
|
||||
from verl import DataProto
|
||||
from verl.utils.torch_functional import get_eos_mask
|
||||
from .base import BaseRollout
|
||||
|
||||
from transformers import GenerationConfig
|
||||
|
||||
__all__ = ['HFRollout']
|
||||
|
||||
|
||||
class HFRollout(BaseRollout):
|
||||
|
||||
def __init__(self, module: nn.Module, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.module = module
|
||||
|
||||
def generate_sequences(self, prompts: DataProto) -> DataProto:
|
||||
batch_size = prompts.batch.batch_size[0]
|
||||
num_chunks = max(batch_size // self.config.get('micro_batch_size', batch_size), 1)
|
||||
batch_prompts = prompts.chunk(chunks=num_chunks)
|
||||
output = [self._generate_minibatch(p) for p in batch_prompts]
|
||||
output = DataProto.concat(output)
|
||||
return output
|
||||
|
||||
@torch.no_grad()
|
||||
def _generate_minibatch(self, prompts: DataProto) -> DataProto:
|
||||
idx = prompts.batch['input_ids'] # (bs, prompt_length)
|
||||
attention_mask = prompts.batch['attention_mask'] # left-padded attention_mask
|
||||
position_ids = prompts.batch['position_ids']
|
||||
|
||||
# used to construct attention_mask
|
||||
eos_token_id = prompts.meta_info['eos_token_id']
|
||||
pad_token_id = prompts.meta_info['pad_token_id']
|
||||
|
||||
batch_size = idx.size(0)
|
||||
prompt_length = idx.size(1)
|
||||
|
||||
self.module.eval()
|
||||
param_ctx = contextlib.nullcontext()
|
||||
|
||||
# make sampling args can be overriden by inputs
|
||||
do_sample = prompts.meta_info.get('do_sample', self.config.do_sample)
|
||||
response_length = prompts.meta_info.get('response_length', self.config.response_length)
|
||||
top_p = prompts.meta_info.get('top_p', self.config.get('top_p', 1.0))
|
||||
top_k = prompts.meta_info.get('top_k', self.config.get('top_k', 0))
|
||||
|
||||
if top_k is None:
|
||||
top_k = 0
|
||||
top_k = max(0, top_k) # to be compatible with vllm
|
||||
|
||||
temperature = prompts.meta_info.get('temperature', self.config.temperature)
|
||||
|
||||
generation_config = GenerationConfig(temperature=temperature, top_p=top_p, top_k=top_k)
|
||||
|
||||
if isinstance(self.module, FSDP):
|
||||
# recurse need to set to False according to https://github.com/pytorch/pytorch/issues/100069
|
||||
param_ctx = FSDP.summon_full_params(self.module, writeback=False, recurse=False)
|
||||
with param_ctx:
|
||||
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
||||
output = self.module.generate(
|
||||
input_ids=idx,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=do_sample,
|
||||
max_new_tokens=response_length,
|
||||
# max_length=max_length,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
generation_config=generation_config,
|
||||
# renormalize_logits=True,
|
||||
output_scores=False, # this is potentially very large
|
||||
return_dict_in_generate=True,
|
||||
use_cache=True)
|
||||
# TODO: filter out the seq with no answers like ds-chat
|
||||
seq = output.sequences
|
||||
|
||||
# huggingface generate will stop generating when all the batch reaches [EOS].
|
||||
# We have to pad to response_length
|
||||
sequence_length = prompt_length + self.config.response_length
|
||||
delta_length = sequence_length - seq.shape[1]
|
||||
|
||||
if delta_length > 0:
|
||||
delta_tokens = torch.ones(size=(batch_size, delta_length), device=seq.device, dtype=seq.dtype)
|
||||
delta_tokens = pad_token_id * delta_tokens
|
||||
seq = torch.cat((seq, delta_tokens), dim=1)
|
||||
|
||||
assert seq.shape[1] == sequence_length
|
||||
|
||||
prompt = seq[:, :prompt_length] # (bs, prompt_length)
|
||||
response = seq[:, prompt_length:] # (bs, response_length)
|
||||
|
||||
response_length = response.size(1)
|
||||
delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
|
||||
delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1)
|
||||
|
||||
response_position_ids = position_ids[:, -1:] + delta_position_id
|
||||
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
|
||||
|
||||
response_attention_mask = get_eos_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)
|
||||
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)
|
||||
|
||||
batch = TensorDict(
|
||||
{
|
||||
'prompts': prompt,
|
||||
'responses': response,
|
||||
'input_ids': seq,
|
||||
'attention_mask': attention_mask,
|
||||
'position_ids': position_ids
|
||||
},
|
||||
batch_size=batch_size)
|
||||
|
||||
# empty cache before compute old_log_prob
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.module.train()
|
||||
return DataProto(batch=batch)
|
||||
15
verl/workers/rollout/naive/__init__.py
Normal file
15
verl/workers/rollout/naive/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .naive_rollout import NaiveRollout
|
||||
119
verl/workers/rollout/naive/naive_rollout.py
Normal file
119
verl/workers/rollout/naive/naive_rollout.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
In single GPU rollout, the sequences are generated directly by sampling from the model.
|
||||
The output will contain
|
||||
1. output_ids
|
||||
2. attention_masks (left padding)
|
||||
3. eos_masks
|
||||
4. log_probs
|
||||
"""
|
||||
from typing import Iterable, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from tensordict import TensorDict
|
||||
from torch import nn
|
||||
|
||||
from verl import DataProto
|
||||
from verl.utils.torch_functional import logprobs_from_logits
|
||||
from ..base import BaseRollout
|
||||
|
||||
__all__ = ['NativeRollout']
|
||||
|
||||
|
||||
class NaiveRollout(BaseRollout):
|
||||
|
||||
def __init__(self, module: nn.Module, config):
|
||||
"""A naive rollout. It requires the module to be compatible with huggingface APIs. That is:
|
||||
The module should define __call__ to receive input_ids, attention_mask and position_ids.
|
||||
It outputs a structure that contains logits field.
|
||||
|
||||
Args:
|
||||
module: module here follows huggingface APIs
|
||||
config: DictConfig
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.module = module
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_sequences(self, prompts: DataProto) -> DataProto:
|
||||
"""Generate sequences"""
|
||||
idx = prompts.batch['input_ids'] # (bs, prompt_length)
|
||||
attention_mask = prompts.batch['attention_mask'] # left-padded attention_mask
|
||||
position_ids = prompts.batch['position_ids']
|
||||
|
||||
# used to construct attention_mask
|
||||
eos_token_id = prompts.meta_info['eos_token_id']
|
||||
|
||||
batch_size = idx.size(0)
|
||||
prompt_length = idx.size(1)
|
||||
|
||||
self.module.eval()
|
||||
|
||||
prev_attention_mask = torch.ones(size=(batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device)
|
||||
|
||||
logits_lst = []
|
||||
for _ in range(self.config.response_length):
|
||||
# if the sequence context is growing too long we must crop it at block_size
|
||||
# idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
|
||||
idx_cond = idx
|
||||
# forward the model to get the logits for the index in the sequence
|
||||
# we use huggingface APIs here
|
||||
output = self.module(input_ids=idx_cond, attention_mask=attention_mask, position_ids=position_ids)
|
||||
logits = output.logits
|
||||
# pluck the logits at the final step and scale by desired temperature
|
||||
logits = logits[:, -1, :] / self.config.temperature # (bs, vocab_size)
|
||||
# optionally crop the logits to only the top k options
|
||||
if self.config.top_k is not None:
|
||||
v, _ = torch.topk(logits, min(self.config.top_k, logits.size(-1)))
|
||||
logits[logits < v[:, [-1]]] = -float('Inf')
|
||||
# apply softmax to convert logits to (normalized) probabilities
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
# sample from the distribution
|
||||
if self.config.do_sample:
|
||||
idx_next = torch.multinomial(probs, num_samples=1)
|
||||
else:
|
||||
idx_next = torch.argmax(probs, dim=-1, keepdim=True)
|
||||
|
||||
attention_mask = torch.cat((attention_mask, prev_attention_mask), dim=-1)
|
||||
|
||||
prev_attention_mask = torch.logical_and(idx_next != eos_token_id, prev_attention_mask.bool())
|
||||
prev_attention_mask.to(attention_mask.dtype)
|
||||
|
||||
position_ids = torch.cat((position_ids, position_ids[:, -1:] + 1), dim=-1)
|
||||
|
||||
# append sampled index to the running sequence and continue
|
||||
idx = torch.cat((idx, idx_next), dim=1)
|
||||
logits_lst.append(logits)
|
||||
|
||||
logits = torch.stack(logits_lst, dim=1) # (bs, response_length, vocab_size)
|
||||
prompts = idx[:, :prompt_length] # (bs, prompt_length)
|
||||
response = idx[:, prompt_length:] # (bs, response_length)
|
||||
log_probs = logprobs_from_logits(logits=logits, labels=response)
|
||||
batch = TensorDict(
|
||||
{
|
||||
'input_ids': prompts,
|
||||
'responses': response,
|
||||
'sequences': idx,
|
||||
'old_log_probs': log_probs,
|
||||
'attention_mask': attention_mask,
|
||||
'position_ids': position_ids,
|
||||
},
|
||||
batch_size=batch_size)
|
||||
|
||||
self.module.train()
|
||||
|
||||
return DataProto(batch=batch)
|
||||
162
verl/workers/rollout/tokenizer.py
Normal file
162
verl/workers/rollout/tokenizer.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
The base tokenizer class, required for any hybrid engine based rollout or inference with vLLM.
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Union
|
||||
|
||||
__all__ = ['HybridEngineBaseTokenizer']
|
||||
|
||||
|
||||
class HybridEngineBaseTokenizer(ABC):
|
||||
"""the tokenizer property and function name should align with HF's to meet vllm requirement"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def vocab_size(self):
|
||||
"""
|
||||
`int`: Size of the base vocabulary (without the added tokens).
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def pad_token_id(self):
|
||||
"""
|
||||
`Optional[int]`: Id of the padding token in the vocabulary. Returns `None` if the token has not been set.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def eos_token_id(self):
|
||||
"""
|
||||
`Optional[int]`: Id of the end of sentence token in the vocabulary. Returns `None` if the token has not been
|
||||
set.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def all_special_ids(self) -> List[int]:
|
||||
"""
|
||||
`List[int]`: List the ids of the special tokens(`'<unk>'`, `'<cls>'`, etc.) mapped to class attributes.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def all_special_tokens(self) -> List[str]:
|
||||
"""
|
||||
`List[str]`: A list of the unique special tokens (`'<unk>'`, `'<cls>'`, ..., etc.).
|
||||
|
||||
Convert tokens of `tokenizers.AddedToken` type to string.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, text):
|
||||
"""
|
||||
Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary.
|
||||
|
||||
Args:
|
||||
text (`str`, `List[str]` or `List[int]`):
|
||||
The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the
|
||||
`tokenize` method) or a list of integers.
|
||||
|
||||
text_pair (`str`, `List[str]` or `List[int]`, *optional*):
|
||||
Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using
|
||||
the `tokenize` method) or a list of integers.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def decode(
|
||||
self,
|
||||
token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: bool = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
|
||||
tokens and clean up tokenization spaces.
|
||||
|
||||
Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
|
||||
|
||||
Args:
|
||||
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
|
||||
List of tokenized input ids. Can be obtained using the `__call__` method.
|
||||
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to remove special tokens in the decoding.
|
||||
clean_up_tokenization_spaces (`bool`, *optional*):
|
||||
Whether or not to clean up the tokenization spaces. If `None`, will default to
|
||||
`self.clean_up_tokenization_spaces`.
|
||||
kwargs (additional keyword arguments, *optional*):
|
||||
Will be passed to the underlying model specific decode method.
|
||||
|
||||
Returns:
|
||||
`str`: The decoded sentence.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_ids_to_tokens(self,
|
||||
ids: Union[int, List[int]],
|
||||
skip_special_tokens: bool = False) -> Union[str, List[str]]:
|
||||
"""
|
||||
Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
|
||||
added tokens.
|
||||
|
||||
Args:
|
||||
ids (`int` or `List[int]`):
|
||||
The token id (or token ids) to convert to tokens.
|
||||
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to remove special tokens in the decoding.
|
||||
|
||||
Returns:
|
||||
`str` or `List[str]`: The decoded token(s).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_added_vocab(self) -> Dict[str, int]:
|
||||
"""
|
||||
Returns the added tokens in the vocabulary as a dictionary of token to index. Results might be different from
|
||||
the fast call because for now we always add the tokens even if they are already in the vocabulary. This is
|
||||
something we should change.
|
||||
|
||||
Returns:
|
||||
`Dict[str, int]`: The added tokens.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
"""
|
||||
Converts a sequence of tokens in a single string. The most simple way to do it is `" ".join(tokens)` but we
|
||||
often want to remove sub-word tokenization artifacts at the same time.
|
||||
|
||||
Args:
|
||||
tokens (`List[str]`): The token to join in a string.
|
||||
|
||||
Returns:
|
||||
`str`: The joined tokens.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_fast(self):
|
||||
return False
|
||||
15
verl/workers/rollout/vllm_rollout/__init__.py
Normal file
15
verl/workers/rollout/vllm_rollout/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .vllm_rollout import vLLMRollout
|
||||
226
verl/workers/rollout/vllm_rollout/vllm_rollout.py
Normal file
226
verl/workers/rollout/vllm_rollout/vllm_rollout.py
Normal file
@@ -0,0 +1,226 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
The vllm_rollout that can be applied in different backend
|
||||
When working with FSDP:
|
||||
- Use DTensor weight loader (recommended) or HF weight loader
|
||||
- Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM
|
||||
When working with Megatron:
|
||||
- Use Megatron weight loader
|
||||
- During training, only the current pp stage holds the parameters
|
||||
- Before inference, broadcast the parameters of the current pp rank to all other pp ranks (all pp ranks holds all the parameters)
|
||||
- Bind the parameters to the inference engine
|
||||
- Do inference in tp. pp is treated as additional dp
|
||||
- After inference, all the parameters that doesn't belong to this pp rank is freed.
|
||||
"""
|
||||
from typing import List
|
||||
from contextlib import contextmanager
|
||||
from omegaconf import DictConfig
|
||||
import torch
|
||||
import torch.distributed
|
||||
from tensordict import TensorDict
|
||||
from torch import nn
|
||||
|
||||
from verl import DataProto
|
||||
from verl.utils.torch_functional import get_eos_mask, pad_sequence_to_length
|
||||
from verl.workers.rollout.base import BaseRollout
|
||||
from verl.third_party.vllm import LLM, vllm_version
|
||||
from verl.third_party.vllm import parallel_state as vllm_ps
|
||||
from vllm import SamplingParams
|
||||
|
||||
# TODO
|
||||
# 1. support pp in vllm
|
||||
# 2. passing tokenizer is not necessary? no encoding/decoding is happending here
|
||||
# 3. simplify init logics
|
||||
|
||||
|
||||
# NOTE(sgm): add for verl. We can optimize it by making the dataloader yield List[int] without padding.
|
||||
def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[int]:
|
||||
# remove the left padding in the prompt token_id
|
||||
# pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
|
||||
non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
|
||||
token_ids = prompt_token_ids[non_pad_index:].tolist()
|
||||
return token_ids
|
||||
|
||||
|
||||
class vLLMRollout(BaseRollout):
|
||||
|
||||
def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model_hf_config, **kwargs):
|
||||
"""A vLLM rollout. It requires the module is supported by the vllm.
|
||||
|
||||
Args:
|
||||
module: module here follows huggingface APIs
|
||||
config: DictConfig
|
||||
tokenizer: the task/model tokenizer
|
||||
model_hf_config: the huggingface config to initiallize the generating model in vllm
|
||||
**kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
assert not (not config.enforce_eager and config.free_cache_engine), \
|
||||
"disable CUDA graph (enforce_eager = False) if free cache engine"
|
||||
|
||||
tensor_parallel_size = self.config.get('tensor_model_parallel_size', 1)
|
||||
assert tensor_parallel_size <= torch.distributed.get_world_size(), \
|
||||
"tensor parallel size should be less than or equal to the world size"
|
||||
|
||||
if kwargs.get('train_tp', None) is not None:
|
||||
# deployed with megatron
|
||||
import os
|
||||
os.environ['CUDA_TIMER_STREAM_KAFKA_ENABLE'] = '0'
|
||||
os.environ['MEGATRON_IMPORT_TIMERS'] = '0'
|
||||
train_tp = kwargs.get('train_tp', None)
|
||||
num_tp_per_train_tp = train_tp // tensor_parallel_size
|
||||
if vllm_version in ('0.4.2', '0.5.4', '0.6.3'):
|
||||
vllm_ps.initialize_parallel_state(tensor_model_parallel_size=tensor_parallel_size,
|
||||
num_tp_per_train_tp=num_tp_per_train_tp)
|
||||
|
||||
assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, \
|
||||
"model context length should be greater than total sequence length"
|
||||
self.inference_engine = LLM(actor_module,
|
||||
tokenizer=tokenizer,
|
||||
model_hf_config=model_hf_config,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
dtype=config.dtype,
|
||||
enforce_eager=config.enforce_eager,
|
||||
gpu_memory_utilization=config.gpu_memory_utilization,
|
||||
skip_tokenizer_init=False,
|
||||
max_model_len=config.prompt_length + config.response_length,
|
||||
load_format=config.load_format)
|
||||
|
||||
# Offload vllm model to reduce peak memory usage
|
||||
self.inference_engine.offload_model_weights()
|
||||
|
||||
kwargs = dict(
|
||||
n=1,
|
||||
logprobs=1, # can be set to 0 and let actor to recompute
|
||||
max_tokens=config.response_length,
|
||||
)
|
||||
|
||||
# we may detokenize the result all together later
|
||||
if vllm_version in ('0.4.2', '0.5.4', '0.6.3'):
|
||||
kwargs['detokenize'] = False
|
||||
|
||||
# supporting adding any sampling params from the config file
|
||||
for k in config.keys():
|
||||
if hasattr(SamplingParams(), str(k)):
|
||||
kwargs[k] = config.get(k)
|
||||
|
||||
print(f"kwargs: {kwargs}")
|
||||
self.sampling_params = SamplingParams(**kwargs)
|
||||
|
||||
self.pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
@contextmanager
|
||||
def update_sampling_params(self, **kwargs):
|
||||
# update sampling params
|
||||
old_sampling_params_args = {}
|
||||
if kwargs:
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self.sampling_params, key):
|
||||
old_value = getattr(self.sampling_params, key)
|
||||
old_sampling_params_args[key] = old_value
|
||||
setattr(self.sampling_params, key, value)
|
||||
yield
|
||||
# roll back to previous sampling params
|
||||
# if len(old_sampling_params_args):
|
||||
for key, value in old_sampling_params_args.items():
|
||||
setattr(self.sampling_params, key, value)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
|
||||
# rebuild vllm cache engine
|
||||
if self.config.free_cache_engine:
|
||||
self.inference_engine.init_cache_engine()
|
||||
|
||||
idx = prompts.batch['input_ids'] # (bs, prompt_length)
|
||||
# left-padded attention_mask
|
||||
attention_mask = prompts.batch['attention_mask']
|
||||
position_ids = prompts.batch['position_ids']
|
||||
|
||||
# used to construct attention_mask
|
||||
eos_token_id = prompts.meta_info['eos_token_id']
|
||||
|
||||
batch_size = idx.size(0)
|
||||
|
||||
idx_list = []
|
||||
# parse idx from torch.Tensor to List[List[str]]
|
||||
for i in range(batch_size):
|
||||
idx_list.append(_pre_process_inputs(self.pad_token_id, idx[i]))
|
||||
|
||||
do_sample = prompts.meta_info.get('do_sample', True)
|
||||
if not do_sample:
|
||||
kwargs = {
|
||||
'best_of': 1,
|
||||
'top_p': 1.0,
|
||||
'top_k': -1,
|
||||
'min_p': 0.0,
|
||||
'temperature': 0,
|
||||
'n': 1 # if greedy, only 1 response
|
||||
}
|
||||
|
||||
# users can customize different sampling_params at different run
|
||||
with self.update_sampling_params(**kwargs):
|
||||
output = self.inference_engine.generate(
|
||||
prompts=None, # because we have already convert it to prompt token id
|
||||
sampling_params=self.sampling_params,
|
||||
prompt_token_ids=idx_list,
|
||||
use_tqdm=False)
|
||||
|
||||
# TODO(sgm): disable logprob when recompute_log_prob is enable
|
||||
# if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length)
|
||||
response = output[0].to(idx.device)
|
||||
log_probs = output[1].to(idx.device)
|
||||
|
||||
if response.shape[1] < self.config.response_length:
|
||||
response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id)
|
||||
log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id)
|
||||
|
||||
if self.config.n > 1 and do_sample:
|
||||
idx = idx.repeat_interleave(self.config.n, dim=0)
|
||||
attention_mask = attention_mask.repeat_interleave(self.config.n, dim=0)
|
||||
position_ids = position_ids.repeat_interleave(self.config.n, dim=0)
|
||||
batch_size = batch_size * self.config.n
|
||||
seq = torch.cat([idx, response], dim=-1)
|
||||
|
||||
response_length = response.size(1)
|
||||
delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
|
||||
delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1)
|
||||
|
||||
# TODO(sgm): fix position_ids on right_pad
|
||||
# prompt: left pad + response: right pad
|
||||
# attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
|
||||
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
|
||||
response_position_ids = position_ids[:, -1:] + delta_position_id
|
||||
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
|
||||
response_attention_mask = get_eos_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)
|
||||
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)
|
||||
|
||||
# all the tp ranks should contain the same data here. data in all ranks are valid
|
||||
batch = TensorDict(
|
||||
{
|
||||
'prompts': idx,
|
||||
'responses': response,
|
||||
'input_ids': seq, # here input_ids become the whole sentences
|
||||
# 'old_log_probs': log_probs, # we will recompute old log prob with actor
|
||||
'attention_mask': attention_mask,
|
||||
'position_ids': position_ids
|
||||
},
|
||||
batch_size=batch_size)
|
||||
|
||||
# free vllm cache engine
|
||||
if self.config.free_cache_engine:
|
||||
self.inference_engine.free_cache_engine()
|
||||
|
||||
return DataProto(batch=batch)
|
||||
Reference in New Issue
Block a user