Initial commit

This commit is contained in:
PeterGriffinJin
2025-02-28 15:16:19 +00:00
commit 068516be64
207 changed files with 33063 additions and 0 deletions

View File

@@ -0,0 +1,19 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import BaseRollout
from .naive import NaiveRollout
from .hf_rollout import HFRollout
__all__ = ["BaseRollout", "NaiveRollout", "HFRollout"]

View File

@@ -0,0 +1,37 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Iterable, Union
from verl import DataProto
__all__ = ['BaseRollout']
class BaseRollout(ABC):
def __init__(self):
"""
Args:
dataloader: an Iterable of TensorDict that consistently generates prompts. Note that the dataloader
should handle when the training stops.
"""
super().__init__()
@abstractmethod
def generate_sequences(self, prompts: DataProto) -> DataProto:
"""Generate sequences"""
pass

View File

@@ -0,0 +1,140 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Rollout with huggingface models.
TODO: refactor this class. Currently, it will hang when using FSDP HybridShard. We should actually create a single GPU model.
Then, get full state_dict and bind the state_dict to the single GPU model. Then, use the single GPU model to perform generation.
"""
import contextlib
import torch
import torch.distributed
from tensordict import TensorDict
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from verl import DataProto
from verl.utils.torch_functional import get_eos_mask
from .base import BaseRollout
from transformers import GenerationConfig
__all__ = ['HFRollout']
class HFRollout(BaseRollout):
def __init__(self, module: nn.Module, config):
super().__init__()
self.config = config
self.module = module
def generate_sequences(self, prompts: DataProto) -> DataProto:
batch_size = prompts.batch.batch_size[0]
num_chunks = max(batch_size // self.config.get('micro_batch_size', batch_size), 1)
batch_prompts = prompts.chunk(chunks=num_chunks)
output = [self._generate_minibatch(p) for p in batch_prompts]
output = DataProto.concat(output)
return output
@torch.no_grad()
def _generate_minibatch(self, prompts: DataProto) -> DataProto:
idx = prompts.batch['input_ids'] # (bs, prompt_length)
attention_mask = prompts.batch['attention_mask'] # left-padded attention_mask
position_ids = prompts.batch['position_ids']
# used to construct attention_mask
eos_token_id = prompts.meta_info['eos_token_id']
pad_token_id = prompts.meta_info['pad_token_id']
batch_size = idx.size(0)
prompt_length = idx.size(1)
self.module.eval()
param_ctx = contextlib.nullcontext()
# make sampling args can be overriden by inputs
do_sample = prompts.meta_info.get('do_sample', self.config.do_sample)
response_length = prompts.meta_info.get('response_length', self.config.response_length)
top_p = prompts.meta_info.get('top_p', self.config.get('top_p', 1.0))
top_k = prompts.meta_info.get('top_k', self.config.get('top_k', 0))
if top_k is None:
top_k = 0
top_k = max(0, top_k) # to be compatible with vllm
temperature = prompts.meta_info.get('temperature', self.config.temperature)
generation_config = GenerationConfig(temperature=temperature, top_p=top_p, top_k=top_k)
if isinstance(self.module, FSDP):
# recurse need to set to False according to https://github.com/pytorch/pytorch/issues/100069
param_ctx = FSDP.summon_full_params(self.module, writeback=False, recurse=False)
with param_ctx:
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
output = self.module.generate(
input_ids=idx,
attention_mask=attention_mask,
do_sample=do_sample,
max_new_tokens=response_length,
# max_length=max_length,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
generation_config=generation_config,
# renormalize_logits=True,
output_scores=False, # this is potentially very large
return_dict_in_generate=True,
use_cache=True)
# TODO: filter out the seq with no answers like ds-chat
seq = output.sequences
# huggingface generate will stop generating when all the batch reaches [EOS].
# We have to pad to response_length
sequence_length = prompt_length + self.config.response_length
delta_length = sequence_length - seq.shape[1]
if delta_length > 0:
delta_tokens = torch.ones(size=(batch_size, delta_length), device=seq.device, dtype=seq.dtype)
delta_tokens = pad_token_id * delta_tokens
seq = torch.cat((seq, delta_tokens), dim=1)
assert seq.shape[1] == sequence_length
prompt = seq[:, :prompt_length] # (bs, prompt_length)
response = seq[:, prompt_length:] # (bs, response_length)
response_length = response.size(1)
delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1)
response_position_ids = position_ids[:, -1:] + delta_position_id
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
response_attention_mask = get_eos_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)
batch = TensorDict(
{
'prompts': prompt,
'responses': response,
'input_ids': seq,
'attention_mask': attention_mask,
'position_ids': position_ids
},
batch_size=batch_size)
# empty cache before compute old_log_prob
torch.cuda.empty_cache()
self.module.train()
return DataProto(batch=batch)

View File

@@ -0,0 +1,15 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .naive_rollout import NaiveRollout

View File

@@ -0,0 +1,119 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
In single GPU rollout, the sequences are generated directly by sampling from the model.
The output will contain
1. output_ids
2. attention_masks (left padding)
3. eos_masks
4. log_probs
"""
from typing import Iterable, Union
import torch
import torch.nn.functional as F
from tensordict import TensorDict
from torch import nn
from verl import DataProto
from verl.utils.torch_functional import logprobs_from_logits
from ..base import BaseRollout
__all__ = ['NativeRollout']
class NaiveRollout(BaseRollout):
def __init__(self, module: nn.Module, config):
"""A naive rollout. It requires the module to be compatible with huggingface APIs. That is:
The module should define __call__ to receive input_ids, attention_mask and position_ids.
It outputs a structure that contains logits field.
Args:
module: module here follows huggingface APIs
config: DictConfig
"""
super().__init__()
self.config = config
self.module = module
@torch.no_grad()
def generate_sequences(self, prompts: DataProto) -> DataProto:
"""Generate sequences"""
idx = prompts.batch['input_ids'] # (bs, prompt_length)
attention_mask = prompts.batch['attention_mask'] # left-padded attention_mask
position_ids = prompts.batch['position_ids']
# used to construct attention_mask
eos_token_id = prompts.meta_info['eos_token_id']
batch_size = idx.size(0)
prompt_length = idx.size(1)
self.module.eval()
prev_attention_mask = torch.ones(size=(batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device)
logits_lst = []
for _ in range(self.config.response_length):
# if the sequence context is growing too long we must crop it at block_size
# idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
idx_cond = idx
# forward the model to get the logits for the index in the sequence
# we use huggingface APIs here
output = self.module(input_ids=idx_cond, attention_mask=attention_mask, position_ids=position_ids)
logits = output.logits
# pluck the logits at the final step and scale by desired temperature
logits = logits[:, -1, :] / self.config.temperature # (bs, vocab_size)
# optionally crop the logits to only the top k options
if self.config.top_k is not None:
v, _ = torch.topk(logits, min(self.config.top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
# sample from the distribution
if self.config.do_sample:
idx_next = torch.multinomial(probs, num_samples=1)
else:
idx_next = torch.argmax(probs, dim=-1, keepdim=True)
attention_mask = torch.cat((attention_mask, prev_attention_mask), dim=-1)
prev_attention_mask = torch.logical_and(idx_next != eos_token_id, prev_attention_mask.bool())
prev_attention_mask.to(attention_mask.dtype)
position_ids = torch.cat((position_ids, position_ids[:, -1:] + 1), dim=-1)
# append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=1)
logits_lst.append(logits)
logits = torch.stack(logits_lst, dim=1) # (bs, response_length, vocab_size)
prompts = idx[:, :prompt_length] # (bs, prompt_length)
response = idx[:, prompt_length:] # (bs, response_length)
log_probs = logprobs_from_logits(logits=logits, labels=response)
batch = TensorDict(
{
'input_ids': prompts,
'responses': response,
'sequences': idx,
'old_log_probs': log_probs,
'attention_mask': attention_mask,
'position_ids': position_ids,
},
batch_size=batch_size)
self.module.train()
return DataProto(batch=batch)

View File

@@ -0,0 +1,162 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The base tokenizer class, required for any hybrid engine based rollout or inference with vLLM.
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Union
__all__ = ['HybridEngineBaseTokenizer']
class HybridEngineBaseTokenizer(ABC):
"""the tokenizer property and function name should align with HF's to meet vllm requirement"""
@property
@abstractmethod
def vocab_size(self):
"""
`int`: Size of the base vocabulary (without the added tokens).
"""
pass
@property
@abstractmethod
def pad_token_id(self):
"""
`Optional[int]`: Id of the padding token in the vocabulary. Returns `None` if the token has not been set.
"""
pass
@property
@abstractmethod
def eos_token_id(self):
"""
`Optional[int]`: Id of the end of sentence token in the vocabulary. Returns `None` if the token has not been
set.
"""
pass
@property
@abstractmethod
def all_special_ids(self) -> List[int]:
"""
`List[int]`: List the ids of the special tokens(`'<unk>'`, `'<cls>'`, etc.) mapped to class attributes.
"""
pass
@property
@abstractmethod
def all_special_tokens(self) -> List[str]:
"""
`List[str]`: A list of the unique special tokens (`'<unk>'`, `'<cls>'`, ..., etc.).
Convert tokens of `tokenizers.AddedToken` type to string.
"""
pass
@abstractmethod
def encode(self, text):
"""
Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary.
Args:
text (`str`, `List[str]` or `List[int]`):
The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the
`tokenize` method) or a list of integers.
text_pair (`str`, `List[str]` or `List[int]`, *optional*):
Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using
the `tokenize` method) or a list of integers.
"""
pass
@abstractmethod
def decode(
self,
token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = None,
**kwargs,
) -> str:
"""
Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
tokens and clean up tokenization spaces.
Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
Args:
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
List of tokenized input ids. Can be obtained using the `__call__` method.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens in the decoding.
clean_up_tokenization_spaces (`bool`, *optional*):
Whether or not to clean up the tokenization spaces. If `None`, will default to
`self.clean_up_tokenization_spaces`.
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
Returns:
`str`: The decoded sentence.
"""
pass
@abstractmethod
def convert_ids_to_tokens(self,
ids: Union[int, List[int]],
skip_special_tokens: bool = False) -> Union[str, List[str]]:
"""
Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
added tokens.
Args:
ids (`int` or `List[int]`):
The token id (or token ids) to convert to tokens.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens in the decoding.
Returns:
`str` or `List[str]`: The decoded token(s).
"""
pass
@abstractmethod
def get_added_vocab(self) -> Dict[str, int]:
"""
Returns the added tokens in the vocabulary as a dictionary of token to index. Results might be different from
the fast call because for now we always add the tokens even if they are already in the vocabulary. This is
something we should change.
Returns:
`Dict[str, int]`: The added tokens.
"""
pass
@abstractmethod
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""
Converts a sequence of tokens in a single string. The most simple way to do it is `" ".join(tokens)` but we
often want to remove sub-word tokenization artifacts at the same time.
Args:
tokens (`List[str]`): The token to join in a string.
Returns:
`str`: The joined tokens.
"""
pass
@property
def is_fast(self):
return False

View File

@@ -0,0 +1,15 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .vllm_rollout import vLLMRollout

View File

@@ -0,0 +1,226 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The vllm_rollout that can be applied in different backend
When working with FSDP:
- Use DTensor weight loader (recommended) or HF weight loader
- Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM
When working with Megatron:
- Use Megatron weight loader
- During training, only the current pp stage holds the parameters
- Before inference, broadcast the parameters of the current pp rank to all other pp ranks (all pp ranks holds all the parameters)
- Bind the parameters to the inference engine
- Do inference in tp. pp is treated as additional dp
- After inference, all the parameters that doesn't belong to this pp rank is freed.
"""
from typing import List
from contextlib import contextmanager
from omegaconf import DictConfig
import torch
import torch.distributed
from tensordict import TensorDict
from torch import nn
from verl import DataProto
from verl.utils.torch_functional import get_eos_mask, pad_sequence_to_length
from verl.workers.rollout.base import BaseRollout
from verl.third_party.vllm import LLM, vllm_version
from verl.third_party.vllm import parallel_state as vllm_ps
from vllm import SamplingParams
# TODO
# 1. support pp in vllm
# 2. passing tokenizer is not necessary? no encoding/decoding is happending here
# 3. simplify init logics
# NOTE(sgm): add for verl. We can optimize it by making the dataloader yield List[int] without padding.
def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[int]:
# remove the left padding in the prompt token_id
# pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
token_ids = prompt_token_ids[non_pad_index:].tolist()
return token_ids
class vLLMRollout(BaseRollout):
def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model_hf_config, **kwargs):
"""A vLLM rollout. It requires the module is supported by the vllm.
Args:
module: module here follows huggingface APIs
config: DictConfig
tokenizer: the task/model tokenizer
model_hf_config: the huggingface config to initiallize the generating model in vllm
**kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group
"""
super().__init__()
self.config = config
assert not (not config.enforce_eager and config.free_cache_engine), \
"disable CUDA graph (enforce_eager = False) if free cache engine"
tensor_parallel_size = self.config.get('tensor_model_parallel_size', 1)
assert tensor_parallel_size <= torch.distributed.get_world_size(), \
"tensor parallel size should be less than or equal to the world size"
if kwargs.get('train_tp', None) is not None:
# deployed with megatron
import os
os.environ['CUDA_TIMER_STREAM_KAFKA_ENABLE'] = '0'
os.environ['MEGATRON_IMPORT_TIMERS'] = '0'
train_tp = kwargs.get('train_tp', None)
num_tp_per_train_tp = train_tp // tensor_parallel_size
if vllm_version in ('0.4.2', '0.5.4', '0.6.3'):
vllm_ps.initialize_parallel_state(tensor_model_parallel_size=tensor_parallel_size,
num_tp_per_train_tp=num_tp_per_train_tp)
assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, \
"model context length should be greater than total sequence length"
self.inference_engine = LLM(actor_module,
tokenizer=tokenizer,
model_hf_config=model_hf_config,
tensor_parallel_size=tensor_parallel_size,
dtype=config.dtype,
enforce_eager=config.enforce_eager,
gpu_memory_utilization=config.gpu_memory_utilization,
skip_tokenizer_init=False,
max_model_len=config.prompt_length + config.response_length,
load_format=config.load_format)
# Offload vllm model to reduce peak memory usage
self.inference_engine.offload_model_weights()
kwargs = dict(
n=1,
logprobs=1, # can be set to 0 and let actor to recompute
max_tokens=config.response_length,
)
# we may detokenize the result all together later
if vllm_version in ('0.4.2', '0.5.4', '0.6.3'):
kwargs['detokenize'] = False
# supporting adding any sampling params from the config file
for k in config.keys():
if hasattr(SamplingParams(), str(k)):
kwargs[k] = config.get(k)
print(f"kwargs: {kwargs}")
self.sampling_params = SamplingParams(**kwargs)
self.pad_token_id = tokenizer.pad_token_id
@contextmanager
def update_sampling_params(self, **kwargs):
# update sampling params
old_sampling_params_args = {}
if kwargs:
for key, value in kwargs.items():
if hasattr(self.sampling_params, key):
old_value = getattr(self.sampling_params, key)
old_sampling_params_args[key] = old_value
setattr(self.sampling_params, key, value)
yield
# roll back to previous sampling params
# if len(old_sampling_params_args):
for key, value in old_sampling_params_args.items():
setattr(self.sampling_params, key, value)
@torch.no_grad()
def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
# rebuild vllm cache engine
if self.config.free_cache_engine:
self.inference_engine.init_cache_engine()
idx = prompts.batch['input_ids'] # (bs, prompt_length)
# left-padded attention_mask
attention_mask = prompts.batch['attention_mask']
position_ids = prompts.batch['position_ids']
# used to construct attention_mask
eos_token_id = prompts.meta_info['eos_token_id']
batch_size = idx.size(0)
idx_list = []
# parse idx from torch.Tensor to List[List[str]]
for i in range(batch_size):
idx_list.append(_pre_process_inputs(self.pad_token_id, idx[i]))
do_sample = prompts.meta_info.get('do_sample', True)
if not do_sample:
kwargs = {
'best_of': 1,
'top_p': 1.0,
'top_k': -1,
'min_p': 0.0,
'temperature': 0,
'n': 1 # if greedy, only 1 response
}
# users can customize different sampling_params at different run
with self.update_sampling_params(**kwargs):
output = self.inference_engine.generate(
prompts=None, # because we have already convert it to prompt token id
sampling_params=self.sampling_params,
prompt_token_ids=idx_list,
use_tqdm=False)
# TODO(sgm): disable logprob when recompute_log_prob is enable
# if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length)
response = output[0].to(idx.device)
log_probs = output[1].to(idx.device)
if response.shape[1] < self.config.response_length:
response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id)
log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id)
if self.config.n > 1 and do_sample:
idx = idx.repeat_interleave(self.config.n, dim=0)
attention_mask = attention_mask.repeat_interleave(self.config.n, dim=0)
position_ids = position_ids.repeat_interleave(self.config.n, dim=0)
batch_size = batch_size * self.config.n
seq = torch.cat([idx, response], dim=-1)
response_length = response.size(1)
delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1)
# TODO(sgm): fix position_ids on right_pad
# prompt: left pad + response: right pad
# attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
response_position_ids = position_ids[:, -1:] + delta_position_id
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
response_attention_mask = get_eos_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)
# all the tp ranks should contain the same data here. data in all ranks are valid
batch = TensorDict(
{
'prompts': idx,
'responses': response,
'input_ids': seq, # here input_ids become the whole sentences
# 'old_log_probs': log_probs, # we will recompute old log prob with actor
'attention_mask': attention_mask,
'position_ids': position_ids
},
batch_size=batch_size)
# free vllm cache engine
if self.config.free_cache_engine:
self.inference_engine.free_cache_engine()
return DataProto(batch=batch)