Initial commit
This commit is contained in:
202
LICENSE
Normal file
202
LICENSE
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
1
Notice.txt
Normal file
1
Notice.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
Copyright 2023-2024 Bytedance Ltd. and/or its affiliates
|
||||||
182
README.md
Normal file
182
README.md
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
# Search-R1: Train your LLMs to reason and call a search engine with reinforcement learning
|
||||||
|
|
||||||
|
<strong>Search-R1</strong> is a reproduction of <strong>DeepSeek-R1(-Zero)</strong> methods for <em>training reasoning and searching (tool-call) interleaved LLMs</em>. We built upon [veRL](https://github.com/volcengine/verl).
|
||||||
|
|
||||||
|
Through RL (rule-based outcome reward), the 3B **base** LLM (both Qwen2.5-3b-base and Llama3.2-3b-base) develops reasoning and search engine calling abilities all on its own.
|
||||||
|
|
||||||
|
Twitter thread: xxx
|
||||||
|
|
||||||
|
Full experiment log: [link](https://wandb.ai/uiuc-dmg/Search-R1-open/)
|
||||||
|
|
||||||
|
The paper will be released soon!
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
## Links
|
||||||
|
|
||||||
|
- [Installation](#installation)
|
||||||
|
- [Quick start](#quick-start)
|
||||||
|
- [Preliminary results](#preliminary-results)
|
||||||
|
- [Use your own dataset](#use-your-own-dataset)
|
||||||
|
- [Use your own search engine](#use-your-own-search-engine)
|
||||||
|
- [Ackowledge](#acknowledge)
|
||||||
|
- [Citations](#citations)
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
### Search-r1 environment
|
||||||
|
```bash
|
||||||
|
conda create -n searchr1 python=3.9
|
||||||
|
conda activate searchr1
|
||||||
|
# install torch [or you can skip this step and let vllm to install the correct version for you]
|
||||||
|
pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cu121
|
||||||
|
# install vllm
|
||||||
|
pip3 install vllm==0.6.3 # or you can install 0.5.4, 0.4.2 and 0.3.1
|
||||||
|
|
||||||
|
# verl
|
||||||
|
pip install -e .
|
||||||
|
|
||||||
|
# flash attention 2
|
||||||
|
pip3 install flash-attn --no-build-isolation
|
||||||
|
pip install wandb
|
||||||
|
```
|
||||||
|
|
||||||
|
### Retriever environment (optional)
|
||||||
|
If you would like to call a local retriever as the search engine, you can install the environment as follows. (We recommend using a seperate environment.)
|
||||||
|
```bash
|
||||||
|
conda create -n retriever python=3.10
|
||||||
|
conda activate retriever
|
||||||
|
|
||||||
|
# we recommend installing torch with conda for faiss-gpu
|
||||||
|
conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=12.1 -c pytorch -c nvidia
|
||||||
|
pip install transformers datasets
|
||||||
|
|
||||||
|
## install the gpu version faiss to guarantee efficient RL rollout
|
||||||
|
conda install -c pytorch -c nvidia faiss-gpu=1.8.0
|
||||||
|
|
||||||
|
## API function
|
||||||
|
pip install uvicorn fastapi
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
Train a reasoning + search LLM on NQ dataset with e5 as the retriever and wikipedia as the corpus.
|
||||||
|
|
||||||
|
(1) Download the indexing and corpus.
|
||||||
|
```bash
|
||||||
|
save_path=/the/path/to/save
|
||||||
|
python scripts/download.py --save_path $save_path
|
||||||
|
cat $save_path/part_* > $save_path/e5_Flat.index
|
||||||
|
gzip -d $save_path/wiki-18.jsonl.gz
|
||||||
|
```
|
||||||
|
|
||||||
|
(2) Process the NQ dataset.
|
||||||
|
```bash
|
||||||
|
python scripts/data_process/nq_search.py
|
||||||
|
```
|
||||||
|
|
||||||
|
(3) Launch a local retrieval server.
|
||||||
|
```bash
|
||||||
|
conda activate retriever
|
||||||
|
bash retrieval_launch.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
(4) Run training with Qwen2.5-3b-Instruct.
|
||||||
|
```bash
|
||||||
|
conda activate searchr1
|
||||||
|
bash train_ppo.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
## Preliminary results
|
||||||
|
|
||||||
|
(1) The base model (llama3.2-3b-base) learns to call the search engine and obtain improved performance.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
(2) The base model (Qwen2.5-7b-base) can learn to conduct multi-turn search engine calling and reasoning with RL.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
## Use your own dataset
|
||||||
|
|
||||||
|
### QA data
|
||||||
|
For each question-answer sample, it should be a dictionary containing the desired content as below:
|
||||||
|
|
||||||
|
```
|
||||||
|
data = {
|
||||||
|
"data_source": data_source,
|
||||||
|
"prompt": [{
|
||||||
|
"role": "user",
|
||||||
|
"content": question,
|
||||||
|
}],
|
||||||
|
"ability": "fact-reasoning",
|
||||||
|
"reward_model": {
|
||||||
|
"style": "rule",
|
||||||
|
"ground_truth": solution
|
||||||
|
},
|
||||||
|
"extra_info": {
|
||||||
|
'split': split,
|
||||||
|
'index': idx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
You can refer to ```scripts/data_process/nq_search.py``` for a concrete data processing example.
|
||||||
|
|
||||||
|
### Corpora
|
||||||
|
|
||||||
|
It is recommended to make your corpus a jsonl file, where each line (a dictionary with "id" key and "contents" key) corresponds to one passage. You can refer to ```example/corpus.jsonl``` for an example.
|
||||||
|
|
||||||
|
The "id" key corresponds to the passage id, while the "contents" key corresponds to the passage content.
|
||||||
|
For example:
|
||||||
|
```
|
||||||
|
{"id": "0", "contents": "Evan Morris Evan L. Morris (January 26, 1977 \u2013 July 9, 2015) was a lobbyist for Genentech and its parent corporation Roche in Washington."}
|
||||||
|
...
|
||||||
|
{"id": "100", "contents": "Three years later, when the United States Exploring Expedition to little-known portions of the globe was organised under Charles Wilkes, Hale was recommended, while yet an undergraduate."}
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
**Index your corpora (optional).**
|
||||||
|
If you would like to use a local retriever as the search engine, you can index your own corpus by:
|
||||||
|
```
|
||||||
|
bash search_r1/search/build_index.sh
|
||||||
|
```
|
||||||
|
You can change ```retriever_name``` and ```retriever_model``` to your interested off-the-shelf retriever.
|
||||||
|
|
||||||
|
## Use your own search engine
|
||||||
|
|
||||||
|
The main philosophy is to launch a local or remote search engine server seperately from the main RL training pipeline.
|
||||||
|
|
||||||
|
The LLM can call the search engine by calling the search API (e.g., "http://127.0.0.1:8000/retrieve").
|
||||||
|
|
||||||
|
You can refer to ```search_r1/search/retriever_server.py``` for an example of launching a local retriever server.
|
||||||
|
|
||||||
|
## To do
|
||||||
|
- Support google search / bing search / brave search API and others.
|
||||||
|
- Support LoRA tuning.
|
||||||
|
- Support supervised finetuning.
|
||||||
|
- Support off-the-shelf rerankers.
|
||||||
|
|
||||||
|
## Acknowledge
|
||||||
|
|
||||||
|
The concept of Search-R1 is inspired by [Deepseek-R1](https://github.com/deepseek-ai/DeepSeek-R1) and [TinyZero](https://github.com/Jiayi-Pan/TinyZero/tree/main).
|
||||||
|
Its implementation is built upon [veRL](https://github.com/volcengine/verl) and [RAGEN](https://github.com/ZihanWang314/RAGEN/tree/main).
|
||||||
|
We sincerely appreciate the efforts of these teams for their contributions to open-source research and development.
|
||||||
|
|
||||||
|
## Citations
|
||||||
|
|
||||||
|
To be added
|
||||||
|
|
||||||
|
<!-- ```md
|
||||||
|
@misc{Search-R1,
|
||||||
|
author = {xxx},
|
||||||
|
title = {xxx},
|
||||||
|
year = {2025},
|
||||||
|
organization = {GitHub},
|
||||||
|
url = {https://github.com/PeterGriffinJin/Search-R1},
|
||||||
|
}
|
||||||
|
``` -->
|
||||||
103
VERL_README.md
Normal file
103
VERL_README.md
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
<h1 style="text-align: center;">veRL: Volcano Engine Reinforcement Learning for LLM</h1>
|
||||||
|
|
||||||
|
veRL is a flexible, efficient and production-ready RL training framework designed for large language models (LLMs).
|
||||||
|
|
||||||
|
veRL is the open-source version of **[HybridFlow: A Flexible and Efficient RLHF Framework](https://arxiv.org/abs/2409.19256v2)** paper.
|
||||||
|
|
||||||
|
veRL is flexible and easy to use with:
|
||||||
|
|
||||||
|
- **Easy extension of diverse RL algorithms**: The Hybrid programming model combines the strengths of single-controller and multi-controller paradigms to enable flexible representation and efficient execution of complex Post-Training dataflows. Allowing users to build RL dataflows in a few lines of code.
|
||||||
|
|
||||||
|
- **Seamless integration of existing LLM infra with modular APIs**: Decouples computation and data dependencies, enabling seamless integration with existing LLM frameworks, such as PyTorch FSDP, Megatron-LM and vLLM. Moreover, users can easily extend to other LLM training and inference frameworks.
|
||||||
|
|
||||||
|
- **Flexible device mapping**: Supports various placement of models onto different sets of GPUs for efficient resource utilization and scalability across different cluster sizes.
|
||||||
|
|
||||||
|
- Readily integration with popular HuggingFace models
|
||||||
|
|
||||||
|
|
||||||
|
veRL is fast with:
|
||||||
|
|
||||||
|
- **State-of-the-art throughput**: By seamlessly integrating existing SOTA LLM training and inference frameworks, veRL achieves high generation and training throughput.
|
||||||
|
|
||||||
|
- **Efficient actor model resharding with 3D-HybridEngine**: Eliminates memory redundancy and significantly reduces communication overhead during transitions between training and generation phases.
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
| <a href="https://verl.readthedocs.io/en/latest/index.html"><b>Documentation</b></a> | <a href="https://arxiv.org/abs/2409.19256v2"><b>Paper</b></a> | <a href="https://join.slack.com/t/verlgroup/shared_invite/zt-2w5p9o4c3-yy0x2Q56s_VlGLsJ93A6vA"><b>Slack</b></a> | <a href="https://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/WeChat.JPG"><b>Wechat</b></a> |
|
||||||
|
|
||||||
|
<!-- <a href=""><b>Slides</b></a> | -->
|
||||||
|
</p>
|
||||||
|
|
||||||
|
## News
|
||||||
|
|
||||||
|
- [2024/12] The team presented <a href="https://neurips.cc/Expo/Conferences/2024/workshop/100677">Post-training LLMs: From Algorithms to Infrastructure</a> at NeurIPS 2024. [Slides](https://github.com/eric-haibin-lin/verl-data/tree/neurips) and [video](https://neurips.cc/Expo/Conferences/2024/workshop/100677) available.
|
||||||
|
- [2024/10] veRL is presented at Ray Summit. [Youtube video](https://www.youtube.com/watch?v=MrhMcXkXvJU&list=PLzTswPQNepXntmT8jr9WaNfqQ60QwW7-U&index=37) available.
|
||||||
|
- [2024/08] HybridFlow (verl) is accepted to EuroSys 2025.
|
||||||
|
|
||||||
|
## Key Features
|
||||||
|
|
||||||
|
- **FSDP** and **Megatron-LM** for training.
|
||||||
|
- **vLLM** and **TGI** for rollout generation, **SGLang** support coming soon.
|
||||||
|
- huggingface models support
|
||||||
|
- Supervised fine-tuning
|
||||||
|
- Reward model training
|
||||||
|
- Reinforcement learning from human feedback with PPO
|
||||||
|
- flash-attention integration, sequence packing
|
||||||
|
- scales up to 70B models and hundreds of GPUs
|
||||||
|
- experiment tracking with wandb and mlflow
|
||||||
|
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
Checkout this [Jupyter Notebook](https://github.com/volcengine/verl/tree/main/examples/ppo_trainer/verl_getting_started.ipynb) to get started with PPO training with a single 24GB L4 GPU (**FREE** GPU quota provided by [Lighting Studio](https://lightning.ai/hlin-verl/studios/verl-getting-started))!
|
||||||
|
|
||||||
|
**Quickstart:**
|
||||||
|
- [Installation](https://verl.readthedocs.io/en/latest/start/install.html)
|
||||||
|
- [Quickstart](https://verl.readthedocs.io/en/latest/start/quickstart.html)
|
||||||
|
|
||||||
|
**Running an PPO example step-by-step:**
|
||||||
|
- Data and Reward Preparation
|
||||||
|
- [Prepare Data (Parquet) for Post-Training](https://verl.readthedocs.io/en/latest/preparation/prepare_data.html)
|
||||||
|
- [Implement Reward Function for Dataset](https://verl.readthedocs.io/en/latest/preparation/reward_function.html)
|
||||||
|
- Understanding the PPO Example
|
||||||
|
- [PPO Example Architecture](https://verl.readthedocs.io/en/latest/examples/ppo_code_architecture.html)
|
||||||
|
- [Config Explanation](https://verl.readthedocs.io/en/latest/examples/config.html)
|
||||||
|
- [Run GSM8K Example](https://verl.readthedocs.io/en/latest/examples/gsm8k_example.html)
|
||||||
|
|
||||||
|
**Reproducible algorithm baselines:**
|
||||||
|
- [PPO](https://verl.readthedocs.io/en/latest/experiment/ppo.html)
|
||||||
|
|
||||||
|
**For code explanation and advance usage (extension):**
|
||||||
|
- PPO Trainer and Workers
|
||||||
|
- [PPO Ray Trainer](https://verl.readthedocs.io/en/latest/workers/ray_trainer.html)
|
||||||
|
- [PyTorch FSDP Backend](https://verl.readthedocs.io/en/latest/workers/fsdp_workers.html)
|
||||||
|
- [Megatron-LM Backend](https://verl.readthedocs.io/en/latest/index.html)
|
||||||
|
- Advance Usage and Extension
|
||||||
|
- [Ray API Design Tutorial](https://verl.readthedocs.io/en/latest/advance/placement.html)
|
||||||
|
- [Extend to other RL(HF) algorithms](https://verl.readthedocs.io/en/latest/advance/dpo_extension.html)
|
||||||
|
- [Add models with the FSDP backend](https://verl.readthedocs.io/en/latest/advance/fsdp_extension.html)
|
||||||
|
- [Add models with the Megatron-LM backend](https://verl.readthedocs.io/en/latest/advance/megatron_extension.html)
|
||||||
|
|
||||||
|
|
||||||
|
## Citation and acknowledgement
|
||||||
|
|
||||||
|
If you find the project helpful, please cite:
|
||||||
|
- [HybridFlow: A Flexible and Efficient RLHF Framework](https://arxiv.org/abs/2409.19256v2)
|
||||||
|
- [A Framework for Training Large Language Models for Code Generation via Proximal Policy Optimization](https://i.cs.hku.hk/~cwu/papers/gmsheng-NL2Code24.pdf)
|
||||||
|
|
||||||
|
```tex
|
||||||
|
@article{sheng2024hybridflow,
|
||||||
|
title = {HybridFlow: A Flexible and Efficient RLHF Framework},
|
||||||
|
author = {Guangming Sheng and Chi Zhang and Zilingfeng Ye and Xibin Wu and Wang Zhang and Ru Zhang and Yanghua Peng and Haibin Lin and Chuan Wu},
|
||||||
|
year = {2024},
|
||||||
|
journal = {arXiv preprint arXiv: 2409.19256}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
verl is inspired by the design of Nemo-Aligner, Deepspeed-chat and OpenRLHF. The project is adopted and supported by Anyscale, Bytedance, LMSys.org, Shanghai AI Lab, Tsinghua University, UC Berkeley, UCLA, UIUC, and University of Hong Kong.
|
||||||
|
|
||||||
|
## Publications Using veRL
|
||||||
|
- [Enhancing Multi-Step Reasoning Abilities of Language Models through Direct Q-Function Optimization](https://arxiv.org/abs/2410.09302)
|
||||||
|
- [Flaming-hot Initiation with Regular Execution Sampling for Large Language Models](https://arxiv.org/abs/2410.21236)
|
||||||
|
- [Process Reinforcement Through Implicit Rewards](https://github.com/PRIME-RL/PRIME/)
|
||||||
|
|
||||||
|
We are HIRING! Send us an [email](mailto:haibin.lin@bytedance.com) if you are interested in internship/FTE opportunities in MLSys/LLM reasoning/multimodal alignment.
|
||||||
43
example/case.txt
Normal file
43
example/case.txt
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
<|im_start|>system
|
||||||
|
You are a helpful assistant.<|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
Answer the given question. You must conduct reasoning inside <think> and </think> first every time you get new information. After reasoning, if you find you lack some knowledge, you can call a search engine by <search> query </search> and it will return the top searched results between <information> and </information>. You can search as many times as your want. If you find no further external knowledge needed, you can directly provide the answer inside <answer> and </answer>, without detailed illustrations. For example, <answer> Beijing </answer>. Question: paul walkers cars in fast and furious list?
|
||||||
|
<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
<think>I need to search for Paul Walker's cars in Fast and Furious movies.</think>
|
||||||
|
<search> Paul Walker's cars in Fast and Furious </search>
|
||||||
|
|
||||||
|
<information>Doc 1(Title: "Paul Walker") also had a large car collection of about 30 cars, a portion of which he co-owned along with Roger Rodas. The collection included (but is not limited to): Roger Rodas became Walker's financial adviser in 2007 and helped to establish Reach Out Worldwide. Rodas, a pro-am racer was the CEO of Always Evolving, a Valencia high-end vehicle performance shop owned by Walker. Walker was also a close friend of his ""2 Fast 2 Furious"" co-star Tyrese Gibson. Vin Diesel considered Walker to be like a brother, both on and off screen, and affectionately called him ""Pablo"". Walker's mother referred to
|
||||||
|
Doc 2(Title: "Paul Walker") Paul Walker Paul William Walker IV (September 12, 1973 – November 30, 2013) was an American actor best known for his role as Brian O'Conner in ""The Fast and the Furious"" franchise. Walker first gained prominence in 1999 with roles in the teen films ""She's All That"" and ""Varsity Blues"". In 2001, he gained international fame for his role in the street racing action film ""The Fast and the Furious"" (2001), a role he reprised in five of the next six installments, but died in 2013 in the middle of filming ""Furious 7"" (2015). Walker began his career guest-starring on
|
||||||
|
Doc 3(Title: "Paul Walker") of Porsche in a separate lawsuit filed by Roger Rodas' widow, Kristine. The ruling had no bearing on two other cases against Porsche which have been filed by Walker's father, who is also the executor of the actor's estate, and his daughter. Walker's father and daughter both reached an agreement with Porsche. Paul Walker Paul William Walker IV (September 12, 1973 – November 30, 2013) was an American actor best known for his role as Brian O'Conner in ""The Fast and the Furious"" franchise. Walker first gained prominence in 1999 with roles in the teen films ""She's All That"" and</information>
|
||||||
|
|
||||||
|
From the information provided, it's clear that Paul Walker was a part of the "Fast and Furious" series, but the specific list of cars is not mentioned. Since I lack this particular detail, I will call a search engine to get the specific list of cars Paul Walker drove in the "Fast and Furious" movies.
|
||||||
|
|
||||||
|
<search> list of cars paul walker drove in Fast and Furious </search>
|
||||||
|
|
||||||
|
<information>Doc 1(Title: "Paul Walker") also had a large car collection of about 30 cars, a portion of which he co-owned along with Roger Rodas. The collection included (but is not limited to): Roger Rodas became Walker's financial adviser in 2007 and helped to establish Reach Out Worldwide. Rodas, a pro-am racer was the CEO of Always Evolving, a Valencia high-end vehicle performance shop owned by Walker. Walker was also a close friend of his ""2 Fast 2 Furious"" co-star Tyrese Gibson. Vin Diesel considered Walker to be like a brother, both on and off screen, and affectionately called him ""Pablo"". Walker's mother referred to
|
||||||
|
Doc 2(Title: "Paul Walker") Paul Walker Paul William Walker IV (September 12, 1973 – November 30, 2013) was an American actor best known for his role as Brian O'Conner in ""The Fast and the Furious"" franchise. Walker first gained prominence in 1999 with roles in the teen films ""She's All That"" and ""Varsity Blues"". In 2001, he gained international fame for his role in the street racing action film ""The Fast and the Furious"" (2001), a role he reprised in five of the next six installments, but died in 2013 in the middle of filming ""Furious 7"" (2015). Walker began his career guest-starring on
|
||||||
|
Doc 3(Title: "Paul Walker") of Porsche in a separate lawsuit filed by Roger Rodas' widow, Kristine. The ruling had no bearing on two other cases against Porsche which have been filed by Walker's father, who is also the executor of the actor's estate, and his daughter. Walker's father and daughter both reached an agreement with Porsche. Paul Walker Paul William Walker IV (September 12, 1973 – November 30, 2013) was an American actor best known for his role as Brian O'Conner in ""The Fast and the Furious"" franchise. Walker first gained prominence in 1999 with roles in the teen films ""She's All That"" and</information>
|
||||||
|
|
||||||
|
Based on the updated information, it seems the focus was on his career, financials, and family. However, I am still missing the specific list of cars he drove in the "Fast and Furious" movies. Since it appears that the information might not be contained within the accessed documents, and I have no further search queries to make, I will provide an answer based on the details I have.
|
||||||
|
|
||||||
|
<answer> Charger </answer>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
<|im_start|>system
|
||||||
|
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
Answer the given question. You must conduct reasoning inside <think> and </think> first every time you get new information. After reasoning, if you find you lack some knowledge, you can call a search engine by <search> query </search> and it will return the top searched results between <information> and </information>. You can search as many times as your want. If you find no further external knowledge needed, you can directly provide the answer inside <answer> and </answer>, without detailed illustrations. For example, <answer> Beijing </answer>. Question: the heart muscle is stimulated to contract by electrical impulses which are generated where?
|
||||||
|
<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
<think> I remember that the heart's electrical impulses originate from a specific part of the heart, but I need to confirm which part.</think>
|
||||||
|
<search> origin of heart's electrical impulses </search>
|
||||||
|
|
||||||
|
<information>Doc 1(Title: "Cardiac pacemaker") Cardiac pacemaker The contraction of cardiac muscle (heart muscle) in all animals is initiated by electrical impulses known as action potentials. The rate at which these impulses fire controls the rate of cardiac contraction, that is, the heart rate. The cells that create these rhythmic impulses, setting the pace for blood pumping, are called pacemaker cells, and they directly control the heart rate. They make up the cardiac pacemaker, that is, the natural pacemaker of the heart. In most humans, the concentration of pacemaker cells in the sinoatrial (SA) node is the natural pacemaker, and the resultant rhythm is a
|
||||||
|
Doc 2(Title: "Electrical conduction system of the heart") Electrical conduction system of the heart The electrical conduction system of the heart transmits signals generated usually by the sinoatrial node to cause contraction of the heart muscle. The pacemaking signal generated in the sinoatrial node travels through the right atrium to the atrioventricular node, along the Bundle of His and through bundle branches to cause contraction of the heart muscle. This signal stimulates contraction first of the right and left atrium, and then the right and left ventricles. This process allows blood to be pumped throughout the body. The conduction system consists of specialised heart muscle cells, and is
|
||||||
|
Doc 3(Title: "Sinoatrial node") Sinoatrial node The sinoatrial node (SA node), also known as sinus node, is a group of cells located in the wall of the right atrium of the heart. These cells have the ability to spontaneously produce an electrical impulse (action potential; see below for more details), that travels through the heart via the electrical conduction system (see figure 1) causing it to contract. In a healthy heart, the SA node continuously produces action potential, setting the rhythm of the heart and so is known as the heart's natural pacemaker. The rate of action potential production (and therefore the heart rate)</information>
|
||||||
|
|
||||||
|
From the information provided, it is clear that the heart's electrical impulses are generated in the sinoatrial (SA) node, which is often referred to as the heart's natural pacemaker. </think>
|
||||||
|
|
||||||
|
<answer> sinoatrial (SA) node </answer>
|
||||||
10
example/corpus.jsonl
Normal file
10
example/corpus.jsonl
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
{"id": "0", "contents": "\"Evan Morris\"\nEvan Morris Evan L. Morris (January 26, 1977 \u2013 July 9, 2015) was a lobbyist for Genentech and its parent corporation Roche in Washington, D.C. He began his career in Washington as an intern in the Clinton White House at age 18. He began his lobbying work at Patton Boggs before moving on to Roche in 2005. His early work at Roche involved government policy on Medicare and Medicaid, flu preparedness, and hepatitis C. His work at Genentech is being investigated by the Wall Street Journal, which states that \u201cshaping up to be one of the biggest U.S. investigations into"}
|
||||||
|
{"id": "1", "contents": "\"Horatio Hale\"\nconsisted of an Algonkin vocabulary, which he gathered from a band of Indians who had camped on the college grounds. Three years later, when the United States Exploring Expedition to little-known portions of the globe was organised under Charles Wilkes, Hale was recommended, while yet an undergraduate, for the post of ethnologist and philologist, and obtained the appointment. From 1838 to 1842, he was employed in the work of the expedition, visiting South America, Australasia, Polynesia, and North-western America, then known as Oregon. From this point he returned overland. The Hale Passages of Puget Sound were named in recognition of"}
|
||||||
|
{"id": "2", "contents": "\"Dibba Al-Hisn\"\nis believed to be the site where the Portuguese during the Habsburg Dynasty built a fort and a wall around the city. In August 1648, the Arabs besieged Muscat, Oman, and on October 31, 1648, a treaty was signed between the two opponents. The terms required the Portuguese to build the fortress of Kuriyat, Dibba Al-Hisn and Matrah (Oman). Note: There have been land disputes between Dibba Al-Hisn and Dibba Al-Baya, but these were resolved in the 1990s. </div> A branch of the Community College of the University of Sharjah is available in Dibba Al-Hisn. Alkhalidya Secondary School Dibba Al-Hisn"}
|
||||||
|
{"id": "3", "contents": "\"Ao Oni (film)\"\nthe door as a big blue hand grabs his head. The group hears Takeshi's scream and come back to check on him. As they reach the room, they see the door open and Takeshi, bleeding, is thrown out. He asks for help before being dragged back and the door closes. As the group enters the room, they finds Takeshi's dismembered body. Anna and Hiroshi pressure Mika about Takuro's whereabouts and about Naoki, but Mika refuses to tell them and runs away, chased by Hiroshi. Anna confronts Shun who describes that this event is the same as the game he created."}
|
||||||
|
{"id": "4", "contents": "\"Pavia Cathedral\"\non March 17, 1989. The cathedral was begun in 1488, under architect Cristoforo Rocchi, who was soon replaced by Giovanni Antonio Amadeo and Gian Giacomo Dolcebuono. The original project, with a nave and two aisles flanked by semicircular niches and a large central dome, was influenced by Bramante, some details of it later appearing in St. Peter's Basilica in Rome. Leonardo da Vinci is also known to have contributed to the project. In 1521, the altar area was completed by Gianpietrino Rizzi, a pupil of Da Vinci. By the 17th century, the presbytery had been completed but only in the"}
|
||||||
|
{"id": "5", "contents": "\"Pavia Cathedral\"\nfollowing century was the tambour built, while the dome itself and the facade had to wait for the 19th century. The dome was designed by Carlo Maciachini and completed in 1885, but partially collapsed the same year. In 1930, construction continued with the two arms of the transept, for which the original plan was followed, although using reinforced concrete (in order to save the remains of the medieval Santa Maria del Popolo). The arms are still missing part of the internal marble decoration. The church is on the Greek Cross plan: it therefore has the same length and width at"}
|
||||||
|
{"id": "6", "contents": "\"Iowa Highway 17\"\nWesley. Although Iowa 17 has only been designated since 1969, most of the route has been a part of the primary highway system since the system's inception in 1919. The route was designated Primary Road No. 60 and connected Des Moines and Goldfield via Webster City. By 1930, the only paved section of what was by then called Iowa 60 was the southernmost from Des Moines to south of Madrid. In two years, paving extended to US 30 east of Boone. 1932 saw many changes to Iowa 60. The route was extended to the north where it ended at US"}
|
||||||
|
{"id": "7", "contents": "\"E.T. the Extra-Terrestrial (video game)\"\nfinancial failure for Atari. By 2004, the cartridges were still very common and offered at very low prices. While reviews of the movie were highly positive, the game was negatively received by critics, with common complaints focused on the plot, gameplay, and visuals. \"\"New York\"\" magazine's Nicholas Pileggi described it as a loser when compared to other games Atari could have released like \"\"Donkey Kong\"\" and \"\"Frogger\"\". \"\"Video Games\"\" called the game \"\"really for kids (the littler ones)\"\". Kevin Bowen of GameSpy's Classic Gaming called the gameplay \"\"convoluted and inane\"\", also criticizing its story for departing from the serious tone"}
|
||||||
|
{"id": "8", "contents": "\"Ao Oni (film)\"\n(for Takeshi cannot see Shun), and why Anna said Mika will not call him, because he is already dead and cannot be seen by everyone, except Anna. He then flashbacks to the time Takuro bullied him in the riverbank. Takuro hit Shun in the head, killing him. It is then revealed that Takuro went to the empty house to hide Shun's body. Takuro then explained that Shun used Takuro's name in the game and got mad. Shun is terrified with the truth and finally disappears as Anna apologizes to him. Takuro plans on killing Anna since she knows too much,"}
|
||||||
|
{"id": "9", "contents": "\"4th Airborne Corps (Soviet Union)\"\n4th Airborne Corps (Soviet Union) The 4th Airborne Corps was an airborne corps of the Red Army in World War II. It fought in the Vyazma airborne operation, an unsuccessful landing during the Rzhev-Vyazma Offensive. The corps was formed in the spring of 1941 in the Western Special Military District from the personnel of the 214th Airborne Brigade. The corps was commanded by Aleksey Semenovich Zhadov. On 22 June 1941, the corps was stationed in the Western Front's second echelon in Pukhavichy in Minsk Region. On 26 June, the corps was ordered to conduct an air-assault and ground attack with"}
|
||||||
BIN
public/llama32-3b.png
Normal file
BIN
public/llama32-3b.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 110 KiB |
BIN
public/multi-turn.png
Normal file
BIN
public/multi-turn.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 629 KiB |
BIN
public/single-turn.png
Normal file
BIN
public/single-turn.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 378 KiB |
78
pyproject.toml
Normal file
78
pyproject.toml
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
# -------------------------------
|
||||||
|
# build-system
|
||||||
|
# -------------------------------
|
||||||
|
[build-system]
|
||||||
|
requires = [
|
||||||
|
"setuptools>=61.0",
|
||||||
|
"wheel"
|
||||||
|
]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
# -------------------------------
|
||||||
|
# project (PEP 621 metadata)
|
||||||
|
# -------------------------------
|
||||||
|
[project]
|
||||||
|
name = "verl"
|
||||||
|
# We'll mark the version as "dynamic" because it's read from the file "verl/version/version"
|
||||||
|
# (PEP 621 calls this "dynamic version").
|
||||||
|
# The actual version is specified in the [tool.setuptools.dynamic] section below.
|
||||||
|
dynamic = ["version"]
|
||||||
|
|
||||||
|
description = "veRL: Volcano Engine Reinforcement Learning for LLM"
|
||||||
|
license = {file = "LICENSE"} # or "Apache-2.0", if you prefer an SPDX identifier
|
||||||
|
readme = {file = "README.md", content-type = "text/markdown"}
|
||||||
|
requires-python = ">=3.8"
|
||||||
|
|
||||||
|
authors = [
|
||||||
|
{ name = "Bytedance - Seed - MLSys", email = "zhangchi.usc1992@bytedance.com" },
|
||||||
|
{ name = "Bytedance - Seed - MLSys", email = "gmsheng@connect.hku.hk" },
|
||||||
|
]
|
||||||
|
|
||||||
|
# Dependencies corresponding to install_requires in setup.py
|
||||||
|
dependencies = [
|
||||||
|
"accelerate",
|
||||||
|
"codetiming",
|
||||||
|
"datasets",
|
||||||
|
"dill",
|
||||||
|
"hydra-core",
|
||||||
|
"numpy",
|
||||||
|
"pybind11",
|
||||||
|
"ray",
|
||||||
|
"tensordict",
|
||||||
|
"transformers<4.48",
|
||||||
|
"vllm<=0.6.3",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Optional dependencies (extras_require in setup.py)
|
||||||
|
[project.optional-dependencies]
|
||||||
|
test = [
|
||||||
|
"pytest", "yapf"
|
||||||
|
]
|
||||||
|
|
||||||
|
# URLs
|
||||||
|
[project.urls]
|
||||||
|
Homepage = "https://github.com/volcengine/verl"
|
||||||
|
|
||||||
|
# -------------------------------
|
||||||
|
# tool.setuptools - Additional config
|
||||||
|
# -------------------------------
|
||||||
|
[tool.setuptools]
|
||||||
|
# True means `setuptools` will attempt to include all relevant files in package_data automatically.
|
||||||
|
# This corresponds to `include_package_data=True` in setup.py.
|
||||||
|
include-package-data = true
|
||||||
|
|
||||||
|
# We read the version from a file in 'verl/version/version'
|
||||||
|
[tool.setuptools.dynamic]
|
||||||
|
version = {file = "verl/version/version"}
|
||||||
|
|
||||||
|
# If you need to mimic `package_dir={'': '.'}`:
|
||||||
|
[tool.setuptools.package-dir]
|
||||||
|
"" = "."
|
||||||
|
|
||||||
|
# If you need to include specific non-Python data (like YAML files or version file):
|
||||||
|
# This is the rough equivalent of package_data={'': ['version/*'], 'verl': ['trainer/config/*.yaml']}
|
||||||
|
[tool.setuptools.package-data]
|
||||||
|
verl = [
|
||||||
|
"version/*",
|
||||||
|
"trainer/config/*.yaml"
|
||||||
|
]
|
||||||
16
requirements.txt
Normal file
16
requirements.txt
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
accelerate
|
||||||
|
codetiming
|
||||||
|
datasets
|
||||||
|
dill
|
||||||
|
flash-attn
|
||||||
|
hydra-core
|
||||||
|
numpy
|
||||||
|
pandas
|
||||||
|
pybind11
|
||||||
|
ray
|
||||||
|
tensordict<0.6
|
||||||
|
transformers<4.48
|
||||||
|
vllm<=0.6.3
|
||||||
|
wandb
|
||||||
|
IPython
|
||||||
|
matplotlib
|
||||||
10
retrieval_launch.sh
Normal file
10
retrieval_launch.sh
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
|
||||||
|
file_path=/the/path/you/save/corpus
|
||||||
|
index_file=$file_path/e5_Flat.index
|
||||||
|
corpus_file=$file_path/wiki-18.jsonl
|
||||||
|
retriever=intfloat/e5-base-v2
|
||||||
|
|
||||||
|
python search_r1/search/retrieval_server.py --index_path $index_file \
|
||||||
|
--corpus_path $corpus_file \
|
||||||
|
--topk 3 \
|
||||||
|
--retriever_model $retriever
|
||||||
100
scripts/data_process/nq.py
Normal file
100
scripts/data_process/nq.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Preprocess the nq dataset to parquet format
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import os
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
from verl.utils.hdfs_io import copy, makedirs
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
def make_prefix(dp, template_type):
|
||||||
|
question = dp['question']
|
||||||
|
|
||||||
|
# NOTE: also need to change reward_score/countdown.py
|
||||||
|
if template_type == 'base':
|
||||||
|
"""This works for any base model"""
|
||||||
|
prefix = f"""Answer the given question. \
|
||||||
|
You should first have a reasoning process in mind and then provides the answer. \
|
||||||
|
Show your reasoning in <think> </think> tags and return the final answer in <answer> </answer> tags, for example <answer> Beijing </answer>. \
|
||||||
|
Question: {question}\n"""
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return prefix
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--local_dir', default='./data/nq')
|
||||||
|
parser.add_argument('--hdfs_dir', default=None)
|
||||||
|
parser.add_argument('--template_type', type=str, default='base')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
data_source = 'nq'
|
||||||
|
|
||||||
|
dataset = datasets.load_dataset('RUC-NLPIR/FlashRAG_datasets', 'nq')
|
||||||
|
|
||||||
|
train_dataset = dataset['train']
|
||||||
|
test_dataset = dataset['test']
|
||||||
|
|
||||||
|
# add a row to each data item that represents a unique id
|
||||||
|
def make_map_fn(split):
|
||||||
|
|
||||||
|
def process_fn(example, idx):
|
||||||
|
example['question'] = example['question'].strip()
|
||||||
|
if example['question'][-1] != '?':
|
||||||
|
example['question'] += '?'
|
||||||
|
question = make_prefix(example, template_type=args.template_type)
|
||||||
|
solution = {
|
||||||
|
"target": example['golden_answers'],
|
||||||
|
}
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"data_source": data_source,
|
||||||
|
"prompt": [{
|
||||||
|
"role": "user",
|
||||||
|
"content": question,
|
||||||
|
}],
|
||||||
|
"ability": "fact-reasoning",
|
||||||
|
"reward_model": {
|
||||||
|
"style": "rule",
|
||||||
|
"ground_truth": solution
|
||||||
|
},
|
||||||
|
"extra_info": {
|
||||||
|
'split': split,
|
||||||
|
'index': idx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
|
||||||
|
return process_fn
|
||||||
|
|
||||||
|
train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True)
|
||||||
|
test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True)
|
||||||
|
|
||||||
|
local_dir = args.local_dir
|
||||||
|
hdfs_dir = args.hdfs_dir
|
||||||
|
|
||||||
|
train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet'))
|
||||||
|
test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet'))
|
||||||
|
|
||||||
|
if hdfs_dir is not None:
|
||||||
|
makedirs(hdfs_dir)
|
||||||
|
|
||||||
|
copy(src=local_dir, dst=hdfs_dir)
|
||||||
141
scripts/data_process/nq_rag.py
Normal file
141
scripts/data_process/nq_rag.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Preprocess the nq dataset to parquet format
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
from verl.utils.hdfs_io import copy, makedirs
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
def make_prefix(dp, template_type):
|
||||||
|
question = dp['question']
|
||||||
|
context = dp['context']
|
||||||
|
|
||||||
|
# NOTE: also need to change reward_score/countdown.py
|
||||||
|
if template_type == 'base':
|
||||||
|
"""This works for any base model"""
|
||||||
|
prefix = f"""Answer the given question with some potentially useful context. \
|
||||||
|
You should analyze the question carefully, evaluate the given context (which may or may not be useful), and then generate an accurate and well-reasoned response. \
|
||||||
|
You should first have a reasoning process in mind and then provides the answer. \
|
||||||
|
Show your reasoning in <think> </think> tags and return the final answer in <answer> </answer> tags, for example <answer> Beijing </answer>. \
|
||||||
|
Question: {question} Context: {context} \n"""
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return prefix
|
||||||
|
|
||||||
|
|
||||||
|
def format_reference(retrieval_result):
|
||||||
|
format_reference = ''
|
||||||
|
for idx, doc_item in enumerate(retrieval_result):
|
||||||
|
content = doc_item['contents']
|
||||||
|
title = content.split("\n")[0]
|
||||||
|
text = "\n".join(content.split("\n")[1:])
|
||||||
|
format_reference += f"Doc {idx+1}(Title: {title}) {text}\n"
|
||||||
|
|
||||||
|
return format_reference
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--local_dir', default='./data/nq_rag')
|
||||||
|
parser.add_argument('--hdfs_dir', default=None)
|
||||||
|
parser.add_argument('--template_type', type=str, default='base')
|
||||||
|
parser.add_argument('--topk', type=int, default=3)
|
||||||
|
parser.add_argument('--corpus_path', type=str, default='/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl')
|
||||||
|
parser.add_argument('--train_retrieval_cache', type=str, default='/home/peterjin/rag_retrieval_cache/nq/e5_train_retrieval_cache_2048.json')
|
||||||
|
parser.add_argument('--test_retrieval_cache', type=str, default='/home/peterjin/rag_retrieval_cache/nq/e5_test_retrieval_cache_10000.json')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
data_source = 'nq'
|
||||||
|
|
||||||
|
dataset = datasets.load_dataset('RUC-NLPIR/FlashRAG_datasets', 'nq')
|
||||||
|
|
||||||
|
train_dataset = dataset['train']
|
||||||
|
test_dataset = dataset['test']
|
||||||
|
|
||||||
|
# read retrieval cache
|
||||||
|
print('reading retrieval cache...')
|
||||||
|
retrieval_cache = json.load(open(args.train_retrieval_cache))
|
||||||
|
# test_retrieval_cache = json.load(open(args.test_retrieval_cache))
|
||||||
|
retrieval_cache.update(json.load(open(args.test_retrieval_cache)))
|
||||||
|
|
||||||
|
# read corpus
|
||||||
|
print('reading corpus...')
|
||||||
|
corpus = {}
|
||||||
|
with open(args.corpus_path) as f:
|
||||||
|
readin = f.readlines()
|
||||||
|
for line in readin:
|
||||||
|
tmp = json.loads(line)
|
||||||
|
corpus[tmp['id']] = tmp
|
||||||
|
|
||||||
|
# add a column for the retrieval context
|
||||||
|
def add_context(example):
|
||||||
|
example['context'] = format_reference([corpus[docs["id"]] for docs in retrieval_cache[example['question']][:args.topk]])
|
||||||
|
return example
|
||||||
|
|
||||||
|
train_dataset = train_dataset.map(function=add_context)
|
||||||
|
test_dataset = test_dataset.map(function=add_context)
|
||||||
|
|
||||||
|
# add a row to each data item that represents a unique id
|
||||||
|
def make_map_fn(split):
|
||||||
|
|
||||||
|
def process_fn(example, idx):
|
||||||
|
example['question'] = example['question'].strip()
|
||||||
|
if example['question'][-1] != '?':
|
||||||
|
example['question'] += '?'
|
||||||
|
question = make_prefix(example, template_type=args.template_type)
|
||||||
|
solution = {
|
||||||
|
"target": example['golden_answers'],
|
||||||
|
}
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"data_source": data_source,
|
||||||
|
"prompt": [{
|
||||||
|
"role": "user",
|
||||||
|
"content": question,
|
||||||
|
}],
|
||||||
|
"ability": "fact-reasoning",
|
||||||
|
"reward_model": {
|
||||||
|
"style": "rule",
|
||||||
|
"ground_truth": solution
|
||||||
|
},
|
||||||
|
"extra_info": {
|
||||||
|
'split': split,
|
||||||
|
'index': idx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
|
||||||
|
return process_fn
|
||||||
|
|
||||||
|
train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True)
|
||||||
|
test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True)
|
||||||
|
|
||||||
|
local_dir = args.local_dir
|
||||||
|
hdfs_dir = args.hdfs_dir
|
||||||
|
|
||||||
|
train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet'))
|
||||||
|
test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet'))
|
||||||
|
|
||||||
|
if hdfs_dir is not None:
|
||||||
|
makedirs(hdfs_dir)
|
||||||
|
|
||||||
|
copy(src=local_dir, dst=hdfs_dir)
|
||||||
101
scripts/data_process/nq_search.py
Normal file
101
scripts/data_process/nq_search.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Preprocess the nq dataset to parquet format
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import os
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
from verl.utils.hdfs_io import copy, makedirs
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
def make_prefix(dp, template_type):
|
||||||
|
question = dp['question']
|
||||||
|
|
||||||
|
# NOTE: also need to change reward_score/countdown.py
|
||||||
|
if template_type == 'base':
|
||||||
|
"""This works for any base model"""
|
||||||
|
prefix = f"""Answer the given question. \
|
||||||
|
You must conduct reasoning inside <think> and </think> first every time you get new information. \
|
||||||
|
After reasoning, if you find you lack some knowledge, you can call a search engine by <search> query </search> and it will return the top searched results between <information> and </information>. \
|
||||||
|
You can search as many times as your want. \
|
||||||
|
If you find no further external knowledge needed, you can directly provide the answer inside <answer> and </answer>, without detailed illustrations. For example, <answer> Beijing </answer>. Question: {question}\n"""
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return prefix
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--local_dir', default='./data/nq_search')
|
||||||
|
parser.add_argument('--hdfs_dir', default=None)
|
||||||
|
parser.add_argument('--template_type', type=str, default='base')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
data_source = 'nq'
|
||||||
|
|
||||||
|
dataset = datasets.load_dataset('RUC-NLPIR/FlashRAG_datasets', 'nq')
|
||||||
|
|
||||||
|
train_dataset = dataset['train']
|
||||||
|
test_dataset = dataset['test']
|
||||||
|
|
||||||
|
# add a row to each data item that represents a unique id
|
||||||
|
def make_map_fn(split):
|
||||||
|
|
||||||
|
def process_fn(example, idx):
|
||||||
|
example['question'] = example['question'].strip()
|
||||||
|
if example['question'][-1] != '?':
|
||||||
|
example['question'] += '?'
|
||||||
|
question = make_prefix(example, template_type=args.template_type)
|
||||||
|
solution = {
|
||||||
|
"target": example['golden_answers'],
|
||||||
|
}
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"data_source": data_source,
|
||||||
|
"prompt": [{
|
||||||
|
"role": "user",
|
||||||
|
"content": question,
|
||||||
|
}],
|
||||||
|
"ability": "fact-reasoning",
|
||||||
|
"reward_model": {
|
||||||
|
"style": "rule",
|
||||||
|
"ground_truth": solution
|
||||||
|
},
|
||||||
|
"extra_info": {
|
||||||
|
'split': split,
|
||||||
|
'index': idx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
|
||||||
|
return process_fn
|
||||||
|
|
||||||
|
train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True)
|
||||||
|
test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True)
|
||||||
|
|
||||||
|
local_dir = args.local_dir
|
||||||
|
hdfs_dir = args.hdfs_dir
|
||||||
|
|
||||||
|
train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet'))
|
||||||
|
test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet'))
|
||||||
|
|
||||||
|
if hdfs_dir is not None:
|
||||||
|
makedirs(hdfs_dir)
|
||||||
|
|
||||||
|
copy(src=local_dir, dst=hdfs_dir)
|
||||||
25
scripts/download.py
Normal file
25
scripts/download.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
import argparse
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Download files from a Hugging Face dataset repository.")
|
||||||
|
parser.add_argument("--repo_id", type=str, default="PeterJinGo/wiki-18-e5-index", help="Hugging Face repository ID")
|
||||||
|
parser.add_argument("--save_path", type=str, required=True, help="Local directory to save files")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
repo_id = "PeterJinGo/wiki-18-e5-index"
|
||||||
|
for file in ["part_aa", "part_ab"]:
|
||||||
|
hf_hub_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
filename=file, # e.g., "e5_Flat.index"
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=args.save_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
repo_id = "PeterJinGo/wiki-18-corpus"
|
||||||
|
hf_hub_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
filename="wiki-18.jsonl.gz",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=args.save_path,
|
||||||
|
)
|
||||||
6
scripts/download.sh
Normal file
6
scripts/download.sh
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
|
||||||
|
save_path=/home/peterjin/debug_cache
|
||||||
|
|
||||||
|
python download.py --savepath $savepath
|
||||||
|
|
||||||
|
cat $save_path/part_* > e5_Flat.index
|
||||||
12
scripts/upload.py
Normal file
12
scripts/upload.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
import os
|
||||||
|
from huggingface_hub import upload_file
|
||||||
|
|
||||||
|
repo_id = "PeterJinGo/wiki-18-e5-index"
|
||||||
|
path = "/home/peterjin/mnt/index/wiki-18"
|
||||||
|
for file in ["part_aa", "part_ab"]:
|
||||||
|
upload_file(
|
||||||
|
path_or_fileobj=os.path.join(path, file), # File path
|
||||||
|
path_in_repo=file, # Destination filename in the repo
|
||||||
|
repo_id=repo_id, # Your dataset repo ID
|
||||||
|
repo_type="dataset"
|
||||||
|
)
|
||||||
6
scripts/upload.sh
Normal file
6
scripts/upload.sh
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
|
||||||
|
index=/home/peterjin/mnt/index/wiki-18/e5_Flat.index
|
||||||
|
|
||||||
|
split -b 40G $index part_
|
||||||
|
|
||||||
|
python upload.py
|
||||||
0
search_r1/__init__.py
Normal file
0
search_r1/__init__.py
Normal file
0
search_r1/llm_agent/__init__.py
Normal file
0
search_r1/llm_agent/__init__.py
Normal file
416
search_r1/llm_agent/generation.py
Normal file
416
search_r1/llm_agent/generation.py
Normal file
@@ -0,0 +1,416 @@
|
|||||||
|
import torch
|
||||||
|
import re
|
||||||
|
from collections import defaultdict
|
||||||
|
import os
|
||||||
|
from typing import List, Dict, Any, Tuple
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from .tensor_helper import TensorHelper, TensorConfig
|
||||||
|
# from search_r1.utils import set_seed
|
||||||
|
# from search_r1.utils.plot import (
|
||||||
|
# save_trajectory_to_output,
|
||||||
|
# parse_llm_output
|
||||||
|
# )
|
||||||
|
from verl import DataProto
|
||||||
|
from verl.utils.tracking import Tracking
|
||||||
|
import shutil
|
||||||
|
import requests
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GenerationConfig:
|
||||||
|
max_turns: int
|
||||||
|
max_start_length: int
|
||||||
|
max_prompt_length: int
|
||||||
|
max_response_length: int
|
||||||
|
max_obs_length: int
|
||||||
|
# logging: dict
|
||||||
|
num_gpus: int
|
||||||
|
no_think_rl: bool=False
|
||||||
|
search_url: str = None
|
||||||
|
topk: int = 3
|
||||||
|
|
||||||
|
class LLMGenerationManager:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer,
|
||||||
|
actor_rollout_wg,
|
||||||
|
config: GenerationConfig,
|
||||||
|
# logger: Tracking,
|
||||||
|
is_validation: bool = False,
|
||||||
|
):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.actor_rollout_wg = actor_rollout_wg
|
||||||
|
self.config = config
|
||||||
|
# self.logger = logger
|
||||||
|
self.is_validation = is_validation
|
||||||
|
|
||||||
|
self.tensor_fn = TensorHelper(TensorConfig(
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
max_prompt_length=config.max_prompt_length,
|
||||||
|
max_obs_length=config.max_obs_length,
|
||||||
|
max_start_length=config.max_start_length
|
||||||
|
))
|
||||||
|
|
||||||
|
def _batch_tokenize(self, responses: List[str]) -> torch.Tensor:
|
||||||
|
"""Tokenize a batch of responses."""
|
||||||
|
return self.tokenizer(
|
||||||
|
responses,
|
||||||
|
add_special_tokens=False,
|
||||||
|
return_tensors='pt',
|
||||||
|
padding="longest"
|
||||||
|
)['input_ids']
|
||||||
|
|
||||||
|
def _postprocess_responses(self, responses: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Process responses to stop at search operation or answer operation."""
|
||||||
|
responses_str = self.tokenizer.batch_decode(
|
||||||
|
responses,
|
||||||
|
skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
responses_str = [resp.split('</search>')[0] + '</search>'
|
||||||
|
if '</search>' in resp
|
||||||
|
else resp.split('</answer>')[0] + '</answer>'
|
||||||
|
if '</answer>' in resp
|
||||||
|
else resp
|
||||||
|
for resp in responses_str]
|
||||||
|
|
||||||
|
if self.config.no_think_rl:
|
||||||
|
raise ValueError('stop')
|
||||||
|
# if no_think_rl is enabled, only keep action in the str
|
||||||
|
actions, _ = self.env.postprocess_predictions(responses_str)
|
||||||
|
responses_str=[f"<answer>{envs[idx].ACTION_LOOKUP[action]}</answer>" for idx, action in enumerate(actions)]
|
||||||
|
print("RESPONSES:", responses_str)
|
||||||
|
responses = self._batch_tokenize(responses_str)
|
||||||
|
return responses, responses_str
|
||||||
|
|
||||||
|
def _process_next_obs(self, next_obs: List[str]) -> torch.Tensor:
|
||||||
|
"""Process next observations from environment."""
|
||||||
|
|
||||||
|
next_obs_ids = self.tokenizer(
|
||||||
|
next_obs,
|
||||||
|
padding='longest',
|
||||||
|
return_tensors='pt',
|
||||||
|
add_special_tokens=False, # Prevents adding special tokens
|
||||||
|
)['input_ids']
|
||||||
|
|
||||||
|
if next_obs_ids.shape[1] > self.config.max_obs_length:
|
||||||
|
print(f"[WARNING] OBSERVATION TOO LONG, CONSIDER CHANGING YOUR CONFIG, {next_obs_ids.shape[1]} & {self.config.max_obs_length}")
|
||||||
|
next_obs_ids = next_obs_ids[:, :self.config.max_obs_length]
|
||||||
|
|
||||||
|
return next_obs_ids
|
||||||
|
|
||||||
|
def _update_rolling_state(self, rollings, cur_responses: torch.Tensor,
|
||||||
|
next_obs_ids: torch.Tensor) -> Dict:
|
||||||
|
"""Update rolling state with new responses and observations."""
|
||||||
|
# Concatenate and handle padding
|
||||||
|
new_input_ids = self.tensor_fn.concatenate_with_padding([
|
||||||
|
rollings.batch['input_ids'],
|
||||||
|
cur_responses,
|
||||||
|
next_obs_ids
|
||||||
|
])
|
||||||
|
|
||||||
|
# Create attention mask and position ids
|
||||||
|
new_attention_mask = self.tensor_fn.create_attention_mask(new_input_ids)
|
||||||
|
new_position_ids = self.tensor_fn.create_position_ids(new_attention_mask)
|
||||||
|
|
||||||
|
# Cut to appropriate length
|
||||||
|
effective_len = new_attention_mask.sum(dim=1).max()
|
||||||
|
max_len = min(self.config.max_prompt_length, effective_len)
|
||||||
|
|
||||||
|
return DataProto.from_dict({
|
||||||
|
'input_ids': new_input_ids[:, -max_len:],
|
||||||
|
'position_ids': new_position_ids[:, -max_len:],
|
||||||
|
'attention_mask': new_attention_mask[:, -max_len:]
|
||||||
|
})
|
||||||
|
|
||||||
|
def _update_right_side(self, right_side: Dict,
|
||||||
|
cur_responses: torch.Tensor,
|
||||||
|
next_obs_ids: torch.Tensor = None) -> Dict:
|
||||||
|
"""Update right side state."""
|
||||||
|
if next_obs_ids != None:
|
||||||
|
responses = self.tensor_fn.concatenate_with_padding([
|
||||||
|
right_side['responses'],
|
||||||
|
cur_responses,
|
||||||
|
next_obs_ids
|
||||||
|
], pad_to_left=False)
|
||||||
|
else:
|
||||||
|
responses = self.tensor_fn.concatenate_with_padding([
|
||||||
|
right_side['responses'],
|
||||||
|
cur_responses,
|
||||||
|
], pad_to_left=False)
|
||||||
|
|
||||||
|
effective_len = self.tensor_fn.create_attention_mask(responses).sum(dim=1).max()
|
||||||
|
max_len = min(self.config.max_prompt_length, effective_len)
|
||||||
|
|
||||||
|
return {'responses': responses[:, :max_len]}
|
||||||
|
|
||||||
|
def _generate_with_gpu_padding(self, active_batch: DataProto) -> DataProto:
|
||||||
|
"""
|
||||||
|
Wrapper for generation that handles multi-GPU padding requirements.
|
||||||
|
if num_gpus <= 1, return self.actor_rollout_wg.generate_sequences(active_batch)
|
||||||
|
if active_batch size is not divisible by num_gpus, pad with first sequence
|
||||||
|
then remove padding from output
|
||||||
|
"""
|
||||||
|
num_gpus = self.config.num_gpus
|
||||||
|
if num_gpus <= 1:
|
||||||
|
return self.actor_rollout_wg.generate_sequences(active_batch)
|
||||||
|
|
||||||
|
batch_size = active_batch.batch['input_ids'].shape[0]
|
||||||
|
remainder = batch_size % num_gpus
|
||||||
|
|
||||||
|
if remainder == 0:
|
||||||
|
return self.actor_rollout_wg.generate_sequences(active_batch)
|
||||||
|
|
||||||
|
# Add padding sequences
|
||||||
|
padding_size = num_gpus - remainder
|
||||||
|
padded_batch = {}
|
||||||
|
|
||||||
|
for k, v in active_batch.batch.items():
|
||||||
|
# Use first sequence as padding template
|
||||||
|
pad_sequence = v[0:1].repeat(padding_size, *[1] * (len(v.shape) - 1))
|
||||||
|
padded_batch[k] = torch.cat([v, pad_sequence], dim=0)
|
||||||
|
|
||||||
|
padded_active_batch = DataProto.from_dict(padded_batch)
|
||||||
|
|
||||||
|
# Generate with padded batch
|
||||||
|
padded_output = self.actor_rollout_wg.generate_sequences(padded_active_batch)
|
||||||
|
|
||||||
|
# Remove padding from output
|
||||||
|
trimmed_batch = {k: v[:-padding_size] for k, v in padded_output.batch.items()}
|
||||||
|
|
||||||
|
# Handle meta_info if present
|
||||||
|
if hasattr(padded_output, 'meta_info') and padded_output.meta_info:
|
||||||
|
trimmed_meta = {}
|
||||||
|
for k, v in padded_output.meta_info.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
trimmed_meta[k] = v[:-padding_size]
|
||||||
|
else:
|
||||||
|
trimmed_meta[k] = v
|
||||||
|
padded_output.meta_info = trimmed_meta
|
||||||
|
|
||||||
|
padded_output.batch = trimmed_batch
|
||||||
|
return padded_output
|
||||||
|
|
||||||
|
def run_llm_loop(self, gen_batch, initial_input_ids: torch.Tensor) -> Tuple[Dict, Dict]:
|
||||||
|
"""Run main LLM generation loop."""
|
||||||
|
|
||||||
|
original_left_side = {'input_ids': initial_input_ids[:, -self.config.max_start_length:]}
|
||||||
|
original_right_side = {'responses': initial_input_ids[:, []]}
|
||||||
|
|
||||||
|
active_mask = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.bool)
|
||||||
|
active_num_list = [active_mask.sum().item()]
|
||||||
|
rollings = gen_batch
|
||||||
|
|
||||||
|
# Main generation loop
|
||||||
|
for step in range(self.config.max_turns):
|
||||||
|
if not active_mask.sum():
|
||||||
|
break
|
||||||
|
rollings.batch = self.tensor_fn.cut_to_effective_len(
|
||||||
|
rollings.batch,
|
||||||
|
keys=['input_ids', 'attention_mask', 'position_ids']
|
||||||
|
)
|
||||||
|
|
||||||
|
# gen_output = self.actor_rollout_wg.generate_sequences(rollings)
|
||||||
|
rollings_active = DataProto.from_dict({
|
||||||
|
k: v[active_mask] for k, v in rollings.batch.items()
|
||||||
|
})
|
||||||
|
gen_output = self._generate_with_gpu_padding(rollings_active)
|
||||||
|
|
||||||
|
meta_info = gen_output.meta_info
|
||||||
|
responses_ids, responses_str = self._postprocess_responses(gen_output.batch['responses'])
|
||||||
|
responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask)
|
||||||
|
|
||||||
|
# Execute in environment and process observations
|
||||||
|
next_obs, dones = self.execute_predictions(
|
||||||
|
responses_str, self.tokenizer.pad_token, active_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool)
|
||||||
|
active_mask = active_mask * curr_active_mask
|
||||||
|
active_num_list.append(active_mask.sum().item())
|
||||||
|
|
||||||
|
next_obs_ids = self._process_next_obs(next_obs)
|
||||||
|
|
||||||
|
# Update states
|
||||||
|
rollings = self._update_rolling_state(
|
||||||
|
rollings,
|
||||||
|
responses_ids,
|
||||||
|
next_obs_ids
|
||||||
|
)
|
||||||
|
original_right_side = self._update_right_side(
|
||||||
|
original_right_side,
|
||||||
|
responses_ids,
|
||||||
|
next_obs_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
# final LLM rollout
|
||||||
|
if active_mask.sum():
|
||||||
|
rollings.batch = self.tensor_fn.cut_to_effective_len(
|
||||||
|
rollings.batch,
|
||||||
|
keys=['input_ids', 'attention_mask', 'position_ids']
|
||||||
|
)
|
||||||
|
|
||||||
|
# gen_output = self.actor_rollout_wg.generate_sequences(rollings)
|
||||||
|
rollings_active = DataProto.from_dict({
|
||||||
|
k: v[active_mask] for k, v in rollings.batch.items()
|
||||||
|
})
|
||||||
|
gen_output = self._generate_with_gpu_padding(rollings_active)
|
||||||
|
|
||||||
|
meta_info = gen_output.meta_info
|
||||||
|
responses_ids, responses_str = self._postprocess_responses(gen_output.batch['responses'])
|
||||||
|
responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask)
|
||||||
|
|
||||||
|
# # Execute in environment and process observations
|
||||||
|
_, dones = self.execute_predictions(
|
||||||
|
responses_str, self.tokenizer.pad_token, active_mask, do_search=False
|
||||||
|
)
|
||||||
|
|
||||||
|
curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool)
|
||||||
|
active_mask = active_mask * curr_active_mask
|
||||||
|
active_num_list.append(active_mask.sum().item())
|
||||||
|
|
||||||
|
original_right_side = self._update_right_side(
|
||||||
|
original_right_side,
|
||||||
|
responses_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("ACTIVE_TRAJ_NUM:", active_num_list)
|
||||||
|
|
||||||
|
return self._compose_final_output(original_left_side, original_right_side, meta_info)
|
||||||
|
|
||||||
|
def _compose_final_output(self, left_side: Dict,
|
||||||
|
right_side: Dict,
|
||||||
|
meta_info: Dict) -> Tuple[Dict, Dict]:
|
||||||
|
"""Compose final generation output."""
|
||||||
|
final_output = right_side.copy()
|
||||||
|
final_output['prompts'] = left_side['input_ids']
|
||||||
|
|
||||||
|
# Combine input IDs
|
||||||
|
final_output['input_ids'] = torch.cat([
|
||||||
|
left_side['input_ids'],
|
||||||
|
right_side['responses']
|
||||||
|
], dim=1)
|
||||||
|
|
||||||
|
# Create attention mask and position ids
|
||||||
|
final_output['attention_mask'] = torch.cat([
|
||||||
|
self.tensor_fn.create_attention_mask(left_side['input_ids']),
|
||||||
|
self.tensor_fn.create_attention_mask(final_output['responses'])
|
||||||
|
], dim=1)
|
||||||
|
|
||||||
|
final_output['position_ids'] = self.tensor_fn.create_position_ids(
|
||||||
|
final_output['attention_mask']
|
||||||
|
)
|
||||||
|
|
||||||
|
final_output = DataProto.from_dict(final_output)
|
||||||
|
final_output.meta_info.update(meta_info)
|
||||||
|
|
||||||
|
return final_output
|
||||||
|
|
||||||
|
def execute_predictions(self, predictions: List[str], pad_token: str, active_mask=None, do_search=True) -> List[str]:
|
||||||
|
"""
|
||||||
|
Execute predictions across multiple environments.
|
||||||
|
NOTE: the function is the actual `step` function in the environment
|
||||||
|
NOTE penalty_for_invalid is not included in observation shown to the LLM
|
||||||
|
|
||||||
|
Args:
|
||||||
|
envs: List of environment instances
|
||||||
|
predictions: List of action predictions
|
||||||
|
pad_token: Token to use for padding
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of observation strings
|
||||||
|
"""
|
||||||
|
cur_actions, contents = self.postprocess_predictions(predictions)
|
||||||
|
next_obs, dones = [], []
|
||||||
|
|
||||||
|
search_queries = [content for action, content in zip(cur_actions, contents) if action == 'search']
|
||||||
|
if do_search:
|
||||||
|
search_results = self.batch_search(search_queries)
|
||||||
|
assert len(search_results) == sum([1 for action in cur_actions if action == 'search'])
|
||||||
|
else:
|
||||||
|
search_results = [''] * sum([1 for action in cur_actions if action == 'search'])
|
||||||
|
|
||||||
|
for i, (action, active) in enumerate(zip(cur_actions, active_mask)):
|
||||||
|
|
||||||
|
if not active:
|
||||||
|
next_obs.append('')
|
||||||
|
dones.append(1)
|
||||||
|
else:
|
||||||
|
if action == 'answer':
|
||||||
|
next_obs.append('')
|
||||||
|
dones.append(1)
|
||||||
|
elif action == 'search':
|
||||||
|
next_obs.append(f'\n\n<information>{search_results.pop(0).strip()}</information>\n\n')
|
||||||
|
dones.append(0)
|
||||||
|
else:
|
||||||
|
next_obs.append(f'\nMy previous action is invalid. \
|
||||||
|
If I want to search, I should put the query between <search> and </search>. \
|
||||||
|
If I want to give the final answer, I should put the answer between <answer> and </answer>. Let me try again.\n')
|
||||||
|
dones.append(0)
|
||||||
|
|
||||||
|
assert len(search_results) == 0
|
||||||
|
|
||||||
|
return next_obs, dones
|
||||||
|
|
||||||
|
def postprocess_predictions(self, predictions: List[Any]) -> Tuple[List[int], List[bool]]:
|
||||||
|
"""
|
||||||
|
Process (text-based) predictions from llm into actions and validity flags.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
predictions: List of raw predictions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (actions list, validity flags list)
|
||||||
|
"""
|
||||||
|
actions = []
|
||||||
|
contents = []
|
||||||
|
|
||||||
|
for prediction in predictions:
|
||||||
|
if isinstance(prediction, str): # for llm output
|
||||||
|
pattern = r'<(search|answer)>(.*?)</\1>'
|
||||||
|
match = re.search(pattern, prediction, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
content = match.group(2).strip() # Return only the content inside the tags
|
||||||
|
action = match.group(1)
|
||||||
|
else:
|
||||||
|
content = ''
|
||||||
|
action = None
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid prediction type: {type(prediction)}")
|
||||||
|
|
||||||
|
actions.append(action)
|
||||||
|
contents.append(content)
|
||||||
|
|
||||||
|
return actions, contents
|
||||||
|
|
||||||
|
def batch_search(self, queries: List[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
Batchified search for queries.
|
||||||
|
Args:
|
||||||
|
queries: queries to call the search engine
|
||||||
|
Returns:
|
||||||
|
search results which is concatenated into a string
|
||||||
|
"""
|
||||||
|
results = self._batch_search(queries)['result']
|
||||||
|
|
||||||
|
return [self._passages2string(result) for result in results]
|
||||||
|
|
||||||
|
def _batch_search(self, queries):
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"queries": queries,
|
||||||
|
"topk": self.config.topk,
|
||||||
|
"return_scores": True
|
||||||
|
}
|
||||||
|
|
||||||
|
return requests.post(self.config.search_url, json=payload).json()
|
||||||
|
|
||||||
|
def _passages2string(self, retrieval_result):
|
||||||
|
format_reference = ''
|
||||||
|
for idx, doc_item in enumerate(retrieval_result):
|
||||||
|
|
||||||
|
content = doc_item['document']['contents']
|
||||||
|
title = content.split("\n")[0]
|
||||||
|
text = "\n".join(content.split("\n")[1:])
|
||||||
|
format_reference += f"Doc {idx+1}(Title: {title}) {text}\n"
|
||||||
|
|
||||||
|
return format_reference
|
||||||
75
search_r1/llm_agent/tensor_helper.py
Normal file
75
search_r1/llm_agent/tensor_helper.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import torch
|
||||||
|
from typing import Dict, Tuple, List
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TensorConfig:
|
||||||
|
pad_token_id: int
|
||||||
|
max_prompt_length: int
|
||||||
|
max_obs_length: int
|
||||||
|
max_start_length: int
|
||||||
|
|
||||||
|
class TensorHelper:
|
||||||
|
def __init__(self, config: TensorConfig):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def cut_to_effective_len(self, tensor_dict: Dict[str, torch.Tensor],
|
||||||
|
keys: List[str], cut_left: bool = True) -> Dict[str, torch.Tensor]:
|
||||||
|
"""Cut tensors to their effective length based on attention mask."""
|
||||||
|
effective_len = tensor_dict['attention_mask'].sum(dim=1).max()
|
||||||
|
result = tensor_dict.copy()
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
if cut_left:
|
||||||
|
result[key] = tensor_dict[key][:, -effective_len:]
|
||||||
|
else:
|
||||||
|
result[key] = tensor_dict[key][:, :effective_len]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def convert_pad_structure(self, tensor: torch.Tensor, pad_to_left: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Convert padding structure and return sorted tensor with indices."""
|
||||||
|
mask = tensor != self.config.pad_token_id if pad_to_left else tensor == self.config.pad_token_id
|
||||||
|
sorted_indices = mask.to(torch.int64).argsort(dim=1, stable=True)
|
||||||
|
return tensor.gather(1, sorted_indices), sorted_indices
|
||||||
|
|
||||||
|
def create_attention_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Create attention mask from input ids."""
|
||||||
|
return torch.where(input_ids != self.config.pad_token_id, 1, 0)
|
||||||
|
|
||||||
|
def create_position_ids(self, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Create position ids from attention mask."""
|
||||||
|
return (torch.cumsum(attention_mask, dim=1) - 1) * attention_mask
|
||||||
|
|
||||||
|
def concatenate_with_padding(self, tensors: List[torch.Tensor],
|
||||||
|
pad_to_left: bool = True) -> torch.Tensor:
|
||||||
|
"""Concatenate tensors and handle padding."""
|
||||||
|
concatenated = torch.cat(tensors, dim=1)
|
||||||
|
padded_tensor, _ = self.convert_pad_structure(concatenated, pad_to_left)
|
||||||
|
return padded_tensor
|
||||||
|
|
||||||
|
def _example_level_pad(self, responses: torch.Tensor,
|
||||||
|
responses_str: List[str],
|
||||||
|
active_mask: torch.Tensor) -> Tuple[torch.Tensor, List[str]]:
|
||||||
|
"""
|
||||||
|
Pad responses for non-active examples with pad tokens.
|
||||||
|
"""
|
||||||
|
assert active_mask.sum() == responses.shape[0]
|
||||||
|
# Create masked responses tensor
|
||||||
|
batch_size = active_mask.shape[0]
|
||||||
|
seq_len = responses.shape[1]
|
||||||
|
padded_responses = torch.full(
|
||||||
|
(batch_size, seq_len), self.config.pad_token_id,
|
||||||
|
dtype=responses.dtype, device=responses.device
|
||||||
|
)
|
||||||
|
padded_responses[active_mask] = responses
|
||||||
|
|
||||||
|
# Create masked response strings
|
||||||
|
padded_responses_str = [""] * batch_size
|
||||||
|
|
||||||
|
s = 0
|
||||||
|
for i, is_active in enumerate(active_mask):
|
||||||
|
if is_active:
|
||||||
|
padded_responses_str[i] = responses_str[s]
|
||||||
|
s += 1
|
||||||
|
|
||||||
|
return padded_responses, padded_responses_str
|
||||||
17
search_r1/search/build_index.sh
Normal file
17
search_r1/search/build_index.sh
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
|
||||||
|
corpus_file=/your/corpus/jsonl/file # jsonl
|
||||||
|
save_dir=/the/path/to/save/index
|
||||||
|
retriever_name=e5 # this is for indexing naming
|
||||||
|
retriever_model=intfloat/e5-base-v2
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python index_builder.py \
|
||||||
|
--retrieval_method $retriever_name \
|
||||||
|
--model_path $retriever_model \
|
||||||
|
--corpus_path $corpus_file \
|
||||||
|
--save_dir $save_dir \
|
||||||
|
--use_fp16 \
|
||||||
|
--max_length 256 \
|
||||||
|
--batch_size 512 \
|
||||||
|
--pooling_method mean \
|
||||||
|
--faiss_type Flat \
|
||||||
|
--save_embedding
|
||||||
348
search_r1/search/index_builder.py
Normal file
348
search_r1/search/index_builder.py
Normal file
@@ -0,0 +1,348 @@
|
|||||||
|
import os
|
||||||
|
import faiss
|
||||||
|
import json
|
||||||
|
import warnings
|
||||||
|
import numpy as np
|
||||||
|
from typing import cast, List, Dict
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
# from LongRAG.retriever.utils import load_model, load_corpus, pooling
|
||||||
|
import datasets
|
||||||
|
from transformers import AutoTokenizer, AutoModel, AutoConfig
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(
|
||||||
|
model_path: str,
|
||||||
|
use_fp16: bool = False
|
||||||
|
):
|
||||||
|
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
model.eval()
|
||||||
|
model.cuda()
|
||||||
|
if use_fp16:
|
||||||
|
model = model.half()
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True)
|
||||||
|
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def pooling(
|
||||||
|
pooler_output,
|
||||||
|
last_hidden_state,
|
||||||
|
attention_mask = None,
|
||||||
|
pooling_method = "mean"
|
||||||
|
):
|
||||||
|
if pooling_method == "mean":
|
||||||
|
last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
||||||
|
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
||||||
|
elif pooling_method == "cls":
|
||||||
|
return last_hidden_state[:, 0]
|
||||||
|
elif pooling_method == "pooler":
|
||||||
|
return pooler_output
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Pooling method not implemented!")
|
||||||
|
|
||||||
|
|
||||||
|
def load_corpus(corpus_path: str):
|
||||||
|
corpus = datasets.load_dataset(
|
||||||
|
'json',
|
||||||
|
data_files=corpus_path,
|
||||||
|
split="train",
|
||||||
|
num_proc=4)
|
||||||
|
return corpus
|
||||||
|
|
||||||
|
|
||||||
|
class Index_Builder:
|
||||||
|
r"""A tool class used to build an index used in retrieval.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
retrieval_method,
|
||||||
|
model_path,
|
||||||
|
corpus_path,
|
||||||
|
save_dir,
|
||||||
|
max_length,
|
||||||
|
batch_size,
|
||||||
|
use_fp16,
|
||||||
|
pooling_method,
|
||||||
|
faiss_type=None,
|
||||||
|
embedding_path=None,
|
||||||
|
save_embedding=False,
|
||||||
|
faiss_gpu=False
|
||||||
|
):
|
||||||
|
|
||||||
|
self.retrieval_method = retrieval_method.lower()
|
||||||
|
self.model_path = model_path
|
||||||
|
self.corpus_path = corpus_path
|
||||||
|
self.save_dir = save_dir
|
||||||
|
self.max_length = max_length
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.use_fp16 = use_fp16
|
||||||
|
self.pooling_method = pooling_method
|
||||||
|
self.faiss_type = faiss_type if faiss_type is not None else 'Flat'
|
||||||
|
self.embedding_path = embedding_path
|
||||||
|
self.save_embedding = save_embedding
|
||||||
|
self.faiss_gpu = faiss_gpu
|
||||||
|
|
||||||
|
self.gpu_num = torch.cuda.device_count()
|
||||||
|
# prepare save dir
|
||||||
|
print(self.save_dir)
|
||||||
|
if not os.path.exists(self.save_dir):
|
||||||
|
os.makedirs(self.save_dir)
|
||||||
|
else:
|
||||||
|
if not self._check_dir(self.save_dir):
|
||||||
|
warnings.warn("Some files already exists in save dir and may be overwritten.", UserWarning)
|
||||||
|
|
||||||
|
self.index_save_path = os.path.join(self.save_dir, f"{self.retrieval_method}_{self.faiss_type}.index")
|
||||||
|
|
||||||
|
self.embedding_save_path = os.path.join(self.save_dir, f"emb_{self.retrieval_method}.memmap")
|
||||||
|
|
||||||
|
self.corpus = load_corpus(self.corpus_path)
|
||||||
|
|
||||||
|
print("Finish loading...")
|
||||||
|
@staticmethod
|
||||||
|
def _check_dir(dir_path):
|
||||||
|
r"""Check if the dir path exists and if there is content.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
if os.path.isdir(dir_path):
|
||||||
|
if len(os.listdir(dir_path)) > 0:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
os.makedirs(dir_path, exist_ok=True)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def build_index(self):
|
||||||
|
r"""Constructing different indexes based on selective retrieval method.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.retrieval_method == "bm25":
|
||||||
|
self.build_bm25_index()
|
||||||
|
else:
|
||||||
|
self.build_dense_index()
|
||||||
|
|
||||||
|
def build_bm25_index(self):
|
||||||
|
"""Building BM25 index based on Pyserini library.
|
||||||
|
|
||||||
|
Reference: https://github.com/castorini/pyserini/blob/master/docs/usage-index.md#building-a-bm25-index-direct-java-implementation
|
||||||
|
"""
|
||||||
|
|
||||||
|
# to use pyserini pipeline, we first need to place jsonl file in the folder
|
||||||
|
self.save_dir = os.path.join(self.save_dir, "bm25")
|
||||||
|
os.makedirs(self.save_dir, exist_ok=True)
|
||||||
|
temp_dir = self.save_dir + "/temp"
|
||||||
|
temp_file_path = temp_dir + "/temp.jsonl"
|
||||||
|
os.makedirs(temp_dir)
|
||||||
|
|
||||||
|
# if self.have_contents:
|
||||||
|
# shutil.copyfile(self.corpus_path, temp_file_path)
|
||||||
|
# else:
|
||||||
|
# with open(temp_file_path, "w") as f:
|
||||||
|
# for item in self.corpus:
|
||||||
|
# f.write(json.dumps(item) + "\n")
|
||||||
|
shutil.copyfile(self.corpus_path, temp_file_path)
|
||||||
|
|
||||||
|
print("Start building bm25 index...")
|
||||||
|
pyserini_args = ["--collection", "JsonCollection",
|
||||||
|
"--input", temp_dir,
|
||||||
|
"--index", self.save_dir,
|
||||||
|
"--generator", "DefaultLuceneDocumentGenerator",
|
||||||
|
"--threads", "1"]
|
||||||
|
|
||||||
|
subprocess.run(["python", "-m", "pyserini.index.lucene"] + pyserini_args)
|
||||||
|
|
||||||
|
shutil.rmtree(temp_dir)
|
||||||
|
|
||||||
|
print("Finish!")
|
||||||
|
|
||||||
|
def _load_embedding(self, embedding_path, corpus_size, hidden_size):
|
||||||
|
all_embeddings = np.memmap(
|
||||||
|
embedding_path,
|
||||||
|
mode="r",
|
||||||
|
dtype=np.float32
|
||||||
|
).reshape(corpus_size, hidden_size)
|
||||||
|
return all_embeddings
|
||||||
|
|
||||||
|
def _save_embedding(self, all_embeddings):
|
||||||
|
memmap = np.memmap(
|
||||||
|
self.embedding_save_path,
|
||||||
|
shape=all_embeddings.shape,
|
||||||
|
mode="w+",
|
||||||
|
dtype=all_embeddings.dtype
|
||||||
|
)
|
||||||
|
length = all_embeddings.shape[0]
|
||||||
|
# add in batch
|
||||||
|
save_batch_size = 10000
|
||||||
|
if length > save_batch_size:
|
||||||
|
for i in tqdm(range(0, length, save_batch_size), leave=False, desc="Saving Embeddings"):
|
||||||
|
j = min(i + save_batch_size, length)
|
||||||
|
memmap[i: j] = all_embeddings[i: j]
|
||||||
|
else:
|
||||||
|
memmap[:] = all_embeddings
|
||||||
|
|
||||||
|
def encode_all(self):
|
||||||
|
if self.gpu_num > 1:
|
||||||
|
print("Use multi gpu!")
|
||||||
|
self.encoder = torch.nn.DataParallel(self.encoder)
|
||||||
|
self.batch_size = self.batch_size * self.gpu_num
|
||||||
|
|
||||||
|
all_embeddings = []
|
||||||
|
|
||||||
|
for start_idx in tqdm(range(0, len(self.corpus), self.batch_size), desc='Inference Embeddings:'):
|
||||||
|
|
||||||
|
batch_data_title = self.corpus[start_idx:start_idx+self.batch_size]['title']
|
||||||
|
batch_data_text = self.corpus[start_idx:start_idx+self.batch_size]['text']
|
||||||
|
batch_data = ['"' + title + '"\n' + text for title, text in zip(batch_data_title, batch_data_text)]
|
||||||
|
|
||||||
|
if self.retrieval_method == "e5":
|
||||||
|
batch_data = [f"passage: {doc}" for doc in batch_data]
|
||||||
|
|
||||||
|
inputs = self.tokenizer(
|
||||||
|
batch_data,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors='pt',
|
||||||
|
max_length=self.max_length,
|
||||||
|
).to('cuda')
|
||||||
|
|
||||||
|
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||||
|
|
||||||
|
#TODO: support encoder-only T5 model
|
||||||
|
if "T5" in type(self.encoder).__name__:
|
||||||
|
# T5-based retrieval model
|
||||||
|
decoder_input_ids = torch.zeros(
|
||||||
|
(inputs['input_ids'].shape[0], 1), dtype=torch.long
|
||||||
|
).to(inputs['input_ids'].device)
|
||||||
|
output = self.encoder(
|
||||||
|
**inputs, decoder_input_ids=decoder_input_ids, return_dict=True
|
||||||
|
)
|
||||||
|
embeddings = output.last_hidden_state[:, 0, :]
|
||||||
|
|
||||||
|
else:
|
||||||
|
output = self.encoder(**inputs, return_dict=True)
|
||||||
|
embeddings = pooling(output.pooler_output,
|
||||||
|
output.last_hidden_state,
|
||||||
|
inputs['attention_mask'],
|
||||||
|
self.pooling_method)
|
||||||
|
if "dpr" not in self.retrieval_method:
|
||||||
|
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
|
||||||
|
|
||||||
|
embeddings = cast(torch.Tensor, embeddings)
|
||||||
|
embeddings = embeddings.detach().cpu().numpy()
|
||||||
|
all_embeddings.append(embeddings)
|
||||||
|
|
||||||
|
all_embeddings = np.concatenate(all_embeddings, axis=0)
|
||||||
|
all_embeddings = all_embeddings.astype(np.float32)
|
||||||
|
|
||||||
|
return all_embeddings
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def build_dense_index(self):
|
||||||
|
"""Obtain the representation of documents based on the embedding model(BERT-based) and
|
||||||
|
construct a faiss index.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if os.path.exists(self.index_save_path):
|
||||||
|
print("The index file already exists and will be overwritten.")
|
||||||
|
|
||||||
|
self.encoder, self.tokenizer = load_model(model_path = self.model_path,
|
||||||
|
use_fp16 = self.use_fp16)
|
||||||
|
if self.embedding_path is not None:
|
||||||
|
hidden_size = self.encoder.config.hidden_size
|
||||||
|
corpus_size = len(self.corpus)
|
||||||
|
all_embeddings = self._load_embedding(self.embedding_path, corpus_size, hidden_size)
|
||||||
|
else:
|
||||||
|
all_embeddings = self.encode_all()
|
||||||
|
if self.save_embedding:
|
||||||
|
self._save_embedding(all_embeddings)
|
||||||
|
del self.corpus
|
||||||
|
|
||||||
|
# build index
|
||||||
|
print("Creating index")
|
||||||
|
dim = all_embeddings.shape[-1]
|
||||||
|
faiss_index = faiss.index_factory(dim, self.faiss_type, faiss.METRIC_INNER_PRODUCT)
|
||||||
|
|
||||||
|
if self.faiss_gpu:
|
||||||
|
co = faiss.GpuMultipleClonerOptions()
|
||||||
|
co.useFloat16 = True
|
||||||
|
co.shard = True
|
||||||
|
faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co)
|
||||||
|
if not faiss_index.is_trained:
|
||||||
|
faiss_index.train(all_embeddings)
|
||||||
|
faiss_index.add(all_embeddings)
|
||||||
|
faiss_index = faiss.index_gpu_to_cpu(faiss_index)
|
||||||
|
else:
|
||||||
|
if not faiss_index.is_trained:
|
||||||
|
faiss_index.train(all_embeddings)
|
||||||
|
faiss_index.add(all_embeddings)
|
||||||
|
|
||||||
|
faiss.write_index(faiss_index, self.index_save_path)
|
||||||
|
print("Finish!")
|
||||||
|
|
||||||
|
|
||||||
|
MODEL2POOLING = {
|
||||||
|
"e5": "mean",
|
||||||
|
"bge": "cls",
|
||||||
|
"contriever": "mean",
|
||||||
|
'jina': 'mean'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description = "Creating index.")
|
||||||
|
|
||||||
|
# Basic parameters
|
||||||
|
parser.add_argument('--retrieval_method', type=str)
|
||||||
|
parser.add_argument('--model_path', type=str, default=None)
|
||||||
|
parser.add_argument('--corpus_path', type=str)
|
||||||
|
parser.add_argument('--save_dir', default= 'indexes/',type=str)
|
||||||
|
|
||||||
|
# Parameters for building dense index
|
||||||
|
parser.add_argument('--max_length', type=int, default=180)
|
||||||
|
parser.add_argument('--batch_size', type=int, default=512)
|
||||||
|
parser.add_argument('--use_fp16', default=False, action='store_true')
|
||||||
|
parser.add_argument('--pooling_method', type=str, default=None)
|
||||||
|
parser.add_argument('--faiss_type',default=None,type=str)
|
||||||
|
parser.add_argument('--embedding_path', default=None, type=str)
|
||||||
|
parser.add_argument('--save_embedding', action='store_true', default=False)
|
||||||
|
parser.add_argument('--faiss_gpu', default=False, action='store_true')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.pooling_method is None:
|
||||||
|
pooling_method = 'mean'
|
||||||
|
for k,v in MODEL2POOLING.items():
|
||||||
|
if k in args.retrieval_method.lower():
|
||||||
|
pooling_method = v
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if args.pooling_method not in ['mean','cls','pooler']:
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
pooling_method = args.pooling_method
|
||||||
|
|
||||||
|
|
||||||
|
index_builder = Index_Builder(
|
||||||
|
retrieval_method = args.retrieval_method,
|
||||||
|
model_path = args.model_path,
|
||||||
|
corpus_path = args.corpus_path,
|
||||||
|
save_dir = args.save_dir,
|
||||||
|
max_length = args.max_length,
|
||||||
|
batch_size = args.batch_size,
|
||||||
|
use_fp16 = args.use_fp16,
|
||||||
|
pooling_method = pooling_method,
|
||||||
|
faiss_type = args.faiss_type,
|
||||||
|
embedding_path = args.embedding_path,
|
||||||
|
save_embedding = args.save_embedding,
|
||||||
|
faiss_gpu = args.faiss_gpu
|
||||||
|
)
|
||||||
|
index_builder.build_index()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
368
search_r1/search/retrieval.py
Normal file
368
search_r1/search/retrieval.py
Normal file
@@ -0,0 +1,368 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from typing import List, Dict
|
||||||
|
import functools
|
||||||
|
from tqdm import tqdm
|
||||||
|
from multiprocessing import Pool
|
||||||
|
import faiss
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from transformers import AutoConfig, AutoTokenizer, AutoModel
|
||||||
|
import argparse
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
|
||||||
|
def load_corpus(corpus_path: str):
|
||||||
|
corpus = datasets.load_dataset(
|
||||||
|
'json',
|
||||||
|
data_files=corpus_path,
|
||||||
|
split="train",
|
||||||
|
num_proc=4)
|
||||||
|
return corpus
|
||||||
|
|
||||||
|
|
||||||
|
def read_jsonl(file_path):
|
||||||
|
data = []
|
||||||
|
|
||||||
|
with open(file_path, "r") as f:
|
||||||
|
readin = f.readlines()
|
||||||
|
for line in readin:
|
||||||
|
data.append(json.loads(line))
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def load_docs(corpus, doc_idxs):
|
||||||
|
results = [corpus[int(idx)] for idx in doc_idxs]
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(
|
||||||
|
model_path: str,
|
||||||
|
use_fp16: bool = False
|
||||||
|
):
|
||||||
|
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
model.eval()
|
||||||
|
model.cuda()
|
||||||
|
if use_fp16:
|
||||||
|
model = model.half()
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True)
|
||||||
|
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def pooling(
|
||||||
|
pooler_output,
|
||||||
|
last_hidden_state,
|
||||||
|
attention_mask = None,
|
||||||
|
pooling_method = "mean"
|
||||||
|
):
|
||||||
|
if pooling_method == "mean":
|
||||||
|
last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
||||||
|
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
||||||
|
elif pooling_method == "cls":
|
||||||
|
return last_hidden_state[:, 0]
|
||||||
|
elif pooling_method == "pooler":
|
||||||
|
return pooler_output
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Pooling method not implemented!")
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder:
|
||||||
|
def __init__(self, model_name, model_path, pooling_method, max_length, use_fp16):
|
||||||
|
self.model_name = model_name
|
||||||
|
self.model_path = model_path
|
||||||
|
self.pooling_method = pooling_method
|
||||||
|
self.max_length = max_length
|
||||||
|
self.use_fp16 = use_fp16
|
||||||
|
|
||||||
|
self.model, self.tokenizer = load_model(model_path=model_path,
|
||||||
|
use_fp16=use_fp16)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def encode(self, query_list: List[str], is_query=True) -> np.ndarray:
|
||||||
|
# processing query for different encoders
|
||||||
|
if isinstance(query_list, str):
|
||||||
|
query_list = [query_list]
|
||||||
|
|
||||||
|
if "e5" in self.model_name.lower():
|
||||||
|
if is_query:
|
||||||
|
query_list = [f"query: {query}" for query in query_list]
|
||||||
|
else:
|
||||||
|
query_list = [f"passage: {query}" for query in query_list]
|
||||||
|
|
||||||
|
if "bge" in self.model_name.lower():
|
||||||
|
if is_query:
|
||||||
|
query_list = [f"Represent this sentence for searching relevant passages: {query}" for query in query_list]
|
||||||
|
|
||||||
|
inputs = self.tokenizer(query_list,
|
||||||
|
max_length=self.max_length,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt"
|
||||||
|
)
|
||||||
|
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||||
|
|
||||||
|
if "T5" in type(self.model).__name__:
|
||||||
|
# T5-based retrieval model
|
||||||
|
decoder_input_ids = torch.zeros(
|
||||||
|
(inputs['input_ids'].shape[0], 1), dtype=torch.long
|
||||||
|
).to(inputs['input_ids'].device)
|
||||||
|
output = self.model(
|
||||||
|
**inputs, decoder_input_ids=decoder_input_ids, return_dict=True
|
||||||
|
)
|
||||||
|
query_emb = output.last_hidden_state[:, 0, :]
|
||||||
|
|
||||||
|
else:
|
||||||
|
output = self.model(**inputs, return_dict=True)
|
||||||
|
query_emb = pooling(output.pooler_output,
|
||||||
|
output.last_hidden_state,
|
||||||
|
inputs['attention_mask'],
|
||||||
|
self.pooling_method)
|
||||||
|
if "dpr" not in self.model_name.lower():
|
||||||
|
query_emb = torch.nn.functional.normalize(query_emb, dim=-1)
|
||||||
|
|
||||||
|
query_emb = query_emb.detach().cpu().numpy()
|
||||||
|
query_emb = query_emb.astype(np.float32, order="C")
|
||||||
|
return query_emb
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRetriever:
|
||||||
|
"""Base object for all retrievers."""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
self.config = config
|
||||||
|
self.retrieval_method = config.retrieval_method
|
||||||
|
self.topk = config.retrieval_topk
|
||||||
|
|
||||||
|
self.index_path = config.index_path
|
||||||
|
self.corpus_path = config.corpus_path
|
||||||
|
|
||||||
|
# self.cache_save_path = os.path.join(config.save_dir, 'retrieval_cache.json')
|
||||||
|
|
||||||
|
def _search(self, query: str, num: int, return_score:bool) -> List[Dict[str, str]]:
|
||||||
|
r"""Retrieve topk relevant documents in corpus.
|
||||||
|
Return:
|
||||||
|
list: contains information related to the document, including:
|
||||||
|
contents: used for building index
|
||||||
|
title: (if provided)
|
||||||
|
text: (if provided)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _batch_search(self, query_list, num, return_score):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def search(self, *args, **kwargs):
|
||||||
|
return self._search(*args, **kwargs)
|
||||||
|
|
||||||
|
def batch_search(self, *args, **kwargs):
|
||||||
|
return self._batch_search(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class BM25Retriever(BaseRetriever):
|
||||||
|
r"""BM25 retriever based on pre-built pyserini index."""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
from pyserini.search.lucene import LuceneSearcher
|
||||||
|
self.searcher = LuceneSearcher(self.index_path)
|
||||||
|
self.contain_doc = self._check_contain_doc()
|
||||||
|
if not self.contain_doc:
|
||||||
|
self.corpus = load_corpus(self.corpus_path)
|
||||||
|
self.max_process_num = 8
|
||||||
|
|
||||||
|
def _check_contain_doc(self):
|
||||||
|
r"""Check if the index contains document content
|
||||||
|
"""
|
||||||
|
return self.searcher.doc(0).raw() is not None
|
||||||
|
|
||||||
|
def _search(self, query: str, num: int = None, return_score = False) -> List[Dict[str, str]]:
|
||||||
|
if num is None:
|
||||||
|
num = self.topk
|
||||||
|
|
||||||
|
hits = self.searcher.search(query, num)
|
||||||
|
if len(hits) < 1:
|
||||||
|
if return_score:
|
||||||
|
return [],[]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
scores = [hit.score for hit in hits]
|
||||||
|
if len(hits) < num:
|
||||||
|
warnings.warn('Not enough documents retrieved!')
|
||||||
|
else:
|
||||||
|
hits = hits[:num]
|
||||||
|
|
||||||
|
if self.contain_doc:
|
||||||
|
all_contents = [json.loads(self.searcher.doc(hit.docid).raw())['contents'] for hit in hits]
|
||||||
|
results = [{'title': content.split("\n")[0].strip("\""),
|
||||||
|
'text': "\n".join(content.split("\n")[1:]),
|
||||||
|
'contents': content} for content in all_contents]
|
||||||
|
else:
|
||||||
|
results = load_docs(self.corpus, [hit.docid for hit in hits])
|
||||||
|
|
||||||
|
if return_score:
|
||||||
|
return results, scores
|
||||||
|
else:
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _batch_search(self, query_list, num: int = None, return_score = False):
|
||||||
|
# TODO: modify batch method
|
||||||
|
results = []
|
||||||
|
scores = []
|
||||||
|
for query in query_list:
|
||||||
|
item_result, item_score = self._search(query, num,True)
|
||||||
|
results.append(item_result)
|
||||||
|
scores.append(item_score)
|
||||||
|
|
||||||
|
if return_score:
|
||||||
|
return results, scores
|
||||||
|
else:
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_available_gpu_memory():
|
||||||
|
memory_info = []
|
||||||
|
for i in range(torch.cuda.device_count()):
|
||||||
|
total_memory = torch.cuda.get_device_properties(i).total_memory
|
||||||
|
allocated_memory = torch.cuda.memory_allocated(i)
|
||||||
|
free_memory = total_memory - allocated_memory
|
||||||
|
memory_info.append((i, free_memory / 1e9)) # Convert to GB
|
||||||
|
return memory_info
|
||||||
|
|
||||||
|
|
||||||
|
class DenseRetriever(BaseRetriever):
|
||||||
|
r"""Dense retriever based on pre-built faiss index."""
|
||||||
|
|
||||||
|
def __init__(self, config: dict):
|
||||||
|
super().__init__(config)
|
||||||
|
self.index = faiss.read_index(self.index_path)
|
||||||
|
if config.faiss_gpu:
|
||||||
|
co = faiss.GpuMultipleClonerOptions()
|
||||||
|
co.useFloat16 = True
|
||||||
|
co.shard = True
|
||||||
|
self.index = faiss.index_cpu_to_all_gpus(self.index, co=co)
|
||||||
|
# self.index = faiss.index_cpu_to_all_gpus(self.index)
|
||||||
|
|
||||||
|
self.corpus = load_corpus(self.corpus_path)
|
||||||
|
self.encoder = Encoder(
|
||||||
|
model_name = self.retrieval_method,
|
||||||
|
model_path = config.retrieval_model_path,
|
||||||
|
pooling_method = config.retrieval_pooling_method,
|
||||||
|
max_length = config.retrieval_query_max_length,
|
||||||
|
use_fp16 = config.retrieval_use_fp16
|
||||||
|
)
|
||||||
|
self.topk = config.retrieval_topk
|
||||||
|
self.batch_size = self.config.retrieval_batch_size
|
||||||
|
|
||||||
|
def _search(self, query: str, num: int = None, return_score = False):
|
||||||
|
if num is None:
|
||||||
|
num = self.topk
|
||||||
|
query_emb = self.encoder.encode(query)
|
||||||
|
scores, idxs = self.index.search(query_emb, k=num)
|
||||||
|
idxs = idxs[0]
|
||||||
|
scores = scores[0]
|
||||||
|
|
||||||
|
results = load_docs(self.corpus, idxs)
|
||||||
|
if return_score:
|
||||||
|
return results, scores
|
||||||
|
else:
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _batch_search(self, query_list: List[str], num: int = None, return_score = False):
|
||||||
|
if isinstance(query_list, str):
|
||||||
|
query_list = [query_list]
|
||||||
|
if num is None:
|
||||||
|
num = self.topk
|
||||||
|
|
||||||
|
batch_size = self.batch_size
|
||||||
|
|
||||||
|
results = []
|
||||||
|
scores = []
|
||||||
|
|
||||||
|
for start_idx in tqdm(range(0, len(query_list), batch_size), desc='Retrieval process: '):
|
||||||
|
query_batch = query_list[start_idx:start_idx + batch_size]
|
||||||
|
|
||||||
|
# from time import time
|
||||||
|
# a = time()
|
||||||
|
batch_emb = self.encoder.encode(query_batch)
|
||||||
|
# b = time()
|
||||||
|
# print(f'################### encode time {b-a} #####################')
|
||||||
|
batch_scores, batch_idxs = self.index.search(batch_emb, k=num)
|
||||||
|
batch_scores = batch_scores.tolist()
|
||||||
|
batch_idxs = batch_idxs.tolist()
|
||||||
|
# print(f'################### search time {time()-b} #####################')
|
||||||
|
# exit()
|
||||||
|
|
||||||
|
flat_idxs = sum(batch_idxs, [])
|
||||||
|
batch_results = load_docs(self.corpus, flat_idxs)
|
||||||
|
batch_results = [batch_results[i*num : (i+1)*num] for i in range(len(batch_idxs))]
|
||||||
|
|
||||||
|
scores.extend(batch_scores)
|
||||||
|
results.extend(batch_results)
|
||||||
|
|
||||||
|
if return_score:
|
||||||
|
return results, scores
|
||||||
|
else:
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_retriever(config):
|
||||||
|
r"""Automatically select retriever class based on config's retrieval method
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (dict): configuration with 'retrieval_method' key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Retriever: retriever instance
|
||||||
|
"""
|
||||||
|
if config.retrieval_method == "bm25":
|
||||||
|
return BM25Retriever(config)
|
||||||
|
else:
|
||||||
|
return DenseRetriever(config)
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset(config):
|
||||||
|
"""Load dataset from config."""
|
||||||
|
|
||||||
|
split_path = os.path.join(config.dataset_path, f'{config.data_split}.jsonl')
|
||||||
|
return read_jsonl(split_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description = "Retrieval")
|
||||||
|
|
||||||
|
# Basic parameters
|
||||||
|
parser.add_argument('--retrieval_method', type=str)
|
||||||
|
parser.add_argument('--retrieval_topk', type=int, default=10)
|
||||||
|
parser.add_argument('--index_path', type=str, default=None)
|
||||||
|
parser.add_argument('--corpus_path', type=str)
|
||||||
|
parser.add_argument('--dataset_path', default=None, type=str)
|
||||||
|
|
||||||
|
parser.add_argument('--faiss_gpu', default=True, type=bool)
|
||||||
|
parser.add_argument('--data_split', default="train", type=str)
|
||||||
|
|
||||||
|
parser.add_argument('--retrieval_model_path', type=str, default=None)
|
||||||
|
parser.add_argument('--retrieval_pooling_method', default='mean', type=str)
|
||||||
|
parser.add_argument('--retrieval_query_max_length', default=256, type=str)
|
||||||
|
parser.add_argument('--retrieval_use_fp16', action='store_true', default=False)
|
||||||
|
parser.add_argument('--retrieval_batch_size', default=512, type=int)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
args.index_path = os.path.join(args.index_path, f'{args.retrieval_method}_Flat.index') if args.retrieval_method != 'bm25' else os.path.join(args.index_path, 'bm25')
|
||||||
|
|
||||||
|
# load dataset
|
||||||
|
all_split = get_dataset(args)
|
||||||
|
|
||||||
|
input_query = [sample['question'] for sample in all_split[:512]]
|
||||||
|
|
||||||
|
# initialize the retriever and conduct retrieval
|
||||||
|
retriever = get_retriever(args)
|
||||||
|
print('Start Retrieving ...')
|
||||||
|
results, scores = retriever.batch_search(input_query, return_score=True)
|
||||||
|
|
||||||
|
# from IPython import embed
|
||||||
|
# embed()
|
||||||
25
search_r1/search/retrieval.sh
Normal file
25
search_r1/search/retrieval.sh
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
|
||||||
|
DATA_NAME=nq
|
||||||
|
|
||||||
|
DATASET_PATH="/home/peterjin/mnt/data/$DATA_NAME"
|
||||||
|
|
||||||
|
SPLIT='test'
|
||||||
|
TOPK=3
|
||||||
|
|
||||||
|
INDEX_PATH=/home/peterjin/mnt/index/wiki-18
|
||||||
|
CORPUS_PATH=/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl
|
||||||
|
SAVE_NAME=e5_${TOPK}_wiki18.json
|
||||||
|
|
||||||
|
# INDEX_PATH=/home/peterjin/rm_retrieval_corpus/index/wiki-21
|
||||||
|
# CORPUS_PATH=/home/peterjin/rm_retrieval_corpus/corpora/wiki/enwiki-dec2021/text-list-100-sec.jsonl
|
||||||
|
# SAVE_NAME=e5_${TOPK}_wiki21.json
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python retrieval.py --retrieval_method e5 \
|
||||||
|
--retrieval_topk $TOPK \
|
||||||
|
--index_path $INDEX_PATH \
|
||||||
|
--corpus_path $CORPUS_PATH \
|
||||||
|
--dataset_path $DATASET_PATH \
|
||||||
|
--data_split $SPLIT \
|
||||||
|
--retrieval_model_path "intfloat/e5-base-v2" \
|
||||||
|
--retrieval_pooling_method "mean" \
|
||||||
|
--retrieval_batch_size 512 \
|
||||||
23
search_r1/search/retrieval_request.py
Normal file
23
search_r1/search/retrieval_request.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import requests
|
||||||
|
|
||||||
|
# URL for your local FastAPI server
|
||||||
|
url = "http://127.0.0.1:8000/retrieve"
|
||||||
|
|
||||||
|
# Example payload
|
||||||
|
payload = {
|
||||||
|
"queries": ["What is the capital of France?", "Explain neural networks."] * 200,
|
||||||
|
"topk": 5,
|
||||||
|
"return_scores": True
|
||||||
|
}
|
||||||
|
|
||||||
|
# Send POST request
|
||||||
|
response = requests.post(url, json=payload)
|
||||||
|
|
||||||
|
# Raise an exception if the request failed
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Get the JSON response
|
||||||
|
retrieved_data = response.json()
|
||||||
|
|
||||||
|
print("Response from server:")
|
||||||
|
print(retrieved_data)
|
||||||
382
search_r1/search/retrieval_server.py
Normal file
382
search_r1/search/retrieval_server.py
Normal file
@@ -0,0 +1,382 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import faiss
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from transformers import AutoConfig, AutoTokenizer, AutoModel
|
||||||
|
from tqdm import tqdm
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Launch the local faiss retriever.")
|
||||||
|
parser.add_argument("--index_path", type=str, default="/home/peterjin/mnt/index/wiki-18/e5_Flat.index", help="Corpus indexing file.")
|
||||||
|
parser.add_argument("--corpus_path", type=str, default="/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl", help="Local corpus file.")
|
||||||
|
parser.add_argument("--topk", type=int, default=3, help="Number of retrieved passages for one query.")
|
||||||
|
parser.add_argument("--retriever_model", type=str, default="intfloat/e5-base-v2", help="Name of the retriever model.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
def load_corpus(corpus_path: str):
|
||||||
|
corpus = datasets.load_dataset(
|
||||||
|
'json',
|
||||||
|
data_files=corpus_path,
|
||||||
|
split="train",
|
||||||
|
num_proc=4
|
||||||
|
)
|
||||||
|
return corpus
|
||||||
|
|
||||||
|
def read_jsonl(file_path):
|
||||||
|
data = []
|
||||||
|
with open(file_path, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
data.append(json.loads(line))
|
||||||
|
return data
|
||||||
|
|
||||||
|
def load_docs(corpus, doc_idxs):
|
||||||
|
results = [corpus[int(idx)] for idx in doc_idxs]
|
||||||
|
return results
|
||||||
|
|
||||||
|
def load_model(model_path: str, use_fp16: bool = False):
|
||||||
|
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
model.eval()
|
||||||
|
model.cuda()
|
||||||
|
if use_fp16:
|
||||||
|
model = model.half()
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
def pooling(
|
||||||
|
pooler_output,
|
||||||
|
last_hidden_state,
|
||||||
|
attention_mask = None,
|
||||||
|
pooling_method = "mean"
|
||||||
|
):
|
||||||
|
if pooling_method == "mean":
|
||||||
|
last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
||||||
|
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
||||||
|
elif pooling_method == "cls":
|
||||||
|
return last_hidden_state[:, 0]
|
||||||
|
elif pooling_method == "pooler":
|
||||||
|
return pooler_output
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Pooling method not implemented!")
|
||||||
|
|
||||||
|
class Encoder:
|
||||||
|
def __init__(self, model_name, model_path, pooling_method, max_length, use_fp16):
|
||||||
|
self.model_name = model_name
|
||||||
|
self.model_path = model_path
|
||||||
|
self.pooling_method = pooling_method
|
||||||
|
self.max_length = max_length
|
||||||
|
self.use_fp16 = use_fp16
|
||||||
|
|
||||||
|
self.model, self.tokenizer = load_model(model_path=model_path, use_fp16=use_fp16)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def encode(self, query_list: List[str], is_query=True) -> np.ndarray:
|
||||||
|
# processing query for different encoders
|
||||||
|
if isinstance(query_list, str):
|
||||||
|
query_list = [query_list]
|
||||||
|
|
||||||
|
if "e5" in self.model_name.lower():
|
||||||
|
if is_query:
|
||||||
|
query_list = [f"query: {query}" for query in query_list]
|
||||||
|
else:
|
||||||
|
query_list = [f"passage: {query}" for query in query_list]
|
||||||
|
|
||||||
|
if "bge" in self.model_name.lower():
|
||||||
|
if is_query:
|
||||||
|
query_list = [f"Represent this sentence for searching relevant passages: {query}" for query in query_list]
|
||||||
|
|
||||||
|
inputs = self.tokenizer(query_list,
|
||||||
|
max_length=self.max_length,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt"
|
||||||
|
)
|
||||||
|
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||||
|
|
||||||
|
if "T5" in type(self.model).__name__:
|
||||||
|
# T5-based retrieval model
|
||||||
|
decoder_input_ids = torch.zeros(
|
||||||
|
(inputs['input_ids'].shape[0], 1), dtype=torch.long
|
||||||
|
).to(inputs['input_ids'].device)
|
||||||
|
output = self.model(
|
||||||
|
**inputs, decoder_input_ids=decoder_input_ids, return_dict=True
|
||||||
|
)
|
||||||
|
query_emb = output.last_hidden_state[:, 0, :]
|
||||||
|
else:
|
||||||
|
output = self.model(**inputs, return_dict=True)
|
||||||
|
query_emb = pooling(output.pooler_output,
|
||||||
|
output.last_hidden_state,
|
||||||
|
inputs['attention_mask'],
|
||||||
|
self.pooling_method)
|
||||||
|
if "dpr" not in self.model_name.lower():
|
||||||
|
query_emb = torch.nn.functional.normalize(query_emb, dim=-1)
|
||||||
|
|
||||||
|
query_emb = query_emb.detach().cpu().numpy()
|
||||||
|
query_emb = query_emb.astype(np.float32, order="C")
|
||||||
|
return query_emb
|
||||||
|
|
||||||
|
class BaseRetriever:
|
||||||
|
def __init__(self, config):
|
||||||
|
self.config = config
|
||||||
|
self.retrieval_method = config.retrieval_method
|
||||||
|
self.topk = config.retrieval_topk
|
||||||
|
|
||||||
|
self.index_path = config.index_path
|
||||||
|
self.corpus_path = config.corpus_path
|
||||||
|
|
||||||
|
def _search(self, query: str, num: int, return_score: bool):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _batch_search(self, query_list: List[str], num: int, return_score: bool):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def search(self, query: str, num: int = None, return_score: bool = False):
|
||||||
|
return self._search(query, num, return_score)
|
||||||
|
|
||||||
|
def batch_search(self, query_list: List[str], num: int = None, return_score: bool = False):
|
||||||
|
return self._batch_search(query_list, num, return_score)
|
||||||
|
|
||||||
|
class BM25Retriever(BaseRetriever):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
from pyserini.search.lucene import LuceneSearcher
|
||||||
|
self.searcher = LuceneSearcher(self.index_path)
|
||||||
|
self.contain_doc = self._check_contain_doc()
|
||||||
|
if not self.contain_doc:
|
||||||
|
self.corpus = load_corpus(self.corpus_path)
|
||||||
|
self.max_process_num = 8
|
||||||
|
|
||||||
|
def _check_contain_doc(self):
|
||||||
|
return self.searcher.doc(0).raw() is not None
|
||||||
|
|
||||||
|
def _search(self, query: str, num: int = None, return_score: bool = False):
|
||||||
|
if num is None:
|
||||||
|
num = self.topk
|
||||||
|
hits = self.searcher.search(query, num)
|
||||||
|
if len(hits) < 1:
|
||||||
|
if return_score:
|
||||||
|
return [], []
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
scores = [hit.score for hit in hits]
|
||||||
|
if len(hits) < num:
|
||||||
|
warnings.warn('Not enough documents retrieved!')
|
||||||
|
else:
|
||||||
|
hits = hits[:num]
|
||||||
|
|
||||||
|
if self.contain_doc:
|
||||||
|
all_contents = [
|
||||||
|
json.loads(self.searcher.doc(hit.docid).raw())['contents']
|
||||||
|
for hit in hits
|
||||||
|
]
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
'title': content.split("\n")[0].strip("\""),
|
||||||
|
'text': "\n".join(content.split("\n")[1:]),
|
||||||
|
'contents': content
|
||||||
|
}
|
||||||
|
for content in all_contents
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
results = load_docs(self.corpus, [hit.docid for hit in hits])
|
||||||
|
|
||||||
|
if return_score:
|
||||||
|
return results, scores
|
||||||
|
else:
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _batch_search(self, query_list: List[str], num: int = None, return_score: bool = False):
|
||||||
|
results = []
|
||||||
|
scores = []
|
||||||
|
for query in query_list:
|
||||||
|
item_result, item_score = self._search(query, num, True)
|
||||||
|
results.append(item_result)
|
||||||
|
scores.append(item_score)
|
||||||
|
if return_score:
|
||||||
|
return results, scores
|
||||||
|
else:
|
||||||
|
return results
|
||||||
|
|
||||||
|
class DenseRetriever(BaseRetriever):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.index = faiss.read_index(self.index_path)
|
||||||
|
if config.faiss_gpu:
|
||||||
|
co = faiss.GpuMultipleClonerOptions()
|
||||||
|
co.useFloat16 = True
|
||||||
|
co.shard = True
|
||||||
|
self.index = faiss.index_cpu_to_all_gpus(self.index, co=co)
|
||||||
|
|
||||||
|
self.corpus = load_corpus(self.corpus_path)
|
||||||
|
self.encoder = Encoder(
|
||||||
|
model_name = self.retrieval_method,
|
||||||
|
model_path = config.retrieval_model_path,
|
||||||
|
pooling_method = config.retrieval_pooling_method,
|
||||||
|
max_length = config.retrieval_query_max_length,
|
||||||
|
use_fp16 = config.retrieval_use_fp16
|
||||||
|
)
|
||||||
|
self.topk = config.retrieval_topk
|
||||||
|
self.batch_size = config.retrieval_batch_size
|
||||||
|
|
||||||
|
def _search(self, query: str, num: int = None, return_score: bool = False):
|
||||||
|
if num is None:
|
||||||
|
num = self.topk
|
||||||
|
query_emb = self.encoder.encode(query)
|
||||||
|
scores, idxs = self.index.search(query_emb, k=num)
|
||||||
|
idxs = idxs[0]
|
||||||
|
scores = scores[0]
|
||||||
|
results = load_docs(self.corpus, idxs)
|
||||||
|
if return_score:
|
||||||
|
return results, scores.tolist()
|
||||||
|
else:
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _batch_search(self, query_list: List[str], num: int = None, return_score: bool = False):
|
||||||
|
if isinstance(query_list, str):
|
||||||
|
query_list = [query_list]
|
||||||
|
if num is None:
|
||||||
|
num = self.topk
|
||||||
|
|
||||||
|
results = []
|
||||||
|
scores = []
|
||||||
|
for start_idx in tqdm(range(0, len(query_list), self.batch_size), desc='Retrieval process: '):
|
||||||
|
query_batch = query_list[start_idx:start_idx + self.batch_size]
|
||||||
|
batch_emb = self.encoder.encode(query_batch)
|
||||||
|
batch_scores, batch_idxs = self.index.search(batch_emb, k=num)
|
||||||
|
batch_scores = batch_scores.tolist()
|
||||||
|
batch_idxs = batch_idxs.tolist()
|
||||||
|
|
||||||
|
# load_docs is not vectorized, but is a python list approach
|
||||||
|
flat_idxs = sum(batch_idxs, [])
|
||||||
|
batch_results = load_docs(self.corpus, flat_idxs)
|
||||||
|
# chunk them back
|
||||||
|
batch_results = [batch_results[i*num : (i+1)*num] for i in range(len(batch_idxs))]
|
||||||
|
|
||||||
|
results.extend(batch_results)
|
||||||
|
scores.extend(batch_scores)
|
||||||
|
if return_score:
|
||||||
|
return results, scores
|
||||||
|
else:
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_retriever(config):
|
||||||
|
if config.retrieval_method == "bm25":
|
||||||
|
return BM25Retriever(config)
|
||||||
|
else:
|
||||||
|
return DenseRetriever(config)
|
||||||
|
|
||||||
|
|
||||||
|
#####################################
|
||||||
|
# FastAPI server below
|
||||||
|
#####################################
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""
|
||||||
|
Minimal config class (simulating your argparse)
|
||||||
|
Replace this with your real arguments or load them dynamically.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
retrieval_method: str = "bm25",
|
||||||
|
retrieval_topk: int = 10,
|
||||||
|
index_path: str = "./index/bm25",
|
||||||
|
corpus_path: str = "./data/corpus.jsonl",
|
||||||
|
dataset_path: str = "./data",
|
||||||
|
data_split: str = "train",
|
||||||
|
faiss_gpu: bool = True,
|
||||||
|
retrieval_model_path: str = "./model",
|
||||||
|
retrieval_pooling_method: str = "mean",
|
||||||
|
retrieval_query_max_length: int = 256,
|
||||||
|
retrieval_use_fp16: bool = False,
|
||||||
|
retrieval_batch_size: int = 128
|
||||||
|
):
|
||||||
|
self.retrieval_method = retrieval_method
|
||||||
|
self.retrieval_topk = retrieval_topk
|
||||||
|
self.index_path = index_path
|
||||||
|
self.corpus_path = corpus_path
|
||||||
|
self.dataset_path = dataset_path
|
||||||
|
self.data_split = data_split
|
||||||
|
self.faiss_gpu = faiss_gpu
|
||||||
|
self.retrieval_model_path = retrieval_model_path
|
||||||
|
self.retrieval_pooling_method = retrieval_pooling_method
|
||||||
|
self.retrieval_query_max_length = retrieval_query_max_length
|
||||||
|
self.retrieval_use_fp16 = retrieval_use_fp16
|
||||||
|
self.retrieval_batch_size = retrieval_batch_size
|
||||||
|
|
||||||
|
|
||||||
|
class QueryRequest(BaseModel):
|
||||||
|
queries: List[str]
|
||||||
|
topk: Optional[int] = None
|
||||||
|
return_scores: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
# 1) Build a config (could also parse from arguments).
|
||||||
|
# In real usage, you'd parse your CLI arguments or environment variables.
|
||||||
|
config = Config(
|
||||||
|
retrieval_method = "e5", # or "dense"
|
||||||
|
index_path=args.index_path,
|
||||||
|
corpus_path=args.corpus_path,
|
||||||
|
retrieval_topk=args.topk,
|
||||||
|
faiss_gpu=True,
|
||||||
|
retrieval_model_path=args.retriever_model,
|
||||||
|
retrieval_pooling_method="mean",
|
||||||
|
retrieval_query_max_length=256,
|
||||||
|
retrieval_use_fp16=True,
|
||||||
|
retrieval_batch_size=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2) Instantiate a global retriever so it is loaded once and reused.
|
||||||
|
retriever = get_retriever(config)
|
||||||
|
|
||||||
|
@app.post("/retrieve")
|
||||||
|
def retrieve_endpoint(request: QueryRequest):
|
||||||
|
"""
|
||||||
|
Endpoint that accepts queries and performs retrieval.
|
||||||
|
Input format:
|
||||||
|
{
|
||||||
|
"queries": ["What is Python?", "Tell me about neural networks."],
|
||||||
|
"topk": 3,
|
||||||
|
"return_scores": true
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
if not request.topk:
|
||||||
|
request.topk = config.retrieval_topk # fallback to default
|
||||||
|
|
||||||
|
# Perform batch retrieval
|
||||||
|
results, scores = retriever.batch_search(
|
||||||
|
query_list=request.queries,
|
||||||
|
num=request.topk,
|
||||||
|
return_score=request.return_scores
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format response
|
||||||
|
resp = []
|
||||||
|
for i, single_result in enumerate(results):
|
||||||
|
if request.return_scores:
|
||||||
|
# If scores are returned, combine them with results
|
||||||
|
combined = []
|
||||||
|
for doc, score in zip(single_result, scores[i]):
|
||||||
|
combined.append({"document": doc, "score": score})
|
||||||
|
resp.append(combined)
|
||||||
|
else:
|
||||||
|
resp.append(single_result)
|
||||||
|
return {"result": resp}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 3) Launch the server. By default, it listens on http://127.0.0.1:8000
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
54
setup.py
Normal file
54
setup.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# setup.py is the fallback installation script when pyproject.toml does not work
|
||||||
|
from setuptools import setup, find_packages
|
||||||
|
import os
|
||||||
|
|
||||||
|
version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
with open(os.path.join(version_folder, 'verl/version/version')) as f:
|
||||||
|
__version__ = f.read().strip()
|
||||||
|
|
||||||
|
|
||||||
|
with open('requirements.txt') as f:
|
||||||
|
required = f.read().splitlines()
|
||||||
|
install_requires = [item.strip() for item in required if item.strip()[0] != '#']
|
||||||
|
|
||||||
|
extras_require = {
|
||||||
|
'test': ['pytest', 'yapf']
|
||||||
|
}
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
this_directory = Path(__file__).parent
|
||||||
|
long_description = (this_directory / "README.md").read_text()
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='verl',
|
||||||
|
version=__version__,
|
||||||
|
package_dir={'': '.'},
|
||||||
|
packages=find_packages(where='.'),
|
||||||
|
url='https://github.com/volcengine/verl',
|
||||||
|
license='Apache 2.0',
|
||||||
|
author='Bytedance - Seed - MLSys',
|
||||||
|
author_email='zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk',
|
||||||
|
description='veRL: Volcano Engine Reinforcement Learning for LLM',
|
||||||
|
install_requires=install_requires,
|
||||||
|
extras_require=extras_require,
|
||||||
|
package_data={'': ['version/*'],
|
||||||
|
'verl': ['trainer/config/*.yaml'],},
|
||||||
|
include_package_data=True,
|
||||||
|
long_description=long_description,
|
||||||
|
long_description_content_type='text/markdown'
|
||||||
|
)
|
||||||
90
train_ppo.sh
Normal file
90
train_ppo.sh
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||||
|
export DATA_DIR='data/nq_search'
|
||||||
|
|
||||||
|
WAND_PROJECT='Search-R1'
|
||||||
|
|
||||||
|
export BASE_MODEL='meta-llama/Llama-3.2-3B'
|
||||||
|
export EXPERIMENT_NAME=nq-search-r1-ppo-llama3.2-3b-em
|
||||||
|
# export BASE_MODEL='meta-llama/Llama-3.2-3B-Instruct'
|
||||||
|
# export EXPERIMENT_NAME=nq-search-r1-ppo-llama3.2-3b-it-em
|
||||||
|
# export BASE_MODEL='meta-llama/Llama-3.1-8B'
|
||||||
|
# export EXPERIMENT_NAME=nq-search-r1-ppo-llama3.1-8b-em
|
||||||
|
# export BASE_MODEL='meta-llama/Llama-3.1-8B-Instruct'
|
||||||
|
# export EXPERIMENT_NAME=nq-search-r1-ppo-llama3.1-8b-it-em
|
||||||
|
|
||||||
|
# export BASE_MODEL='Qwen/Qwen2.5-3B'
|
||||||
|
# export EXPERIMENT_NAME=nq-search-r1-ppo-qwen2.5-3b-em
|
||||||
|
# export BASE_MODEL='Qwen/Qwen2.5-3B-Instruct'
|
||||||
|
# export EXPERIMENT_NAME=nq-search-r1-ppo-qwen2.5-3b-it-em
|
||||||
|
# export BASE_MODEL='Qwen/Qwen2.5-7B'
|
||||||
|
# export EXPERIMENT_NAME=nq-search-r1-ppo-qwen2.5-7b-em
|
||||||
|
# export BASE_MODEL='Qwen/Qwen2.5-7B-Instruct'
|
||||||
|
# export EXPERIMENT_NAME=nq-search-r1-ppo-qwen2.5-7b-it-em
|
||||||
|
|
||||||
|
# set -x
|
||||||
|
export VLLM_ATTENTION_BACKEND=XFORMERS # vllm + qwen2-7b with flash_attn has some issues
|
||||||
|
|
||||||
|
# max_prompt_length = (config['training']['max_start_length'] + config['training']['max_response_length'] * (config['training']['max_turns'] - 1) + config['training']['max_obs_length'] * config['training']['max_turns'])
|
||||||
|
|
||||||
|
PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
|
||||||
|
data.train_files=$DATA_DIR/train.parquet \
|
||||||
|
data.val_files=$DATA_DIR/test.parquet \
|
||||||
|
data.train_data_num=null \
|
||||||
|
data.val_data_num=null \
|
||||||
|
data.train_batch_size=512 \
|
||||||
|
data.val_batch_size=256 \
|
||||||
|
data.max_prompt_length=4096 \
|
||||||
|
data.max_response_length=500 \
|
||||||
|
data.max_start_length=2048 \
|
||||||
|
data.max_obs_length=500 \
|
||||||
|
data.shuffle_train_dataloader=True \
|
||||||
|
algorithm.adv_estimator=gae \
|
||||||
|
actor_rollout_ref.model.path=$BASE_MODEL \
|
||||||
|
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||||
|
actor_rollout_ref.model.enable_gradient_checkpointing=true \
|
||||||
|
actor_rollout_ref.model.use_remove_padding=True \
|
||||||
|
actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.95 \
|
||||||
|
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
|
||||||
|
actor_rollout_ref.actor.ppo_micro_batch_size=64 \
|
||||||
|
actor_rollout_ref.actor.fsdp_config.param_offload=true \
|
||||||
|
actor_rollout_ref.actor.fsdp_config.grad_offload=true \
|
||||||
|
actor_rollout_ref.actor.fsdp_config.optimizer_offload=true \
|
||||||
|
actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \
|
||||||
|
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
|
||||||
|
actor_rollout_ref.rollout.name=vllm \
|
||||||
|
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
|
||||||
|
actor_rollout_ref.ref.log_prob_micro_batch_size=128 \
|
||||||
|
actor_rollout_ref.ref.fsdp_config.param_offload=True \
|
||||||
|
actor_rollout_ref.rollout.n_agent=1 \
|
||||||
|
actor_rollout_ref.rollout.temperature=1 \
|
||||||
|
actor_rollout_ref.actor.state_masking=true \
|
||||||
|
critic.optim.lr=1e-5 \
|
||||||
|
critic.model.use_remove_padding=True \
|
||||||
|
critic.optim.lr_warmup_steps_ratio=0.05 \
|
||||||
|
critic.model.path=$BASE_MODEL \
|
||||||
|
critic.model.enable_gradient_checkpointing=true \
|
||||||
|
critic.ppo_micro_batch_size=8 \
|
||||||
|
critic.model.fsdp_config.param_offload=true \
|
||||||
|
critic.model.fsdp_config.grad_offload=true \
|
||||||
|
critic.model.fsdp_config.optimizer_offload=true \
|
||||||
|
algorithm.kl_ctrl.kl_coef=0.001 \
|
||||||
|
algorithm.no_think_rl=false \
|
||||||
|
trainer.critic_warmup=0 \
|
||||||
|
trainer.logger=['wandb'] \
|
||||||
|
+trainer.val_only=false \
|
||||||
|
+trainer.val_before_train=true \
|
||||||
|
trainer.default_hdfs_dir=null \
|
||||||
|
trainer.n_gpus_per_node=8 \
|
||||||
|
trainer.nnodes=1 \
|
||||||
|
trainer.save_freq=100 \
|
||||||
|
trainer.test_freq=100 \
|
||||||
|
trainer.project_name=$WAND_PROJECT \
|
||||||
|
trainer.experiment_name=$EXPERIMENT_NAME \
|
||||||
|
trainer.total_epochs=15 \
|
||||||
|
trainer.total_training_steps=300 \
|
||||||
|
trainer.default_hdfs_dir=null \
|
||||||
|
trainer.default_local_dir=verl_checkpoints/$EXPERIMENT_NAME \
|
||||||
|
max_turns=2 \
|
||||||
|
retriever.url="http://127.0.0.1:8000/retrieve" \
|
||||||
|
retriever.topk=3 \
|
||||||
|
2>&1 | tee $EXPERIMENT_NAME.log
|
||||||
27
verl/__init__.py
Normal file
27
verl/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
with open(os.path.join(version_folder, 'version/version')) as f:
|
||||||
|
__version__ = f.read().strip()
|
||||||
|
|
||||||
|
from .protocol import DataProto
|
||||||
|
|
||||||
|
from .utils.logging_utils import set_basic_config
|
||||||
|
import logging
|
||||||
|
|
||||||
|
set_basic_config(level=logging.WARNING)
|
||||||
35
verl/models/README.md
Normal file
35
verl/models/README.md
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
# Models
|
||||||
|
Common modelzoo such as huggingface/transformers stuggles when using Pytorch native model parallelism. Following the design principle of vLLM, we keep a simple, parallelizable, highly-optimized with packed inputs in verl.
|
||||||
|
## Adding a New Huggingface Model
|
||||||
|
### Step 1: Copy the model file from HF to verl
|
||||||
|
- Add a new file under verl/models/hf
|
||||||
|
- Copy ONLY the model file from huggingface/transformers/models to verl/models/hf
|
||||||
|
|
||||||
|
### Step 2: Modify the model file to use packed inputs
|
||||||
|
- Remove all the code related to inference (kv cache)
|
||||||
|
- Modify the inputs to include only
|
||||||
|
- input_ids (total_nnz,)
|
||||||
|
- cu_seqlens (total_nnz + 1,)
|
||||||
|
- max_seqlen_in_batch: int
|
||||||
|
- Note that this requires using flash attention with causal mask.
|
||||||
|
|
||||||
|
### Step 2.5: Add tests
|
||||||
|
- Add a test to compare this version and the huggingface version
|
||||||
|
- Following the infrastructure and add tests to tests/models/hf
|
||||||
|
|
||||||
|
### Step 3: Add a function to apply tensor parallelism
|
||||||
|
- Please follow
|
||||||
|
- https://pytorch.org/docs/stable/distributed.tensor.parallel.html
|
||||||
|
- https://pytorch.org/tutorials/intermediate/TP_tutorial.html
|
||||||
|
- General comments
|
||||||
|
- Tensor Parallelism in native Pytorch is NOT auto-parallelism. The way it works is to specify how model parameters and input/output reshards using configs. These configs are then registered as hooks to perform input/output resharding before/after model forward.
|
||||||
|
|
||||||
|
### Step 4: Add a function to apply data parallelism
|
||||||
|
- Please use FSDP2 APIs
|
||||||
|
- See demo here https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py#L413
|
||||||
|
|
||||||
|
### Step 5: Add a function to apply pipeline parallelism
|
||||||
|
- Comes in Pytorch 2.4
|
||||||
|
- Currently only in alpha in nightly version
|
||||||
|
- Check torchtitan for more details
|
||||||
|
|
||||||
13
verl/models/__init__.py
Normal file
13
verl/models/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
13
verl/models/llama/__init__.py
Normal file
13
verl/models/llama/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
24
verl/models/llama/megatron/__init__.py
Normal file
24
verl/models/llama/megatron/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .modeling_llama_megatron import (
|
||||||
|
# original model with megatron
|
||||||
|
ParallelLlamaModel,
|
||||||
|
ParallelLlamaForCausalLM,
|
||||||
|
# rmpad with megatron
|
||||||
|
ParallelLlamaForCausalLMRmPad,
|
||||||
|
ParallelLlamaForValueRmPad,
|
||||||
|
# rmpad with megatron and pipeline parallelism
|
||||||
|
ParallelLlamaForCausalLMRmPadPP,
|
||||||
|
ParallelLlamaForValueRmPadPP)
|
||||||
13
verl/models/llama/megatron/checkpoint_utils/__init__.py
Normal file
13
verl/models/llama/megatron/checkpoint_utils/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
446
verl/models/llama/megatron/checkpoint_utils/llama_loader.py
Normal file
446
verl/models/llama/megatron/checkpoint_utils/llama_loader.py
Normal file
@@ -0,0 +1,446 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
from typing import Dict, Any, Callable, Optional
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
|
def _megatron_calc_layer_map(config):
|
||||||
|
"""Calculate the mapping of global layer_idx to local layer_idx
|
||||||
|
Returns:
|
||||||
|
layer_map (Dict: int -> tuple(int, int, int)):
|
||||||
|
mapping from the global layer index to
|
||||||
|
a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
|
||||||
|
"""
|
||||||
|
import megatron
|
||||||
|
from megatron.core import mpu
|
||||||
|
|
||||||
|
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
||||||
|
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
|
||||||
|
|
||||||
|
layer_map = dict()
|
||||||
|
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
|
||||||
|
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
|
||||||
|
|
||||||
|
for pp_rank_idx in range(pp_size):
|
||||||
|
for virtual_pp_rank_idx in range(virtual_pp_size):
|
||||||
|
layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) +
|
||||||
|
pp_rank_idx * num_layers_per_model)
|
||||||
|
for layer_idx in range(num_layers_per_model):
|
||||||
|
layer_map[layer_offset + layer_idx] = (
|
||||||
|
pp_rank_idx,
|
||||||
|
virtual_pp_rank_idx,
|
||||||
|
layer_idx,
|
||||||
|
)
|
||||||
|
return layer_map
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict_to_megatron_llama(state_dict, wrapped_models, config, params_dtype, is_value_model=False):
|
||||||
|
"""Load merged state_dict to sharded Megatron module in training.
|
||||||
|
"""
|
||||||
|
import megatron
|
||||||
|
from megatron.core import mpu
|
||||||
|
from megatron.utils import print_rank_0, unwrap_model
|
||||||
|
from megatron.core.transformer.module import Float16Module
|
||||||
|
from megatron.core import DistributedDataParallel as LocalDDP
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as torchDDP
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
def _get_gpt_model(model):
|
||||||
|
return model
|
||||||
|
|
||||||
|
def broadcast_params(module):
|
||||||
|
for param in module.parameters():
|
||||||
|
torch.distributed.broadcast(param.data,
|
||||||
|
src=mpu.get_data_parallel_src_rank(),
|
||||||
|
group=mpu.get_data_parallel_group())
|
||||||
|
|
||||||
|
dp_rank = mpu.get_data_parallel_rank()
|
||||||
|
pp_rank = mpu.get_pipeline_model_parallel_rank()
|
||||||
|
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
||||||
|
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
|
||||||
|
mp_group = mpu.get_model_parallel_group()
|
||||||
|
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
|
||||||
|
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
|
||||||
|
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
|
||||||
|
|
||||||
|
if not isinstance(wrapped_models, (list, tuple)):
|
||||||
|
wrapped_models = list(wrapped_models)
|
||||||
|
|
||||||
|
assert len(wrapped_models) == virtual_pp_size
|
||||||
|
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
|
||||||
|
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
|
||||||
|
|
||||||
|
models = [None] * len(wrapped_models)
|
||||||
|
|
||||||
|
for i, wrapped_model in enumerate(wrapped_models):
|
||||||
|
models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
|
||||||
|
gpt_model_module = _get_gpt_model(models[i])
|
||||||
|
assert len(gpt_model_module.model.layers) == num_layers_per_model
|
||||||
|
|
||||||
|
def _broadcast_tensor(tensor, name) -> torch.Tensor:
|
||||||
|
"""broadcast tensor from rank0 across mp_group"""
|
||||||
|
nonlocal state_dict
|
||||||
|
nonlocal mp_group
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
if name in state_dict:
|
||||||
|
weight = state_dict[name]
|
||||||
|
tensor_shape = weight.shape
|
||||||
|
else:
|
||||||
|
tensor_shape = None
|
||||||
|
else:
|
||||||
|
weight = None
|
||||||
|
tensor_shape = None
|
||||||
|
|
||||||
|
obj_list = [tensor_shape]
|
||||||
|
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
|
||||||
|
tensor_shape = obj_list[0]
|
||||||
|
|
||||||
|
if tensor_shape is None:
|
||||||
|
# all or none ranks in the mp_group should reach here
|
||||||
|
print_rank_0(f"tensor:[{name}] not in state_dict, skip load")
|
||||||
|
return
|
||||||
|
|
||||||
|
if tensor is None:
|
||||||
|
tensor = torch.empty(
|
||||||
|
tensor_shape,
|
||||||
|
dtype=params_dtype,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
tensor.data.copy_(weight)
|
||||||
|
dist.broadcast(tensor, src=0, group=mp_group)
|
||||||
|
|
||||||
|
def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
|
||||||
|
"""broadcast tensor in tp shards across mp_group"""
|
||||||
|
nonlocal state_dict
|
||||||
|
nonlocal mp_group
|
||||||
|
tp_rank = mpu.get_tensor_model_parallel_rank()
|
||||||
|
tp_size = mpu.get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
if name in state_dict:
|
||||||
|
full_weight = state_dict[name]
|
||||||
|
|
||||||
|
if mutate_func is not None:
|
||||||
|
full_weight = mutate_func(full_weight)
|
||||||
|
tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
|
||||||
|
chunk_shape = tensor_chunk[0].shape
|
||||||
|
else:
|
||||||
|
chunk_shape = None
|
||||||
|
else:
|
||||||
|
chunk_shape = None
|
||||||
|
|
||||||
|
obj_list = [chunk_shape]
|
||||||
|
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
|
||||||
|
chunk_shape = obj_list[0]
|
||||||
|
if chunk_shape is None:
|
||||||
|
# all or none ranks in the mp_group should reach here
|
||||||
|
print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
|
||||||
|
return
|
||||||
|
|
||||||
|
if tensor is None:
|
||||||
|
sync_tensor = torch.empty(
|
||||||
|
chunk_shape,
|
||||||
|
dtype=params_dtype,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert (tensor.shape == chunk_shape
|
||||||
|
), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
|
||||||
|
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
|
||||||
|
|
||||||
|
for i in range(tp_size):
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
sync_tensor.data.copy_(tensor_chunk[i])
|
||||||
|
dist.broadcast(sync_tensor, src=0, group=mp_group)
|
||||||
|
if (i == tp_rank) and (tensor is not None):
|
||||||
|
tensor.data.copy_(sync_tensor)
|
||||||
|
|
||||||
|
def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
|
||||||
|
"""broadcast tensor in tp shards across mp_group"""
|
||||||
|
nonlocal state_dict
|
||||||
|
nonlocal mp_group
|
||||||
|
tp_rank = mpu.get_tensor_model_parallel_rank()
|
||||||
|
tp_size = mpu.get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
if name in state_dict:
|
||||||
|
full_weight = state_dict[name]
|
||||||
|
if mutate_func is not None:
|
||||||
|
full_weight = mutate_func(full_weight)
|
||||||
|
tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
|
||||||
|
chunk_shape = tensor_chunk[0].shape
|
||||||
|
else:
|
||||||
|
chunk_shape = None
|
||||||
|
else:
|
||||||
|
chunk_shape = None
|
||||||
|
|
||||||
|
obj_list = [chunk_shape]
|
||||||
|
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
|
||||||
|
chunk_shape = obj_list[0]
|
||||||
|
if chunk_shape is None:
|
||||||
|
# all or none ranks in the mp_group should reach here
|
||||||
|
print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
|
||||||
|
return
|
||||||
|
|
||||||
|
if tensor is None:
|
||||||
|
sync_tensor = torch.empty(
|
||||||
|
chunk_shape,
|
||||||
|
dtype=params_dtype,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert (tensor.shape == chunk_shape
|
||||||
|
), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
|
||||||
|
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
|
||||||
|
|
||||||
|
for i in range(tp_size):
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
sync_tensor.data.copy_(tensor_chunk[i])
|
||||||
|
dist.broadcast(sync_tensor, src=0, group=mp_group)
|
||||||
|
if (i == tp_rank) and (tensor is not None):
|
||||||
|
tensor.data.copy_(sync_tensor)
|
||||||
|
|
||||||
|
def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:
|
||||||
|
"""broadcast tensor in tp shards across mp_group"""
|
||||||
|
nonlocal state_dict
|
||||||
|
nonlocal mp_group
|
||||||
|
tp_rank = mpu.get_tensor_model_parallel_rank()
|
||||||
|
tp_size = mpu.get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
gate_weight = state_dict[gate_name]
|
||||||
|
up_weight = state_dict[up_name]
|
||||||
|
new_gate_up_weight = torch.empty(config.intermediate_size * 2,
|
||||||
|
config.hidden_size,
|
||||||
|
dtype=params_dtype,
|
||||||
|
device=torch.cuda.current_device())
|
||||||
|
for i in range(tp_size):
|
||||||
|
intermediate_size_tp = config.intermediate_size // tp_size
|
||||||
|
gate_weight_tp = gate_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp]
|
||||||
|
up_weight_tp = up_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp]
|
||||||
|
new_gate_up_weight[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)].copy_(
|
||||||
|
torch.cat([gate_weight_tp, up_weight_tp], dim=0))
|
||||||
|
|
||||||
|
tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)
|
||||||
|
chunk_shape = tensor_chunk[0].shape
|
||||||
|
else:
|
||||||
|
chunk_shape = None
|
||||||
|
|
||||||
|
obj_list = [chunk_shape]
|
||||||
|
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
|
||||||
|
chunk_shape = obj_list[0]
|
||||||
|
if chunk_shape is None:
|
||||||
|
# all or none ranks in the mp_group should reach here
|
||||||
|
print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading")
|
||||||
|
return
|
||||||
|
|
||||||
|
if tensor is None:
|
||||||
|
sync_tensor = torch.empty(
|
||||||
|
chunk_shape,
|
||||||
|
dtype=params_dtype,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
tensor.shape == chunk_shape
|
||||||
|
), f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}"
|
||||||
|
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
|
||||||
|
|
||||||
|
for i in range(tp_size):
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
sync_tensor.data.copy_(tensor_chunk[i])
|
||||||
|
dist.broadcast(sync_tensor, src=0, group=mp_group)
|
||||||
|
if (i == tp_rank) and (tensor is not None):
|
||||||
|
tensor.data.copy_(sync_tensor)
|
||||||
|
|
||||||
|
def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor:
|
||||||
|
"""broadcast tensor in tp shards across mp_group"""
|
||||||
|
nonlocal state_dict
|
||||||
|
nonlocal mp_group
|
||||||
|
tp_rank = mpu.get_tensor_model_parallel_rank()
|
||||||
|
tp_size = mpu.get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
assert (q_name in state_dict and k_name in state_dict and v_name in state_dict)
|
||||||
|
full_weight_q = state_dict[q_name]
|
||||||
|
full_weight_k = state_dict[k_name]
|
||||||
|
full_weight_v = state_dict[v_name]
|
||||||
|
|
||||||
|
hidden_size_per_head = config.hidden_size // config.num_attention_heads
|
||||||
|
|
||||||
|
if config.num_key_value_heads >= tp_size:
|
||||||
|
q_size_tp = config.hidden_size // tp_size
|
||||||
|
kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
|
||||||
|
total_size = q_size_tp + 2 * kv_size_tp
|
||||||
|
new_weight_qkv = torch.empty(total_size * tp_size,
|
||||||
|
config.hidden_size,
|
||||||
|
dtype=params_dtype,
|
||||||
|
device=torch.cuda.current_device())
|
||||||
|
for i in range(tp_size):
|
||||||
|
q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp]
|
||||||
|
k_part = full_weight_k[i * kv_size_tp:(i + 1) * kv_size_tp]
|
||||||
|
v_part = full_weight_v[i * kv_size_tp:(i + 1) * kv_size_tp]
|
||||||
|
new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part],
|
||||||
|
dim=0))
|
||||||
|
|
||||||
|
else:
|
||||||
|
q_size_tp = config.hidden_size // tp_size
|
||||||
|
kv_size_tp = hidden_size_per_head
|
||||||
|
total_size = q_size_tp + 2 * kv_size_tp
|
||||||
|
new_weight_qkv = torch.empty(total_size * tp_size,
|
||||||
|
config.hidden_size,
|
||||||
|
dtype=params_dtype,
|
||||||
|
device=torch.cuda.current_device())
|
||||||
|
for i in range(tp_size):
|
||||||
|
q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp]
|
||||||
|
start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head
|
||||||
|
end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head
|
||||||
|
k_part = full_weight_k[start_idx:end_idx]
|
||||||
|
v_part = full_weight_v[start_idx:end_idx]
|
||||||
|
new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part],
|
||||||
|
dim=0))
|
||||||
|
|
||||||
|
tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)
|
||||||
|
chunk_shape = tensor_chunk[0].shape
|
||||||
|
else:
|
||||||
|
chunk_shape = None
|
||||||
|
|
||||||
|
obj_list = [chunk_shape]
|
||||||
|
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
|
||||||
|
chunk_shape = obj_list[0]
|
||||||
|
if chunk_shape is None:
|
||||||
|
# all or none ranks in the mp_group should reach here
|
||||||
|
print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
|
||||||
|
return
|
||||||
|
|
||||||
|
if tensor is None:
|
||||||
|
sync_tensor = torch.empty(
|
||||||
|
chunk_shape,
|
||||||
|
dtype=params_dtype,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert (tensor.shape == chunk_shape
|
||||||
|
), f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}"
|
||||||
|
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
|
||||||
|
|
||||||
|
for i in range(tp_size):
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
sync_tensor.data.copy_(tensor_chunk[i])
|
||||||
|
dist.broadcast(sync_tensor, src=0, group=mp_group)
|
||||||
|
if (i == tp_rank) and (tensor is not None):
|
||||||
|
tensor.data.copy_(sync_tensor)
|
||||||
|
|
||||||
|
if dp_rank == 0:
|
||||||
|
# Embeddings
|
||||||
|
# -------------------
|
||||||
|
print_rank_0("loading embeddings...")
|
||||||
|
gpt_model_module = _get_gpt_model(models[0])
|
||||||
|
embed_tokens_weight = None
|
||||||
|
if pp_rank == 0:
|
||||||
|
embed_tokens_weight = gpt_model_module.model.embed_tokens.weight
|
||||||
|
_broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight")
|
||||||
|
|
||||||
|
# Transformer layers
|
||||||
|
# -------------------
|
||||||
|
layer_map = _megatron_calc_layer_map(config)
|
||||||
|
|
||||||
|
for layer in range(config.num_hidden_layers):
|
||||||
|
print_rank_0(f"loading layer #{layer}...")
|
||||||
|
layer_name = f"model.layers.{layer}"
|
||||||
|
dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]
|
||||||
|
|
||||||
|
gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])
|
||||||
|
sync_layer = gpt_model_module.model.layers[dst_layer_idx]
|
||||||
|
|
||||||
|
_broadcast_tensor(
|
||||||
|
sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,
|
||||||
|
f"{layer_name}.input_layernorm.weight",
|
||||||
|
)
|
||||||
|
|
||||||
|
_broadcast_tp_shard_tensor_qkv(
|
||||||
|
sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,
|
||||||
|
f"{layer_name}.self_attn.q_proj.weight",
|
||||||
|
f"{layer_name}.self_attn.k_proj.weight",
|
||||||
|
f"{layer_name}.self_attn.v_proj.weight",
|
||||||
|
)
|
||||||
|
|
||||||
|
_broadcast_tp_shard_tensor(
|
||||||
|
sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,
|
||||||
|
f"{layer_name}.self_attn.o_proj.weight",
|
||||||
|
chunk_dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
_broadcast_tensor(
|
||||||
|
sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,
|
||||||
|
f"{layer_name}.post_attention_layernorm.weight",
|
||||||
|
)
|
||||||
|
|
||||||
|
_broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,
|
||||||
|
f"{layer_name}.mlp.gate_proj.weight", f"{layer_name}.mlp.up_proj.weight")
|
||||||
|
|
||||||
|
_broadcast_tp_shard_tensor(
|
||||||
|
sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,
|
||||||
|
f"{layer_name}.mlp.down_proj.weight",
|
||||||
|
chunk_dim=1,
|
||||||
|
)
|
||||||
|
# Final Layernorm
|
||||||
|
# -------------------
|
||||||
|
print_rank_0("loading final layernorm...")
|
||||||
|
gpt_model_module = _get_gpt_model(models[-1])
|
||||||
|
_broadcast_tensor(
|
||||||
|
getattr(gpt_model_module.model.norm, "weight", None),
|
||||||
|
"model.norm.weight",
|
||||||
|
)
|
||||||
|
|
||||||
|
print_rank_0("loading lm_head...")
|
||||||
|
lm_head_weight = None
|
||||||
|
if pp_rank + 1 == pp_size:
|
||||||
|
lm_head_weight = gpt_model_module.lm_head.weight
|
||||||
|
|
||||||
|
if is_value_model:
|
||||||
|
# if torch.distributed.get_rank() == 0:
|
||||||
|
if 'lm_head.weight' in state_dict and state_dict['lm_head.weight'].shape[0] == 1:
|
||||||
|
_broadcast_tensor(lm_head_weight, "lm_head.weight")
|
||||||
|
elif 'reward_head.weight' in state_dict and state_dict['reward_head.weight'].shape[0] == 1:
|
||||||
|
_broadcast_tensor(lm_head_weight, "reward_head.weight")
|
||||||
|
print_rank_0('load lm_head from value_head weight')
|
||||||
|
else:
|
||||||
|
_broadcast_tensor(None, "lm_head.weight")
|
||||||
|
print_rank_0('fail to match lm_head in value_model')
|
||||||
|
# else:
|
||||||
|
|
||||||
|
# _broadcast_tensor(lm_head_weight, "lm_head.weight")
|
||||||
|
|
||||||
|
else:
|
||||||
|
_broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight")
|
||||||
|
dist.barrier()
|
||||||
|
# Broadcast weights inside data parallel groups
|
||||||
|
for wrapped_model in wrapped_models:
|
||||||
|
broadcast_params(wrapped_model)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s")
|
||||||
449
verl/models/llama/megatron/checkpoint_utils/llama_saver.py
Normal file
449
verl/models/llama/megatron/checkpoint_utils/llama_saver.py
Normal file
@@ -0,0 +1,449 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import megatron
|
||||||
|
from megatron.core import mpu
|
||||||
|
from megatron.utils import print_rank_0, unwrap_model
|
||||||
|
from megatron.model import Float16Module
|
||||||
|
from megatron.model import DistributedDataParallel as LocalDDP
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as torchDDP
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
import torch.distributed as dist
|
||||||
|
from megatron import get_args
|
||||||
|
|
||||||
|
|
||||||
|
def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0):
|
||||||
|
"""given TP,DP,PP rank to get the global rank."""
|
||||||
|
|
||||||
|
args = get_args()
|
||||||
|
tp_size = mpu.get_tensor_model_parallel_world_size()
|
||||||
|
dp_size = mpu.get_data_parallel_world_size()
|
||||||
|
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
||||||
|
assert (tp_size * dp_size * pp_size == torch.distributed.get_world_size()
|
||||||
|
), f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}"
|
||||||
|
if args.switch_dp_and_pp_grouping:
|
||||||
|
# TP-PP-DP grouping
|
||||||
|
return (dp_rank * pp_size + pp_rank) * tp_size + tp_rank
|
||||||
|
else:
|
||||||
|
# TP-DP-PP grouping
|
||||||
|
return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank
|
||||||
|
|
||||||
|
|
||||||
|
def _megatron_calc_layer_map(config):
|
||||||
|
"""Calculate the mapping of global layer_idx to local layer_idx
|
||||||
|
Returns:
|
||||||
|
layer_map (Dict: int -> tuple(int, int, int)):
|
||||||
|
mapping from the global layer index to
|
||||||
|
a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
|
||||||
|
"""
|
||||||
|
import megatron
|
||||||
|
from megatron.core import mpu
|
||||||
|
|
||||||
|
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
||||||
|
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
|
||||||
|
|
||||||
|
args = megatron.get_args()
|
||||||
|
layer_map = dict()
|
||||||
|
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
|
||||||
|
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
|
||||||
|
|
||||||
|
for pp_rank_idx in range(pp_size):
|
||||||
|
for virtual_pp_rank_idx in range(virtual_pp_size):
|
||||||
|
layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) +
|
||||||
|
pp_rank_idx * num_layers_per_model)
|
||||||
|
for layer_idx in range(num_layers_per_model):
|
||||||
|
layer_map[layer_offset + layer_idx] = (
|
||||||
|
pp_rank_idx,
|
||||||
|
virtual_pp_rank_idx,
|
||||||
|
layer_idx,
|
||||||
|
)
|
||||||
|
return layer_map
|
||||||
|
|
||||||
|
|
||||||
|
def merge_megatron_ckpt_llama(wrapped_models, config, is_value_model=False, dtype='bf16'):
|
||||||
|
"""Merge sharded parameters of a Megatron module into a merged checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wrapped_modelss (list of megatron.model.DistributedDataParallel):
|
||||||
|
The local DDP wrapped megatron modules.
|
||||||
|
dtype (str or None):
|
||||||
|
The data type of state_dict. if None, the data type of the original parameters
|
||||||
|
is used.
|
||||||
|
gpt_model_key: key to access model
|
||||||
|
Returns:
|
||||||
|
state_dict (dict):
|
||||||
|
The merged state_dict in rank 0, and an empty dictionary in other ranks.
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
args = megatron.get_args()
|
||||||
|
|
||||||
|
def _get_gpt_model(model):
|
||||||
|
return model
|
||||||
|
|
||||||
|
dp_rank = mpu.get_data_parallel_rank()
|
||||||
|
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
||||||
|
pp_rank = mpu.get_pipeline_model_parallel_rank()
|
||||||
|
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
|
||||||
|
mp_group = mpu.get_model_parallel_group()
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
|
||||||
|
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
|
||||||
|
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
|
||||||
|
|
||||||
|
if not isinstance(wrapped_models, (list, tuple)):
|
||||||
|
wrapped_models = list(wrapped_models)
|
||||||
|
|
||||||
|
assert len(wrapped_models) == virtual_pp_size
|
||||||
|
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
|
||||||
|
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
|
||||||
|
|
||||||
|
models = [None] * len(wrapped_models)
|
||||||
|
|
||||||
|
for i, wrapped_model in enumerate(wrapped_models):
|
||||||
|
models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
|
||||||
|
assert len(models[i].model.layers
|
||||||
|
) == num_layers_per_model, 'len model layers {} not equal to num_layers_per_model {}'.format(
|
||||||
|
len(models[i].model.layers), num_layers_per_model)
|
||||||
|
|
||||||
|
state_dict = dict()
|
||||||
|
|
||||||
|
def _get_cpu_tensor(tensor: torch.Tensor):
|
||||||
|
if tensor is None:
|
||||||
|
return None
|
||||||
|
if tensor.device == torch.device("cpu"):
|
||||||
|
return tensor.detach().clone()
|
||||||
|
return tensor.detach().cpu()
|
||||||
|
|
||||||
|
def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor:
|
||||||
|
"""broadcast tensor across mp_group"""
|
||||||
|
nonlocal state_dict
|
||||||
|
nonlocal mp_group
|
||||||
|
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
|
||||||
|
|
||||||
|
if torch.distributed.get_rank() == src_rank:
|
||||||
|
if tensor is None:
|
||||||
|
weight = None
|
||||||
|
tensor_shape = None
|
||||||
|
else:
|
||||||
|
weight = tensor
|
||||||
|
tensor_shape = weight.shape
|
||||||
|
else:
|
||||||
|
weight = None
|
||||||
|
tensor_shape = None
|
||||||
|
|
||||||
|
obj_list = [tensor_shape]
|
||||||
|
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
|
||||||
|
tensor_shape = obj_list[0]
|
||||||
|
|
||||||
|
if tensor_shape is None:
|
||||||
|
# all or none ranks in the mp_group should reach here
|
||||||
|
print_rank_0(f"tensor:[{name}] not exist, skip collect")
|
||||||
|
return
|
||||||
|
|
||||||
|
if weight is None:
|
||||||
|
weight = torch.empty(
|
||||||
|
tensor_shape,
|
||||||
|
dtype=args.params_dtype,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
dist.broadcast(weight, src=src_rank, group=mp_group)
|
||||||
|
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
state_dict[name] = _get_cpu_tensor(weight)
|
||||||
|
|
||||||
|
def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor:
|
||||||
|
"""broadcast tensor in tp shards across mp_group"""
|
||||||
|
nonlocal state_dict
|
||||||
|
nonlocal mp_group
|
||||||
|
tp_rank = mpu.get_tensor_model_parallel_rank()
|
||||||
|
tp_size = mpu.get_tensor_model_parallel_world_size()
|
||||||
|
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
|
||||||
|
|
||||||
|
if torch.distributed.get_rank() == src_rank:
|
||||||
|
chunk_shape = tensor.shape
|
||||||
|
else:
|
||||||
|
chunk_shape = None
|
||||||
|
|
||||||
|
obj_list = [chunk_shape]
|
||||||
|
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
|
||||||
|
chunk_shape = obj_list[0]
|
||||||
|
if chunk_shape is None:
|
||||||
|
# all or none ranks in the mp_group should reach here
|
||||||
|
print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting")
|
||||||
|
return
|
||||||
|
|
||||||
|
buffer_tensor = torch.empty(
|
||||||
|
chunk_shape,
|
||||||
|
dtype=args.params_dtype,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk_tensors = [None] * tp_size
|
||||||
|
|
||||||
|
for i in range(tp_size):
|
||||||
|
cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
|
||||||
|
sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
|
||||||
|
dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
|
||||||
|
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
|
||||||
|
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
full_tensor = torch.concat(chunk_tensors, dim=concat_dim)
|
||||||
|
if mutate_func is not None:
|
||||||
|
full_tensor = mutate_func(full_tensor)
|
||||||
|
state_dict[name] = full_tensor
|
||||||
|
|
||||||
|
def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor:
|
||||||
|
"""broadcast tensor in tp shards across mp_group"""
|
||||||
|
nonlocal state_dict
|
||||||
|
nonlocal mp_group
|
||||||
|
tp_rank = mpu.get_tensor_model_parallel_rank()
|
||||||
|
tp_size = mpu.get_tensor_model_parallel_world_size()
|
||||||
|
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
|
||||||
|
|
||||||
|
if torch.distributed.get_rank() == src_rank:
|
||||||
|
chunk_shape = tensor.shape
|
||||||
|
else:
|
||||||
|
chunk_shape = None
|
||||||
|
|
||||||
|
obj_list = [chunk_shape]
|
||||||
|
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
|
||||||
|
chunk_shape = obj_list[0]
|
||||||
|
if chunk_shape is None:
|
||||||
|
# all or none ranks in the mp_group should reach here
|
||||||
|
print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting")
|
||||||
|
return
|
||||||
|
|
||||||
|
buffer_tensor = torch.empty(
|
||||||
|
chunk_shape,
|
||||||
|
dtype=args.params_dtype,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk_tensors = [None] * tp_size
|
||||||
|
|
||||||
|
for i in range(tp_size):
|
||||||
|
cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
|
||||||
|
sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
|
||||||
|
dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
|
||||||
|
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
|
||||||
|
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
full_tensor = torch.concat(chunk_tensors, dim=0)
|
||||||
|
intermediate_size_tp = config.intermediate_size // tp_size
|
||||||
|
gate_weight_list = []
|
||||||
|
up_weight_list = []
|
||||||
|
for i in range(tp_size):
|
||||||
|
gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)]
|
||||||
|
gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]
|
||||||
|
up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]
|
||||||
|
gate_weight_list.append(gate_weight_tp)
|
||||||
|
up_weight_list.append(up_weight_tp)
|
||||||
|
|
||||||
|
state_dict[gate_name] = torch.cat(gate_weight_list, dim=0)
|
||||||
|
state_dict[up_name] = torch.cat(up_weight_list, dim=0)
|
||||||
|
|
||||||
|
def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank):
|
||||||
|
"""broadcast tensor in tp shards across mp_group"""
|
||||||
|
nonlocal state_dict
|
||||||
|
nonlocal mp_group
|
||||||
|
tp_rank = mpu.get_tensor_model_parallel_rank()
|
||||||
|
tp_size = mpu.get_tensor_model_parallel_world_size()
|
||||||
|
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
|
||||||
|
|
||||||
|
if torch.distributed.get_rank() == src_rank:
|
||||||
|
chunk_shape = tensor.shape
|
||||||
|
else:
|
||||||
|
chunk_shape = None
|
||||||
|
|
||||||
|
obj_list = [chunk_shape]
|
||||||
|
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
|
||||||
|
chunk_shape = obj_list[0]
|
||||||
|
if chunk_shape is None:
|
||||||
|
# all or none ranks in the mp_group should reach here
|
||||||
|
print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting")
|
||||||
|
return
|
||||||
|
|
||||||
|
buffer_tensor = torch.empty(
|
||||||
|
chunk_shape,
|
||||||
|
dtype=args.params_dtype,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk_tensors = [None] * tp_size
|
||||||
|
|
||||||
|
for i in range(tp_size):
|
||||||
|
cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
|
||||||
|
sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
|
||||||
|
dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
|
||||||
|
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
|
||||||
|
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
full_tensor = torch.concat(chunk_tensors, dim=0)
|
||||||
|
q_weight_list = []
|
||||||
|
k_weight_list = []
|
||||||
|
v_weight_list = []
|
||||||
|
hidden_size_per_head = config.hidden_size // config.num_attention_heads
|
||||||
|
|
||||||
|
if config.num_key_value_heads >= tp_size:
|
||||||
|
q_size_tp = config.hidden_size // tp_size
|
||||||
|
kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
|
||||||
|
total_size = q_size_tp + 2 * kv_size_tp
|
||||||
|
for i in range(tp_size):
|
||||||
|
qkv_part = full_tensor[i * total_size:(i + 1) * total_size]
|
||||||
|
q_part = qkv_part[:q_size_tp]
|
||||||
|
k_part = qkv_part[q_size_tp:q_size_tp + kv_size_tp]
|
||||||
|
v_part = qkv_part[q_size_tp + kv_size_tp:total_size]
|
||||||
|
q_weight_list.append(q_part)
|
||||||
|
k_weight_list.append(k_part)
|
||||||
|
v_weight_list.append(v_part)
|
||||||
|
else:
|
||||||
|
q_size_tp = config.hidden_size // tp_size
|
||||||
|
kv_size_tp = hidden_size_per_head
|
||||||
|
total_size = q_size_tp + 2 * kv_size_tp
|
||||||
|
for i in range(tp_size):
|
||||||
|
qkv_part = full_tensor[i * total_size:(i + 1) * total_size]
|
||||||
|
q_part = qkv_part[:q_size_tp]
|
||||||
|
k_part = qkv_part[q_size_tp:q_size_tp + kv_size_tp]
|
||||||
|
v_part = qkv_part[q_size_tp + kv_size_tp:total_size]
|
||||||
|
q_weight_list.append(q_part)
|
||||||
|
if i * config.num_key_value_heads % tp_size == 0:
|
||||||
|
k_weight_list.append(k_part)
|
||||||
|
v_weight_list.append(v_part)
|
||||||
|
|
||||||
|
state_dict[q_name] = torch.cat(q_weight_list, dim=0)
|
||||||
|
state_dict[k_name] = torch.cat(k_weight_list, dim=0)
|
||||||
|
state_dict[v_name] = torch.cat(v_weight_list, dim=0)
|
||||||
|
|
||||||
|
# empty cache before collecting weights
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
# Embeddings
|
||||||
|
# -------------------
|
||||||
|
if dp_rank == 0:
|
||||||
|
# Embeddings
|
||||||
|
# -------------------
|
||||||
|
print_rank_0("collecting embeddings...")
|
||||||
|
gpt_model_module = _get_gpt_model(models[0])
|
||||||
|
_broadcast_tp_shard_tensor(
|
||||||
|
gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None,
|
||||||
|
"model.embed_tokens.weight",
|
||||||
|
src_pp_rank=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Transformer layers
|
||||||
|
# -------------------
|
||||||
|
layer_map = _megatron_calc_layer_map(config)
|
||||||
|
for layer in range(config.num_hidden_layers):
|
||||||
|
print_rank_0(f"collecting layer #{layer}...")
|
||||||
|
layer_name = f"model.layers.{layer}"
|
||||||
|
src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer]
|
||||||
|
|
||||||
|
gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank])
|
||||||
|
sync_layer = gpt_model_module.model.layers[src_layer_idx]
|
||||||
|
|
||||||
|
_broadcast_tensor(
|
||||||
|
sync_layer.input_layernorm.weight,
|
||||||
|
f"{layer_name}.input_layernorm.weight",
|
||||||
|
src_pp_rank=src_pp_rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
_broadcast_tp_shard_tensor_qkv(
|
||||||
|
sync_layer.self_attn.qkv_proj.weight,
|
||||||
|
f"{layer_name}.self_attn.q_proj.weight",
|
||||||
|
f"{layer_name}.self_attn.k_proj.weight",
|
||||||
|
f"{layer_name}.self_attn.v_proj.weight",
|
||||||
|
src_pp_rank=src_pp_rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
_broadcast_tp_shard_tensor(
|
||||||
|
sync_layer.self_attn.o_proj.weight,
|
||||||
|
f"{layer_name}.self_attn.o_proj.weight",
|
||||||
|
concat_dim=1,
|
||||||
|
src_pp_rank=src_pp_rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
_broadcast_tensor(
|
||||||
|
sync_layer.post_attention_layernorm.weight,
|
||||||
|
f"{layer_name}.post_attention_layernorm.weight",
|
||||||
|
src_pp_rank=src_pp_rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
_broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.gate_up_proj.weight,
|
||||||
|
f"{layer_name}.mlp.gate_proj.weight",
|
||||||
|
f"{layer_name}.mlp.up_proj.weight",
|
||||||
|
src_pp_rank=src_pp_rank)
|
||||||
|
|
||||||
|
_broadcast_tp_shard_tensor(
|
||||||
|
sync_layer.mlp.down_proj.weight,
|
||||||
|
f"{layer_name}.mlp.down_proj.weight",
|
||||||
|
concat_dim=1,
|
||||||
|
src_pp_rank=src_pp_rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Final Layernorm
|
||||||
|
# -------------------
|
||||||
|
print_rank_0("collecting final layernorm...")
|
||||||
|
gpt_model_module = _get_gpt_model(models[-1])
|
||||||
|
_broadcast_tensor(
|
||||||
|
getattr(gpt_model_module.model.norm, "weight", None),
|
||||||
|
"model.norm.weight",
|
||||||
|
src_pp_rank=pp_size - 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
print_rank_0("collecting lm_head...")
|
||||||
|
|
||||||
|
if is_value_model:
|
||||||
|
_broadcast_tensor(getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None,
|
||||||
|
"reward_head.weight",
|
||||||
|
src_pp_rank=pp_size - 1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
_broadcast_tp_shard_tensor(
|
||||||
|
getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None,
|
||||||
|
"lm_head.weight",
|
||||||
|
src_pp_rank=pp_size - 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
if dtype == "fp16":
|
||||||
|
dtype = torch.float16
|
||||||
|
elif dtype == "bf16":
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
elif dtype is None or dtype == "fp32":
|
||||||
|
dtype = torch.float32
|
||||||
|
else:
|
||||||
|
print(f'Unknown/unsupported dtype to save: {dtype}"')
|
||||||
|
exit(1)
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if dtype != v.dtype:
|
||||||
|
state_dict[k] = v.to(dtype)
|
||||||
|
|
||||||
|
print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s")
|
||||||
|
return state_dict
|
||||||
18
verl/models/llama/megatron/layers/__init__.py
Normal file
18
verl/models/llama/megatron/layers/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .parallel_attention import ParallelLlamaAttention
|
||||||
|
from .parallel_decoder import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad
|
||||||
|
from .parallel_mlp import ParallelLlamaMLP
|
||||||
|
from .parallel_rmsnorm import ParallelLlamaRMSNorm
|
||||||
418
verl/models/llama/megatron/layers/parallel_attention.py
Normal file
418
verl/models/llama/megatron/layers/parallel_attention.py
Normal file
@@ -0,0 +1,418 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from megatron.core import parallel_state as mpu
|
||||||
|
from megatron.core import tensor_parallel
|
||||||
|
from megatron.core import ModelParallelConfig
|
||||||
|
from torch import nn
|
||||||
|
from transformers import LlamaConfig
|
||||||
|
from verl.models.llama.megatron.layers.parallel_linear import QKVParallelLinear
|
||||||
|
|
||||||
|
from verl.utils.megatron import tensor_parallel as tp_utils
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaRotaryEmbedding(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.base = base
|
||||||
|
inv_freq = 1.0 / (self.base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
|
# Build here to make `torch.jit.trace` work.
|
||||||
|
self._set_cos_sin_cache(seq_len=max_position_embeddings,
|
||||||
|
device=self.inv_freq.device,
|
||||||
|
dtype=torch.get_default_dtype())
|
||||||
|
|
||||||
|
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||||
|
self.max_seq_len_cached = seq_len
|
||||||
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
|
||||||
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||||
|
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||||
|
|
||||||
|
def forward(self, x, seq_len=None):
|
||||||
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||||
|
if seq_len > self.max_seq_len_cached:
|
||||||
|
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
return (
|
||||||
|
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||||
|
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||||
|
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||||
|
|
||||||
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
super().__init__(dim, max_position_embeddings, base, device)
|
||||||
|
|
||||||
|
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||||
|
self.max_seq_len_cached = seq_len
|
||||||
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
t = t / self.scaling_factor
|
||||||
|
|
||||||
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||||
|
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||||
|
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||||
|
|
||||||
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
super().__init__(dim, max_position_embeddings, base, device)
|
||||||
|
|
||||||
|
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||||
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
if seq_len > self.max_position_embeddings:
|
||||||
|
base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) -
|
||||||
|
(self.scaling_factor - 1))**(self.dim / (self.dim - 2))
|
||||||
|
inv_freq = 1.0 / (base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
|
||||||
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||||
|
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
"""Rotates half the hidden dims of the input."""
|
||||||
|
x1 = x[..., :x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2:]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||||
|
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||||
|
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||||
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||||
|
"""
|
||||||
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return hidden_states
|
||||||
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||||
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelLlamaAttention(nn.Module):
|
||||||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.megatron_config = megatron_config
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
|
self.num_key_value_heads = config.num_key_value_heads
|
||||||
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
|
self.rope_theta = config.rope_theta
|
||||||
|
|
||||||
|
# assign values after tp
|
||||||
|
tp_size = mpu.get_tensor_model_parallel_world_size()
|
||||||
|
assert self.num_heads % tp_size == 0, f'num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}'
|
||||||
|
assert self.num_key_value_heads % tp_size == 0, \
|
||||||
|
f'num_key_value_heads must be divisible by tp_size. Got num_key_value_heads={self.num_key_value_heads}, tp_size={tp_size}'
|
||||||
|
|
||||||
|
self.num_heads_per_tp = self.num_heads // tp_size
|
||||||
|
self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size
|
||||||
|
self.hidden_size_per_tp = self.hidden_size // tp_size
|
||||||
|
|
||||||
|
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||||
|
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||||
|
f" and `num_heads`: {self.num_heads}).")
|
||||||
|
|
||||||
|
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
|
||||||
|
row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()
|
||||||
|
|
||||||
|
if megatron_config is not None:
|
||||||
|
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||||
|
assert row_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||||
|
tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)
|
||||||
|
tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)
|
||||||
|
|
||||||
|
# [self.q_size, self.k_size, self.v_size]
|
||||||
|
self.qkv_proj = QKVParallelLinear(input_size=self.hidden_size,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_key_value_heads=self.num_key_value_heads,
|
||||||
|
head_dim=self.head_dim,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
gather_output=False,
|
||||||
|
skip_bias_add=False,
|
||||||
|
**column_kwargs)
|
||||||
|
|
||||||
|
self.q_size = self.num_heads_per_tp * self.head_dim
|
||||||
|
self.k_size = self.num_key_value_heads_per_tp * self.head_dim
|
||||||
|
self.v_size = self.num_key_value_heads_per_tp * self.head_dim
|
||||||
|
|
||||||
|
self.o_proj = tensor_parallel.RowParallelLinear(input_size=self.num_heads * self.head_dim,
|
||||||
|
output_size=self.hidden_size,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
input_is_parallel=True,
|
||||||
|
skip_bias_add=False,
|
||||||
|
**row_kwargs)
|
||||||
|
|
||||||
|
self._init_rope()
|
||||||
|
|
||||||
|
def _init_rope(self):
|
||||||
|
if self.config.rope_scaling is None:
|
||||||
|
self.rotary_emb = LlamaRotaryEmbedding(
|
||||||
|
self.head_dim,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
base=self.rope_theta,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
scaling_type = self.config.rope_scaling["type"]
|
||||||
|
scaling_factor = self.config.rope_scaling["factor"]
|
||||||
|
if scaling_type == "linear":
|
||||||
|
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
|
||||||
|
self.head_dim,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
|
base=self.rope_theta,
|
||||||
|
)
|
||||||
|
elif scaling_type == "dynamic":
|
||||||
|
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
|
||||||
|
self.head_dim,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
|
base=self.rope_theta,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||||
|
|
||||||
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
qkv = self.qkv_proj(hidden_states)[0]
|
||||||
|
query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}")
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}")
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}")
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp)
|
||||||
|
attn_output = self.o_proj(attn_output)[0]
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Remove padding Attention
|
||||||
|
- Using Flash-attn 2
|
||||||
|
- Compatible with sequence parallel
|
||||||
|
"""
|
||||||
|
|
||||||
|
from transformers.utils import is_flash_attn_2_available
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
if is_flash_attn_2_available():
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length):
|
||||||
|
batch_size = position_ids.shape[0]
|
||||||
|
|
||||||
|
q = pad_input(q, indices, batch_size, sequence_length) # (batch_size, seqlen, num_head, head_dim)
|
||||||
|
k = pad_input(k, indices, batch_size, sequence_length)
|
||||||
|
cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
|
||||||
|
sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
|
||||||
|
q_embed = index_first_axis(rearrange(q_embed, "b s ... -> (b s) ..."), indices)
|
||||||
|
k_embed = index_first_axis(rearrange(k_embed, "b s ... -> (b s) ..."), indices)
|
||||||
|
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
from flash_attn.layers.rotary import apply_rotary_emb
|
||||||
|
|
||||||
|
|
||||||
|
# use flash-attn rotary embeddings with rmpad
|
||||||
|
# cos/sin shoudl be: (seq_length, rotary_dim / 2)
|
||||||
|
def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen):
|
||||||
|
q_embed = apply_rotary_emb(q,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
interleaved=False,
|
||||||
|
inplace=False,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen)
|
||||||
|
k_embed = apply_rotary_emb(k,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
interleaved=False,
|
||||||
|
inplace=False,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelLlamaAttentionRmPad(ParallelLlamaAttention):
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
sequence_length: int = None,
|
||||||
|
indices: torch.Tensor = None,
|
||||||
|
cu_seqlens: torch.Tensor = None,
|
||||||
|
max_seqlen_in_batch: int = None):
|
||||||
|
total_nnz, _, _ = hidden_states.size() # This is the total_nnz padded after sequence parallel
|
||||||
|
|
||||||
|
if self.megatron_config.sequence_parallel:
|
||||||
|
total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
qkv = self.qkv_proj(hidden_states)[0]
|
||||||
|
query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size],
|
||||||
|
dim=-1) # (total_nnz, 1, hidden_size)
|
||||||
|
|
||||||
|
if self.megatron_config.sequence_parallel:
|
||||||
|
sequence_parallel_pad = total_nnz - cu_seqlens[-1]
|
||||||
|
total_nnz = cu_seqlens[-1] # total_nnz before sp padding
|
||||||
|
query_states = query_states[:total_nnz]
|
||||||
|
key_states = key_states[:total_nnz]
|
||||||
|
value_states = value_states[:total_nnz]
|
||||||
|
|
||||||
|
# Flash attention requires the input to have the shape
|
||||||
|
# batch_size x seq_length x head_dime x hidden_dim
|
||||||
|
# therefore we just need to keep the original shape
|
||||||
|
query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim)
|
||||||
|
key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)
|
||||||
|
value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)
|
||||||
|
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=sequence_length)
|
||||||
|
cos, sin = cos[:, :cos.shape[1] // 2], sin[:, :sin.shape[1] // 2] # flash attn only needs half
|
||||||
|
query_states, key_states = apply_rotary_pos_emb_rmpad_flash(query_states,
|
||||||
|
key_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen_in_batch)
|
||||||
|
# query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, position_ids, indices,
|
||||||
|
|
||||||
|
# TODO: llama does not have dropout in the config??
|
||||||
|
# It is recommended to use dropout with FA according to the docs
|
||||||
|
# when training.
|
||||||
|
dropout_rate = 0.0 # if not self.training else self.attn_dropout
|
||||||
|
|
||||||
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
|
# cast them back in float16 just to be sure everything works as expected.
|
||||||
|
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||||
|
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||||
|
input_dtype = query_states.dtype
|
||||||
|
if input_dtype == torch.float32:
|
||||||
|
query_states = query_states.to(torch.float16)
|
||||||
|
key_states = key_states.to(torch.float16)
|
||||||
|
value_states = value_states.to(torch.float16)
|
||||||
|
|
||||||
|
attn_output_unpad = flash_attn_varlen_func(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
cu_seqlens_q=cu_seqlens,
|
||||||
|
cu_seqlens_k=cu_seqlens,
|
||||||
|
max_seqlen_q=max_seqlen_in_batch,
|
||||||
|
max_seqlen_k=max_seqlen_in_batch,
|
||||||
|
dropout_p=dropout_rate,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output_unpad = attn_output_unpad.to(input_dtype)
|
||||||
|
attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous()
|
||||||
|
|
||||||
|
# sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled
|
||||||
|
# Here we need to repad
|
||||||
|
if self.megatron_config.sequence_parallel:
|
||||||
|
attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad))
|
||||||
|
|
||||||
|
attn_output_unpad = self.o_proj(attn_output_unpad)[0]
|
||||||
|
return attn_output_unpad
|
||||||
146
verl/models/llama/megatron/layers/parallel_decoder.py
Normal file
146
verl/models/llama/megatron/layers/parallel_decoder.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import LlamaConfig
|
||||||
|
from megatron.core import ModelParallelConfig
|
||||||
|
|
||||||
|
from .parallel_attention import ParallelLlamaAttention, ParallelLlamaAttentionRmPad
|
||||||
|
from .parallel_mlp import ParallelLlamaMLP
|
||||||
|
from .parallel_rmsnorm import ParallelLlamaRMSNorm
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelLlamaDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.self_attn = ParallelLlamaAttention(config=config, megatron_config=megatron_config)
|
||||||
|
|
||||||
|
self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config)
|
||||||
|
self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config)
|
||||||
|
self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||||
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
use_cache (`bool`, *optional*):
|
||||||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||||
|
(see `past_key_values`).
|
||||||
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||||
|
"""
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Note: sequence parallel is hidden inside ColumnParallelLinear
|
||||||
|
# reduce scatter is hidden inside RowParallelLinear
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: add sequence parallel operator reduce_scatter here
|
||||||
|
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# TODO: add sequence parallel operator all_gather here
|
||||||
|
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
|
||||||
|
# TODO: add sequence parallel operator reduce_scatter here
|
||||||
|
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = hidden_states
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelLlamaDecoderLayerRmPad(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.megatron_config = megatron_config
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.self_attn = ParallelLlamaAttentionRmPad(config=config, megatron_config=megatron_config)
|
||||||
|
|
||||||
|
self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config)
|
||||||
|
self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config)
|
||||||
|
self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
sequence_length: int = None,
|
||||||
|
indices: torch.Tensor = None,
|
||||||
|
cu_seqlens: int = None,
|
||||||
|
max_seqlen_in_batch: int = None
|
||||||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
residual = hidden_states # (total_nnz // sp, 1, hidden_size)
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
# (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size)
|
||||||
|
# -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size)
|
||||||
|
hidden_states = self.self_attn(hidden_states=hidden_states,
|
||||||
|
position_ids=position_ids,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
indices=indices,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen_in_batch=max_seqlen_in_batch)
|
||||||
|
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
# shape changes same as attn
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = hidden_states
|
||||||
|
|
||||||
|
return outputs
|
||||||
74
verl/models/llama/megatron/layers/parallel_linear.py
Normal file
74
verl/models/llama/megatron/layers/parallel_linear.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
from megatron.core import tensor_parallel
|
||||||
|
|
||||||
|
|
||||||
|
class QKVParallelLinear(tensor_parallel.ColumnParallelLinear):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
input_size,
|
||||||
|
num_heads,
|
||||||
|
num_key_value_heads,
|
||||||
|
head_dim,
|
||||||
|
*,
|
||||||
|
bias=True,
|
||||||
|
gather_output=True,
|
||||||
|
skip_bias_add=False,
|
||||||
|
**kwargs):
|
||||||
|
# Keep input parameters, and already restrict the head numbers
|
||||||
|
self.input_size = input_size
|
||||||
|
self.q_output_size = num_heads * head_dim
|
||||||
|
self.kv_output_size = num_key_value_heads * head_dim
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.gather_output = gather_output
|
||||||
|
self.skip_bias_add = skip_bias_add
|
||||||
|
|
||||||
|
input_size = self.input_size
|
||||||
|
output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim
|
||||||
|
|
||||||
|
super().__init__(input_size=input_size,
|
||||||
|
output_size=output_size,
|
||||||
|
bias=bias,
|
||||||
|
gather_output=gather_output,
|
||||||
|
skip_bias_add=skip_bias_add,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
input_size,
|
||||||
|
gate_ouput_size,
|
||||||
|
up_output_size,
|
||||||
|
*,
|
||||||
|
bias=True,
|
||||||
|
gather_output=True,
|
||||||
|
skip_bias_add=False,
|
||||||
|
**kwargs):
|
||||||
|
# Keep input parameters, and already restrict the head numbers
|
||||||
|
self.input_size = input_size
|
||||||
|
self.output_size = gate_ouput_size + up_output_size
|
||||||
|
self.gather_output = gather_output
|
||||||
|
self.skip_bias_add = skip_bias_add
|
||||||
|
|
||||||
|
super().__init__(input_size=self.input_size,
|
||||||
|
output_size=self.output_size,
|
||||||
|
bias=bias,
|
||||||
|
gather_output=gather_output,
|
||||||
|
skip_bias_add=skip_bias_add,
|
||||||
|
**kwargs)
|
||||||
74
verl/models/llama/megatron/layers/parallel_mlp.py
Normal file
74
verl/models/llama/megatron/layers/parallel_mlp.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from megatron.core import parallel_state as mpu
|
||||||
|
from megatron.core import tensor_parallel
|
||||||
|
from megatron.core import ModelParallelConfig
|
||||||
|
from torch import nn
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from verl.models.llama.megatron.layers.parallel_linear import MergedColumnParallelLinear
|
||||||
|
|
||||||
|
from verl.utils.megatron import tensor_parallel as tp_utils
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelLlamaMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.intermediate_size = config.intermediate_size
|
||||||
|
# The weight is only [hidden_size, intermediate_size // model_parallel_world_size]
|
||||||
|
|
||||||
|
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
|
||||||
|
row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()
|
||||||
|
|
||||||
|
if megatron_config is not None:
|
||||||
|
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||||
|
assert row_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||||
|
tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)
|
||||||
|
tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)
|
||||||
|
|
||||||
|
tp_size = mpu.get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
|
input_size=self.hidden_size,
|
||||||
|
gate_ouput_size=self.intermediate_size,
|
||||||
|
up_output_size=self.intermediate_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
skip_bias_add=False,
|
||||||
|
**column_kwargs,
|
||||||
|
)
|
||||||
|
self.gate_size = self.intermediate_size // tp_size
|
||||||
|
|
||||||
|
self.down_proj = tensor_parallel.RowParallelLinear(input_size=self.intermediate_size,
|
||||||
|
output_size=self.hidden_size,
|
||||||
|
bias=False,
|
||||||
|
input_is_parallel=True,
|
||||||
|
skip_bias_add=False,
|
||||||
|
**row_kwargs)
|
||||||
|
|
||||||
|
self.act_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gate_up = self.gate_up_proj(x)[0]
|
||||||
|
gate, up = gate_up.split(self.gate_size, dim=-1)
|
||||||
|
return self.down_proj(self.act_fn(gate) * up)[0]
|
||||||
46
verl/models/llama/megatron/layers/parallel_rmsnorm.py
Normal file
46
verl/models/llama/megatron/layers/parallel_rmsnorm.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import numbers
|
||||||
|
import torch
|
||||||
|
from megatron.core import ModelParallelConfig
|
||||||
|
from torch import nn
|
||||||
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
|
from apex.normalization.fused_layer_norm import fused_rms_norm_affine
|
||||||
|
from verl.utils.megatron import sequence_parallel as sp_utils
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelLlamaRMSNorm(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
|
||||||
|
"""
|
||||||
|
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if isinstance(config.hidden_size, numbers.Integral):
|
||||||
|
normalized_shape = (config.hidden_size,)
|
||||||
|
self.normalized_shape = torch.Size(normalized_shape)
|
||||||
|
self.weight = nn.Parameter(torch.ones(self.normalized_shape))
|
||||||
|
self.variance_epsilon = config.rms_norm_eps
|
||||||
|
|
||||||
|
if megatron_config.sequence_parallel:
|
||||||
|
sp_utils.mark_parameter_as_sequence_parallel(self.weight)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
return fused_rms_norm_affine(input=hidden_states,
|
||||||
|
weight=self.weight,
|
||||||
|
normalized_shape=self.normalized_shape,
|
||||||
|
eps=self.variance_epsilon,
|
||||||
|
memory_efficient=True)
|
||||||
656
verl/models/llama/megatron/modeling_llama_megatron.py
Normal file
656
verl/models/llama/megatron/modeling_llama_megatron.py
Normal file
@@ -0,0 +1,656 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""PyTorch LLaMA model with Megatron-style acceleration."""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from megatron.core import tensor_parallel
|
||||||
|
from megatron.core import ModelParallelConfig
|
||||||
|
from torch import nn
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||||
|
from transformers.models.llama.modeling_llama import CausalLMOutputWithPast
|
||||||
|
|
||||||
|
from verl.utils.megatron import sequence_parallel as sp_utils
|
||||||
|
from verl.utils.megatron import tensor_parallel as tp_utils
|
||||||
|
from .layers import ParallelLlamaDecoderLayer, ParallelLlamaRMSNorm, ParallelLlamaDecoderLayerRmPad
|
||||||
|
"""
|
||||||
|
TODO:
|
||||||
|
1. Add weight initialization. Here we need to be careful on TP weight init.
|
||||||
|
2. Add sequence parallel
|
||||||
|
3. Load checkpoint from meta LLama pretrained checkpoint
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
||||||
|
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device):
|
||||||
|
"""
|
||||||
|
Make causal mask used for bi-directional self-attention.
|
||||||
|
"""
|
||||||
|
bsz, tgt_len = input_ids_shape
|
||||||
|
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
||||||
|
mask_cond = torch.arange(mask.size(-1), device=device)
|
||||||
|
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||||
|
mask = mask.to(dtype)
|
||||||
|
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
||||||
|
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||||
|
"""
|
||||||
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||||
|
"""
|
||||||
|
bsz, src_len = mask.size()
|
||||||
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||||
|
|
||||||
|
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||||
|
|
||||||
|
inverted_mask = 1.0 - expanded_mask
|
||||||
|
|
||||||
|
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelLlamaModel(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: LlamaConfig
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
|
||||||
|
if megatron_config is not None:
|
||||||
|
assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||||
|
tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
|
||||||
|
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size,
|
||||||
|
embedding_dim=config.hidden_size,
|
||||||
|
**embedding_kwargs)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[ParallelLlamaDecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)])
|
||||||
|
self.norm = ParallelLlamaRMSNorm(config, megatron_config)
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
||||||
|
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds):
|
||||||
|
# create causal mask
|
||||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
combined_attention_mask = None
|
||||||
|
if input_shape[-1] > 1:
|
||||||
|
combined_attention_mask = _make_causal_mask(
|
||||||
|
input_shape,
|
||||||
|
inputs_embeds.dtype,
|
||||||
|
device=inputs_embeds.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype,
|
||||||
|
tgt_len=input_shape[-1]).to(inputs_embeds.device)
|
||||||
|
combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask +
|
||||||
|
combined_attention_mask)
|
||||||
|
|
||||||
|
return combined_attention_mask
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: input ids. shape (batch_size, seq_length)
|
||||||
|
attention_mask: attention_mask. shape (batch_size, seq_length)
|
||||||
|
position_ids: position ids. shape (batch_size, seq_length)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
# embed positions
|
||||||
|
|
||||||
|
attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelLlamaForCausalLM(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.model = ParallelLlamaModel(config, megatron_config=megatron_config)
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
|
||||||
|
if megatron_config is not None:
|
||||||
|
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||||
|
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
|
||||||
|
|
||||||
|
self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=config.hidden_size,
|
||||||
|
output_size=config.vocab_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
skip_bias_add=False,
|
||||||
|
**column_kwargs)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
```"""
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs
|
||||||
|
logits = self.lm_head(hidden_states)[0]
|
||||||
|
|
||||||
|
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)
|
||||||
|
|
||||||
|
logits = logits.float()
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=None,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=None,
|
||||||
|
hidden_states=None,
|
||||||
|
attentions=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelLlamaModelRmPad(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: LlamaConfig
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
|
||||||
|
self.megatron_config = megatron_config
|
||||||
|
if megatron_config is not None:
|
||||||
|
assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||||
|
tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
|
||||||
|
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size,
|
||||||
|
embedding_dim=config.hidden_size,
|
||||||
|
**embedding_kwargs)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[ParallelLlamaDecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)])
|
||||||
|
self.norm = ParallelLlamaRMSNorm(config, megatron_config)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
sequence_length: int = None,
|
||||||
|
indices: torch.Tensor = None,
|
||||||
|
cu_seqlens: int = None,
|
||||||
|
max_seqlen_in_batch: int = None) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: input ids. shape (1, totol_nnz)
|
||||||
|
position_ids: position ids. shape (batch_size, seq_length)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size)
|
||||||
|
|
||||||
|
# (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)
|
||||||
|
inputs_embeds = inputs_embeds.transpose(0, 1)
|
||||||
|
if self.megatron_config.sequence_parallel:
|
||||||
|
inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
|
layer_outputs = decoder_layer(hidden_states,
|
||||||
|
position_ids=position_ids,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
indices=indices,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen_in_batch=max_seqlen_in_batch)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelLlamaForCausalLMRmPad(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.megatron_config = megatron_config
|
||||||
|
self.model = ParallelLlamaModelRmPad(config, megatron_config=megatron_config)
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self._init_head()
|
||||||
|
|
||||||
|
def _init_head(self):
|
||||||
|
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
|
||||||
|
if self.megatron_config is not None:
|
||||||
|
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||||
|
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
|
||||||
|
self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=self.config.hidden_size,
|
||||||
|
output_size=self.config.vocab_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
skip_bias_add=False,
|
||||||
|
**column_kwargs)
|
||||||
|
|
||||||
|
def _forward_head(self, hidden_states):
|
||||||
|
# all_gather from sequence parallel region is performed inside lm_head
|
||||||
|
logits = self.lm_head(hidden_states)[0]
|
||||||
|
logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp)
|
||||||
|
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) # (total_nnz_padded, 1, vocab_size)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
```"""
|
||||||
|
batch_size, sequence_length = input_ids.shape
|
||||||
|
|
||||||
|
# remove padding here
|
||||||
|
input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1),
|
||||||
|
attention_mask) # (total_nnz, 1)
|
||||||
|
|
||||||
|
# pad input_ids to multiple of tp for all tp ranks
|
||||||
|
# TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap
|
||||||
|
if self.megatron_config.sequence_parallel:
|
||||||
|
input_ids = sp_utils.pad_to_sequence_parallel(input_ids)
|
||||||
|
|
||||||
|
input_ids = input_ids.transpose(0, 1) # (1, total_nnz+pad)
|
||||||
|
|
||||||
|
outputs = self.model(input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
indices=indices,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen_in_batch=max_seqlen_in_batch)
|
||||||
|
|
||||||
|
hidden_states = outputs
|
||||||
|
|
||||||
|
logits = self._forward_head(hidden_states)
|
||||||
|
|
||||||
|
# remove padding from sequence parallel
|
||||||
|
if self.megatron_config.sequence_parallel:
|
||||||
|
totol_nnz = cu_seqlens[-1]
|
||||||
|
logits = logits[:totol_nnz] # (total_nnz_padded)
|
||||||
|
|
||||||
|
logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension
|
||||||
|
# add removed padding back
|
||||||
|
logits = pad_input(logits, indices, batch_size,
|
||||||
|
seqlen=sequence_length) # (batch_size, sequence_length, vocab_size)
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=None,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=None,
|
||||||
|
hidden_states=None,
|
||||||
|
attentions=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelLlamaForValueRmPad(ParallelLlamaForCausalLMRmPad):
|
||||||
|
|
||||||
|
def _init_head(self):
|
||||||
|
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
|
||||||
|
if self.megatron_config is not None:
|
||||||
|
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||||
|
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
|
||||||
|
self.lm_head = nn.Linear(in_features=self.config.hidden_size, out_features=1, bias=False)
|
||||||
|
# lm_head is effectively the same as sequence parallel
|
||||||
|
sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)
|
||||||
|
|
||||||
|
def _forward_head(self, hidden_states):
|
||||||
|
logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1)
|
||||||
|
logits = logits.float()
|
||||||
|
if self.megatron_config.sequence_parallel:
|
||||||
|
logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
output = super().forward(input_ids, attention_mask, position_ids)
|
||||||
|
output.logits = torch.squeeze(output.logits, dim=-1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Support pipeline parallelism
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelLlamaModelRmPadPP(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
||||||
|
This model definition supports pipeline parallelism. To support pp and vpp,
|
||||||
|
- This model only contains layer in this pp stage and vpp chunk
|
||||||
|
- When calling get_model in Megatron, this rank will instantiate all the vpp chunks in this pp.
|
||||||
|
Args:
|
||||||
|
config: LlamaConfig
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process):
|
||||||
|
super().__init__()
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.pre_process = pre_process
|
||||||
|
self.post_process = post_process
|
||||||
|
self.megatron_config = megatron_config
|
||||||
|
embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
|
||||||
|
if megatron_config is not None:
|
||||||
|
assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||||
|
tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
|
||||||
|
if pre_process:
|
||||||
|
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size,
|
||||||
|
embedding_dim=config.hidden_size,
|
||||||
|
**embedding_kwargs)
|
||||||
|
else:
|
||||||
|
self.embed_tokens = None
|
||||||
|
|
||||||
|
# pp_rank = megatron_config.pipeline_model_parallel_rank
|
||||||
|
pp_size = megatron_config.pipeline_model_parallel_size
|
||||||
|
self.num_layer_per_pp = config.num_hidden_layers // pp_size
|
||||||
|
vpp_size = megatron_config.virtual_pipeline_model_parallel_size
|
||||||
|
|
||||||
|
if vpp_size is not None:
|
||||||
|
self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size
|
||||||
|
self.num_layer_this_model = self.num_layer_vpp_chunk
|
||||||
|
# vpp_rank = megatron_config.virtual_pipeline_model_parallel_rank
|
||||||
|
# self.offset = vpp_rank * (
|
||||||
|
# config.num_hidden_layers // megatron_config.virtual_pipeline_model_parallel_size) + \
|
||||||
|
# (megatron_config.pipeline_model_parallel_rank * self.num_layer_vpp_chunk)
|
||||||
|
else:
|
||||||
|
self.num_layer_this_model = self.num_layer_per_pp
|
||||||
|
# self.offset = pp_rank * self.num_layer_per_pp
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
for i in range(self.num_layer_this_model):
|
||||||
|
layer = ParallelLlamaDecoderLayerRmPad(config, megatron_config)
|
||||||
|
# setattr(layer, 'hidden_layer_index', self.offset + i)
|
||||||
|
layers.append(layer)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(layers)
|
||||||
|
|
||||||
|
if post_process:
|
||||||
|
self.norm = ParallelLlamaRMSNorm(config, megatron_config)
|
||||||
|
else:
|
||||||
|
self.norm = None
|
||||||
|
|
||||||
|
def set_input_tensor(self, input_tensor):
|
||||||
|
"""Set input tensor to be used instead of forward()'s input.
|
||||||
|
|
||||||
|
When doing pipeline parallelism the input from the previous
|
||||||
|
stage comes from communication, not from the input, so the
|
||||||
|
model's forward_step_func won't have it. This function is thus
|
||||||
|
used by internal code to bypass the input provided by the
|
||||||
|
forward_step_func"""
|
||||||
|
self.input_tensor = input_tensor
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
sequence_length: int = None,
|
||||||
|
indices: torch.Tensor = None,
|
||||||
|
cu_seqlens: int = None,
|
||||||
|
max_seqlen_in_batch: int = None) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: input ids. shape (1, totol_nnz)
|
||||||
|
position_ids: position ids. shape (batch_size, seq_length)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.pre_process:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size)
|
||||||
|
|
||||||
|
# vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron
|
||||||
|
# so need to deal with it by handle here:
|
||||||
|
# (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)
|
||||||
|
inputs_embeds = inputs_embeds.transpose(0, 1)
|
||||||
|
if self.megatron_config.sequence_parallel:
|
||||||
|
inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
else:
|
||||||
|
# self.hidden_states should be passed by Megatron
|
||||||
|
hidden_states = self.input_tensor
|
||||||
|
|
||||||
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
|
layer_outputs = decoder_layer(hidden_states,
|
||||||
|
position_ids=position_ids,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
indices=indices,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen_in_batch=max_seqlen_in_batch)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs
|
||||||
|
|
||||||
|
if self.post_process:
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelLlamaForCausalLMRmPadPP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.megatron_config = megatron_config
|
||||||
|
self.model = ParallelLlamaModelRmPadPP(config,
|
||||||
|
megatron_config=megatron_config,
|
||||||
|
pre_process=pre_process,
|
||||||
|
post_process=post_process)
|
||||||
|
self.share_embeddings_and_output_weights = None # workaround, megatron requires this attr
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.pre_process = pre_process
|
||||||
|
self.post_process = post_process
|
||||||
|
if post_process:
|
||||||
|
self._init_head()
|
||||||
|
|
||||||
|
def set_input_tensor(self, input_tensor):
|
||||||
|
"""Set input tensor to be used instead of forward()'s input.
|
||||||
|
|
||||||
|
When doing pipeline parallelism the input from the previous
|
||||||
|
stage comes from communication, not from the input, so the
|
||||||
|
model's forward_step_func won't have it. This function is thus
|
||||||
|
used by internal code to bypass the input provided by the
|
||||||
|
forward_step_func"""
|
||||||
|
assert len(input_tensor) == 1
|
||||||
|
self.model.set_input_tensor(input_tensor[0])
|
||||||
|
|
||||||
|
def _init_head(self):
|
||||||
|
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
|
||||||
|
if self.megatron_config is not None:
|
||||||
|
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||||
|
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
|
||||||
|
self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=self.config.hidden_size,
|
||||||
|
output_size=self.config.vocab_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
skip_bias_add=False,
|
||||||
|
**column_kwargs)
|
||||||
|
|
||||||
|
def _forward_head(self, hidden_states):
|
||||||
|
# all_gather from sequence parallel region is performed inside lm_head
|
||||||
|
# logits shape before forward_head hidden_states.shape: [4, 32, 4096]
|
||||||
|
logits = self.lm_head(hidden_states)[0]
|
||||||
|
# logits shape after forward_head logits.shape: [8, 32, 8]
|
||||||
|
logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
# original input
|
||||||
|
*,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
```"""
|
||||||
|
|
||||||
|
# Note that input_ids, attention_mask and position_ids should be passed to every pp layer.
|
||||||
|
# In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model
|
||||||
|
batch_size, sequence_length = input_ids.shape
|
||||||
|
# remove padding here
|
||||||
|
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1),
|
||||||
|
attention_mask) # (total_nnz, 1)
|
||||||
|
|
||||||
|
# pad input_ids to multiple of tp for all tp ranks
|
||||||
|
# TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap
|
||||||
|
if self.megatron_config.sequence_parallel:
|
||||||
|
input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad)
|
||||||
|
|
||||||
|
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz+pad)
|
||||||
|
|
||||||
|
outputs = self.model(input_ids=input_ids_rmpad,
|
||||||
|
position_ids=position_ids,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
indices=indices,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen_in_batch=max_seqlen_in_batch)
|
||||||
|
|
||||||
|
if self.post_process:
|
||||||
|
hidden_states = outputs
|
||||||
|
# print(f'hidden_states.shape = {hidden_states.shape}') # torch.Size([4, 32, 4096])
|
||||||
|
logits = self._forward_head(hidden_states)
|
||||||
|
logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16])
|
||||||
|
|
||||||
|
# remove padding from sequence parallel
|
||||||
|
if self.megatron_config.sequence_parallel:
|
||||||
|
totol_nnz = cu_seqlens[-1]
|
||||||
|
logits = logits[:totol_nnz] # (total_nnz_padded)
|
||||||
|
# add removed padding back. If input is already rmpad, we let the caller pad_input
|
||||||
|
logits = pad_input(logits, indices, batch_size,
|
||||||
|
seqlen=sequence_length) # (batch_size, sequence_length, vocab_size)
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=None,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=None,
|
||||||
|
hidden_states=None,
|
||||||
|
attentions=None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelLlamaForValueRmPadPP(ParallelLlamaForCausalLMRmPadPP):
|
||||||
|
|
||||||
|
def _init_head(self):
|
||||||
|
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
|
||||||
|
if self.megatron_config is not None:
|
||||||
|
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
|
||||||
|
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
|
||||||
|
self.lm_head = nn.Linear(in_features=self.config.hidden_size, out_features=1, bias=False)
|
||||||
|
# lm_head is effectively the same as sequence parallel
|
||||||
|
sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)
|
||||||
|
|
||||||
|
def _forward_head(self, hidden_states):
|
||||||
|
logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1)
|
||||||
|
logits = logits.float()
|
||||||
|
if self.megatron_config.sequence_parallel:
|
||||||
|
logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
|
||||||
|
if self.post_process:
|
||||||
|
output.logits = torch.squeeze(output.logits, dim=-1)
|
||||||
|
return output
|
||||||
|
else:
|
||||||
|
return output
|
||||||
66
verl/models/registry.py
Normal file
66
verl/models/registry.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
from typing import List, Optional, Type
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# Supported models using HF Rmpad
|
||||||
|
# TODO(sgm): HF may supported more than listed here, we should add more after testing
|
||||||
|
from transformers import LlamaConfig, MistralConfig, GemmaConfig, Qwen2Config
|
||||||
|
|
||||||
|
_REOVEPAD_MODELS = {'llama': LlamaConfig, 'mistral': MistralConfig, 'gemma': GemmaConfig, 'qwen2': Qwen2Config}
|
||||||
|
|
||||||
|
|
||||||
|
def check_model_support_rmpad(model_type: str):
|
||||||
|
assert isinstance(model_type, str)
|
||||||
|
if not model_type in _REOVEPAD_MODELS.keys():
|
||||||
|
raise ValueError(f"Model architecture {model_type} is not supported for now. "
|
||||||
|
f"RMPad supported architectures: {_REOVEPAD_MODELS.keys()}."
|
||||||
|
f"Please set `use_remove_padding=False` in the model config.")
|
||||||
|
|
||||||
|
|
||||||
|
# Supported models in Megatron-LM
|
||||||
|
# Architecture -> (module, class).
|
||||||
|
_MODELS = {
|
||||||
|
"LlamaForCausalLM":
|
||||||
|
("llama", ("ParallelLlamaForCausalLMRmPadPP", "ParallelLlamaForValueRmPadPP", "ParallelLlamaForCausalLMRmPad")),
|
||||||
|
"MistralForCausalLM": ("mistral", ("ParallelMistralForCausalLMRmPadPP", "ParallelMistralForValueRmPadPP",
|
||||||
|
"ParallelMistralForCausalLMRmPad"))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# return model class
|
||||||
|
class ModelRegistry:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_model_cls(model_arch: str, value=False) -> Optional[Type[nn.Module]]:
|
||||||
|
if model_arch not in _MODELS:
|
||||||
|
return None
|
||||||
|
|
||||||
|
megatron = "megatron"
|
||||||
|
|
||||||
|
module_name, model_cls_name = _MODELS[model_arch]
|
||||||
|
if not value: # actor/ref
|
||||||
|
model_cls_name = model_cls_name[0]
|
||||||
|
elif value: # critic/rm
|
||||||
|
model_cls_name = model_cls_name[1]
|
||||||
|
|
||||||
|
module = importlib.import_module(f"verl.models.{module_name}.{megatron}.modeling_{module_name}_megatron")
|
||||||
|
return getattr(module, model_cls_name, None)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_supported_archs() -> List[str]:
|
||||||
|
return list(_MODELS.keys())
|
||||||
13
verl/models/transformers/__init__.py
Normal file
13
verl/models/transformers/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
145
verl/models/transformers/llama.py
Normal file
145
verl/models/transformers/llama.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from typing import Optional, List, Union, Tuple, Unpack, Callable
|
||||||
|
|
||||||
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.utils import logging
|
||||||
|
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||||
|
from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, get_ulysses_sequence_parallel_world_size
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
def llama_flash_attn_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
"""
|
||||||
|
adapt from transformers 4.47.1
|
||||||
|
"""
|
||||||
|
output_attentions = False
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
# Flash attention requires the input to have the shape
|
||||||
|
# batch_size x seq_length x head_dim x hidden_dim
|
||||||
|
# therefore we just need to keep the original shape
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# trade off: repeat first and then all to all
|
||||||
|
# key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
# value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
########## AlltoAll for Ulysses ##########
|
||||||
|
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
|
||||||
|
|
||||||
|
if ulysses_sp_size > 1:
|
||||||
|
# (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)
|
||||||
|
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
|
||||||
|
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
|
||||||
|
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
|
||||||
|
|
||||||
|
full_q_len = query_states.size(2) # full seq length
|
||||||
|
|
||||||
|
if position_embeddings is None:
|
||||||
|
logger.warning_once(
|
||||||
|
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||||
|
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||||
|
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||||
|
"removed and `position_embeddings` will be mandatory.")
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
else:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||||
|
# to be able to avoid many of these transpose/reshape/view.
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||||
|
|
||||||
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
|
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||||
|
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||||
|
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||||
|
|
||||||
|
input_dtype = query_states.dtype
|
||||||
|
if input_dtype == torch.float32:
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
target_dtype = torch.get_autocast_gpu_dtype()
|
||||||
|
# Handle the case where the model is quantized
|
||||||
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
|
else:
|
||||||
|
target_dtype = self.q_proj.weight.dtype
|
||||||
|
|
||||||
|
logger.warning_once(
|
||||||
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||||
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||||
|
f" {target_dtype}.")
|
||||||
|
|
||||||
|
query_states = query_states.to(target_dtype)
|
||||||
|
key_states = key_states.to(target_dtype)
|
||||||
|
value_states = value_states.to(target_dtype)
|
||||||
|
|
||||||
|
attn_output = _flash_attention_forward(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
full_q_len,
|
||||||
|
position_ids=position_ids,
|
||||||
|
dropout=dropout_rate,
|
||||||
|
sliding_window=getattr(self, "sliding_window", None),
|
||||||
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||||
|
is_causal=self.is_causal,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
|
||||||
|
########## AlltoAll for Ulysses ##########
|
||||||
|
if ulysses_sp_size > 1:
|
||||||
|
attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
74
verl/models/transformers/monkey_patch.py
Normal file
74
verl/models/transformers/monkey_patch.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Apply monkey-patch function to models
|
||||||
|
"""
|
||||||
|
|
||||||
|
#### Open Source Models
|
||||||
|
#### transformers version < 4.48
|
||||||
|
|
||||||
|
|
||||||
|
def apply_monkey_patch_to_llama():
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
|
||||||
|
from verl.models.transformers.llama import llama_flash_attn_forward
|
||||||
|
LlamaFlashAttention2.forward = llama_flash_attn_forward
|
||||||
|
|
||||||
|
|
||||||
|
def apply_monkey_patch_to_qwen2():
|
||||||
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2FlashAttention2
|
||||||
|
from verl.models.transformers.qwen2 import qwen2_flash_attn_forward
|
||||||
|
Qwen2FlashAttention2.forward = qwen2_flash_attn_forward
|
||||||
|
|
||||||
|
|
||||||
|
_PATCH_NAME_TO_FUNC = {
|
||||||
|
'llama': apply_monkey_patch_to_llama,
|
||||||
|
'qwen2': apply_monkey_patch_to_qwen2,
|
||||||
|
}
|
||||||
|
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
def apply_monkey_patch(config: PretrainedConfig, verbose=True):
|
||||||
|
if not is_transformers_version_in_range("4.45.0", "4.47.1"):
|
||||||
|
raise AssertionError("The installed `transformers` version doesn't support ulysses patch. "
|
||||||
|
"Please install a version between 4.45.0 and 4.47.1 to use this ulysses feature.")
|
||||||
|
success_apply_monkey_patch = False
|
||||||
|
if config.model_type in _PATCH_NAME_TO_FUNC:
|
||||||
|
_PATCH_NAME_TO_FUNC[config.model_type]()
|
||||||
|
success_apply_monkey_patch = True
|
||||||
|
|
||||||
|
if success_apply_monkey_patch and verbose:
|
||||||
|
print(f'Applying monkey patch to model {config.model_type}')
|
||||||
|
elif not success_apply_monkey_patch:
|
||||||
|
raise NotImplementedError(f'Ulysses for model {config.model_type} is not implemented, \
|
||||||
|
please set `ulysses_sequence_parallel_size=1`')
|
||||||
|
|
||||||
|
return success_apply_monkey_patch
|
||||||
|
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
|
from packaging import version
|
||||||
|
import importlib.metadata
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def is_transformers_version_in_range(min_version: str, max_version: str) -> bool:
|
||||||
|
try:
|
||||||
|
# Get the installed version of the transformers library
|
||||||
|
transformers_version = importlib.metadata.version("transformers")
|
||||||
|
except importlib.metadata.PackageNotFoundError:
|
||||||
|
raise ModuleNotFoundError("The `transformers` package is not installed.")
|
||||||
|
|
||||||
|
# Check if the version is within the specified range
|
||||||
|
return version.parse(min_version) <= version.parse(transformers_version) <= version.parse(max_version)
|
||||||
137
verl/models/transformers/qwen2.py
Normal file
137
verl/models/transformers/qwen2.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.utils import logging
|
||||||
|
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||||
|
from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, get_ulysses_sequence_parallel_world_size
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def qwen2_flash_attn_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||||
|
):
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
########## AlltoAll for Ulysses ##########
|
||||||
|
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
|
||||||
|
|
||||||
|
if ulysses_sp_size > 1:
|
||||||
|
# (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)
|
||||||
|
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
|
||||||
|
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
|
||||||
|
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
|
||||||
|
|
||||||
|
full_q_len = query_states.size(2) # full seq length
|
||||||
|
|
||||||
|
if position_embeddings is None:
|
||||||
|
logger.warning_once(
|
||||||
|
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||||
|
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||||
|
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||||
|
"removed and `position_embeddings` will be mandatory.")
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
else:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
||||||
|
|
||||||
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
|
# cast them back in float16 just to be sure everything works as expected.
|
||||||
|
input_dtype = query_states.dtype
|
||||||
|
if input_dtype == torch.float32:
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
target_dtype = torch.get_autocast_gpu_dtype()
|
||||||
|
# Handle the case where the model is quantized
|
||||||
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
|
else:
|
||||||
|
target_dtype = self.q_proj.weight.dtype
|
||||||
|
|
||||||
|
logger.warning_once(
|
||||||
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||||
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||||
|
f" {target_dtype}.")
|
||||||
|
|
||||||
|
query_states = query_states.to(target_dtype)
|
||||||
|
key_states = key_states.to(target_dtype)
|
||||||
|
value_states = value_states.to(target_dtype)
|
||||||
|
|
||||||
|
# Reashape to the expected shape for Flash Attention
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
if (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and
|
||||||
|
self.layer_idx >= self.config.max_window_layers):
|
||||||
|
sliding_window = self.config.sliding_window
|
||||||
|
else:
|
||||||
|
sliding_window = None
|
||||||
|
|
||||||
|
attn_output = _flash_attention_forward(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
full_q_len,
|
||||||
|
position_ids=position_ids,
|
||||||
|
dropout=dropout_rate,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
is_causal=self.is_causal,
|
||||||
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# use full_q_len to reshape
|
||||||
|
attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
|
||||||
|
########## AlltoAll for Ulysses ##########
|
||||||
|
if ulysses_sp_size > 1:
|
||||||
|
attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
23
verl/models/weight_loader_registry.py
Normal file
23
verl/models/weight_loader_registry.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
def get_weight_loader(arch: str):
|
||||||
|
from verl.models.llama.megatron.checkpoint_utils.llama_loader import load_state_dict_to_megatron_llama
|
||||||
|
_MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY = {'LlamaForCausalLM': load_state_dict_to_megatron_llama}
|
||||||
|
|
||||||
|
if arch in _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY:
|
||||||
|
return _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY[arch]
|
||||||
|
raise ValueError(f"Model architectures {arch} are not supported for now. "
|
||||||
|
f"Supported architectures: {_MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY.keys()}")
|
||||||
639
verl/protocol.py
Normal file
639
verl/protocol.py
Normal file
@@ -0,0 +1,639 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Implement base data transfer protocol between any two functions, modules.
|
||||||
|
We can subclass Protocol to define more detailed batch info with specific keys
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pickle
|
||||||
|
import numpy as np
|
||||||
|
import copy
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Callable, Dict, List, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import tensordict
|
||||||
|
from tensordict import TensorDict
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
from verl.utils.py_functional import union_two_dict
|
||||||
|
|
||||||
|
__all__ = ['DataProto', 'union_tensor_dict']
|
||||||
|
|
||||||
|
try:
|
||||||
|
tensordict.set_lazy_legacy(False).set()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def pad_dataproto_to_divisor(data: 'DataProto', size_divisor: int):
|
||||||
|
"""Pad a DataProto to size divisible by size_divisor
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size_divisor (int): size divisor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
data: (DataProto): the padded DataProto
|
||||||
|
pad_size (int)
|
||||||
|
"""
|
||||||
|
assert isinstance(data, DataProto), 'data must be a DataProto'
|
||||||
|
if len(data) % size_divisor != 0:
|
||||||
|
pad_size = size_divisor - len(data) % size_divisor
|
||||||
|
data_padded = DataProto.concat([data, data[:pad_size]])
|
||||||
|
else:
|
||||||
|
pad_size = 0
|
||||||
|
data_padded = data
|
||||||
|
return data_padded, pad_size
|
||||||
|
|
||||||
|
|
||||||
|
def unpad_dataproto(data: 'DataProto', pad_size):
|
||||||
|
if pad_size != 0:
|
||||||
|
data = data[:-pad_size]
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict:
|
||||||
|
"""Union two tensordicts."""
|
||||||
|
assert tensor_dict1.batch_size == tensor_dict2.batch_size, \
|
||||||
|
f'Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}'
|
||||||
|
for key in tensor_dict2.keys():
|
||||||
|
if key not in tensor_dict1.keys():
|
||||||
|
tensor_dict1[key] = tensor_dict2[key]
|
||||||
|
else:
|
||||||
|
assert tensor_dict1[key].equal(tensor_dict2[key]), \
|
||||||
|
f'{key} in tensor_dict1 and tensor_dict2 are not the same object'
|
||||||
|
|
||||||
|
return tensor_dict1
|
||||||
|
|
||||||
|
|
||||||
|
def union_numpy_dict(tensor_dict1: dict[np.ndarray], tensor_dict2: dict[np.ndarray]) -> dict[np.ndarray]:
|
||||||
|
for key, val in tensor_dict2.items():
|
||||||
|
if key in tensor_dict1:
|
||||||
|
assert isinstance(tensor_dict2[key], np.ndarray)
|
||||||
|
assert isinstance(tensor_dict1[key], np.ndarray)
|
||||||
|
assert np.all(tensor_dict2[key] == tensor_dict1[key]), \
|
||||||
|
f'{key} in tensor_dict1 and tensor_dict2 are not the same object'
|
||||||
|
tensor_dict1[key] = val
|
||||||
|
|
||||||
|
return tensor_dict1
|
||||||
|
|
||||||
|
|
||||||
|
def list_of_dict_to_dict_of_list(list_of_dict: list[dict]):
|
||||||
|
if len(list_of_dict) == 0:
|
||||||
|
return {}
|
||||||
|
keys = list_of_dict[0].keys()
|
||||||
|
output = {key: [] for key in keys}
|
||||||
|
for data in list_of_dict:
|
||||||
|
for key, item in data.items():
|
||||||
|
assert key in output
|
||||||
|
output[key].append(item)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def fold_batch_dim(data: 'DataProto', new_batch_size):
|
||||||
|
"""
|
||||||
|
Fold a batch dim from [bsz, xxx] into [new_bsz, bsz // new_bsz, xxx]
|
||||||
|
"""
|
||||||
|
batch_size = data.batch.batch_size[0]
|
||||||
|
|
||||||
|
assert batch_size % new_batch_size == 0
|
||||||
|
|
||||||
|
tensor: TensorDict = data.batch
|
||||||
|
non_tensor = data.non_tensor_batch
|
||||||
|
|
||||||
|
tensor = tensor.view(new_batch_size, -1)
|
||||||
|
tensor.auto_batch_size_(batch_dims=1)
|
||||||
|
|
||||||
|
for key, val in non_tensor.items():
|
||||||
|
non_tensor[key] = np.reshape(val, newshape=(new_batch_size, -1, *val.shape[1:]))
|
||||||
|
|
||||||
|
return DataProto(batch=tensor, non_tensor_batch=non_tensor, meta_info=data.meta_info)
|
||||||
|
|
||||||
|
|
||||||
|
def unfold_batch_dim(data: 'DataProto', batch_dims=2):
|
||||||
|
"""
|
||||||
|
Unfold the first n dims as new batch dim
|
||||||
|
"""
|
||||||
|
tensor: TensorDict = data.batch
|
||||||
|
non_tensor = data.non_tensor_batch
|
||||||
|
tensor.auto_batch_size_(batch_dims=batch_dims)
|
||||||
|
tensor = tensor.view(-1)
|
||||||
|
|
||||||
|
batch_size = tensor.batch_size[0]
|
||||||
|
|
||||||
|
non_tensor_new = {}
|
||||||
|
|
||||||
|
for key, val in non_tensor.items():
|
||||||
|
non_tensor_new[key] = np.reshape(val, newshape=(batch_size, *val.shape[batch_dims:]))
|
||||||
|
|
||||||
|
return DataProto(batch=tensor, non_tensor_batch=non_tensor_new, meta_info=data.meta_info)
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(x: list['DataProtoItem']):
|
||||||
|
batch = []
|
||||||
|
non_tensor_batch = []
|
||||||
|
for data in x:
|
||||||
|
batch.append(data.batch)
|
||||||
|
non_tensor_batch.append(data.non_tensor_batch)
|
||||||
|
batch = torch.stack(batch).contiguous()
|
||||||
|
non_tensor_batch = list_of_dict_to_dict_of_list(non_tensor_batch)
|
||||||
|
for key, val in non_tensor_batch.items():
|
||||||
|
non_tensor_batch[key] = np.array(val, dtype=object)
|
||||||
|
return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataProtoItem:
|
||||||
|
# TODO(zhangchi.usc1992) add consistency check
|
||||||
|
batch: TensorDict = None
|
||||||
|
non_tensor_batch: Dict = field(default_factory=dict)
|
||||||
|
meta_info: Dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataProto:
|
||||||
|
"""
|
||||||
|
A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions.
|
||||||
|
It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/.
|
||||||
|
TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the
|
||||||
|
same batch size should be put inside batch.
|
||||||
|
"""
|
||||||
|
batch: TensorDict = None
|
||||||
|
non_tensor_batch: Dict = field(default_factory=dict)
|
||||||
|
meta_info: Dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# perform necessary checking
|
||||||
|
self.check_consistency()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
if self.batch is not None:
|
||||||
|
return self.batch.batch_size[0]
|
||||||
|
elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0:
|
||||||
|
random_key = list(self.non_tensor_batch.keys())[0]
|
||||||
|
return self.non_tensor_batch[random_key].shape[0]
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
tensor_data = self.batch[item]
|
||||||
|
non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
|
||||||
|
return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
import io
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
if tensordict.__version__ >= '0.5.0' and self.batch is not None:
|
||||||
|
self.batch = self.batch.contiguous()
|
||||||
|
self.batch = self.batch.consolidate()
|
||||||
|
torch.save(self.batch, buffer)
|
||||||
|
buffer_bytes = buffer.getvalue()
|
||||||
|
return buffer_bytes, self.non_tensor_batch, self.meta_info
|
||||||
|
|
||||||
|
def __setstate__(self, data):
|
||||||
|
import io
|
||||||
|
batch_deserialized_bytes, non_tensor_batch, meta_info = data
|
||||||
|
batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes)
|
||||||
|
batch = torch.load(batch_deserialized,
|
||||||
|
weights_only=False,
|
||||||
|
map_location='cpu' if not torch.cuda.is_available() else None)
|
||||||
|
self.batch = batch
|
||||||
|
self.non_tensor_batch = non_tensor_batch
|
||||||
|
self.meta_info = meta_info
|
||||||
|
|
||||||
|
def save_to_disk(self, filepath):
|
||||||
|
with open(filepath, 'wb') as f:
|
||||||
|
pickle.dump(self, f)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_from_disk(filepath) -> 'DataProto':
|
||||||
|
with open(filepath, 'rb') as f:
|
||||||
|
data = pickle.load(f)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def print_size(self, prefix=""):
|
||||||
|
size_of_tensordict = 0
|
||||||
|
for key, tensor in self.batch.items():
|
||||||
|
size_of_tensordict += tensor.element_size() * tensor.numel()
|
||||||
|
size_of_numpy_array = 0
|
||||||
|
for key, numpy_array in self.non_tensor_batch.items():
|
||||||
|
size_of_numpy_array += numpy_array.nbytes
|
||||||
|
|
||||||
|
size_of_numpy_array /= 1024**3
|
||||||
|
size_of_tensordict /= 1024**3
|
||||||
|
|
||||||
|
message = f'Size of tensordict: {size_of_tensordict} GB, size of non_tensor_batch: {size_of_numpy_array} GB'
|
||||||
|
|
||||||
|
if prefix:
|
||||||
|
message = f'{prefix}, ' + message
|
||||||
|
print(message)
|
||||||
|
|
||||||
|
def check_consistency(self):
|
||||||
|
"""Check the consistency of the DataProto. Mainly for batch and non_tensor_batch
|
||||||
|
We expose this function as a public one so that user can call themselves directly
|
||||||
|
"""
|
||||||
|
if self.batch is not None:
|
||||||
|
assert len(self.batch.batch_size) == 1, 'only support num_batch_dims=1'
|
||||||
|
|
||||||
|
if self.non_tensor_batch is not None:
|
||||||
|
for key, val in self.non_tensor_batch.items():
|
||||||
|
assert isinstance(val, np.ndarray)
|
||||||
|
|
||||||
|
if self.batch is not None and len(self.non_tensor_batch) != 0:
|
||||||
|
# TODO: we can actually lift this restriction if needed
|
||||||
|
assert len(self.batch.batch_size) == 1, 'only support num_batch_dims=1 when non_tensor_batch is not empty.'
|
||||||
|
|
||||||
|
batch_size = self.batch.batch_size[0]
|
||||||
|
for key, val in self.non_tensor_batch.items():
|
||||||
|
assert isinstance(
|
||||||
|
val, np.ndarray
|
||||||
|
) and val.dtype == object, 'data in the non_tensor_batch must be a numpy.array with dtype=object'
|
||||||
|
assert val.shape[
|
||||||
|
0] == batch_size, f'key {key} length {len(val)} is not equal to batch size {batch_size}'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_single_dict(cls, data: Dict[str, Union[torch.Tensor, np.ndarray]], meta_info=None):
|
||||||
|
tensors = {}
|
||||||
|
non_tensors = {}
|
||||||
|
|
||||||
|
for key, val in data.items():
|
||||||
|
if isinstance(val, torch.Tensor):
|
||||||
|
tensors[key] = val
|
||||||
|
elif isinstance(val, np.ndarray):
|
||||||
|
non_tensors[key] = val
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unsupported type in data {type(val)}')
|
||||||
|
|
||||||
|
return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, tensors: Dict[str, torch.Tensor], non_tensors=None, meta_info=None, num_batch_dims=1):
|
||||||
|
"""Create a DataProto from a dict of tensors. This assumes that
|
||||||
|
1. All the tensor in tensors have the same dim0
|
||||||
|
2. Only dim0 is the batch dim
|
||||||
|
"""
|
||||||
|
assert len(tensors) > 0, 'tensors must not be empty'
|
||||||
|
assert num_batch_dims > 0, 'num_batch_dims must be greater than zero'
|
||||||
|
if non_tensors is not None:
|
||||||
|
assert num_batch_dims == 1, 'only support num_batch_dims=1 when non_tensors is not None.'
|
||||||
|
|
||||||
|
if meta_info is None:
|
||||||
|
meta_info = {}
|
||||||
|
if non_tensors is None:
|
||||||
|
non_tensors = {}
|
||||||
|
|
||||||
|
assert isinstance(non_tensors, dict)
|
||||||
|
|
||||||
|
# get and check batch size
|
||||||
|
batch_size = None
|
||||||
|
pivot_key = None
|
||||||
|
for key, tensor in tensors.items():
|
||||||
|
if batch_size is None:
|
||||||
|
batch_size = tensor.shape[:num_batch_dims]
|
||||||
|
pivot_key = key
|
||||||
|
else:
|
||||||
|
current_batch = tensor.shape[:num_batch_dims]
|
||||||
|
assert batch_size == current_batch, \
|
||||||
|
f'Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. Got {pivot_key} has {batch_size}, {key} has {current_batch}'
|
||||||
|
|
||||||
|
for key, val in non_tensors.items():
|
||||||
|
non_tensors[key] = np.array(val, dtype=object)
|
||||||
|
|
||||||
|
tensor_dict = TensorDict(source=tensors, batch_size=batch_size)
|
||||||
|
return cls(batch=tensor_dict, non_tensor_batch=non_tensors, meta_info=meta_info)
|
||||||
|
|
||||||
|
def to(self, device) -> 'DataProto':
|
||||||
|
"""move the batch to device
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (torch.device, str): torch device
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataProto: the current DataProto
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.batch is not None:
|
||||||
|
self.batch = self.batch.to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) -> 'DataProto':
|
||||||
|
"""Select a subset of the DataProto via batch_keys and meta_info_keys
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_keys (list, optional): a list of strings indicating the keys in batch to select
|
||||||
|
meta_info_keys (list, optional): a list of keys indicating the meta info to select
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataProto: the DataProto with the selected batch_keys and meta_info_keys
|
||||||
|
"""
|
||||||
|
# TODO (zhangchi.usc1992) whether to copy
|
||||||
|
if batch_keys is not None:
|
||||||
|
batch_keys = tuple(batch_keys)
|
||||||
|
sub_batch = self.batch.select(*batch_keys)
|
||||||
|
else:
|
||||||
|
sub_batch = self.batch
|
||||||
|
|
||||||
|
if non_tensor_batch_keys is not None:
|
||||||
|
non_tensor_batch = {key: val for key, val in self.non_tensor_batch.items() if key in non_tensor_batch_keys}
|
||||||
|
else:
|
||||||
|
non_tensor_batch = self.non_tensor_batch
|
||||||
|
|
||||||
|
if deepcopy:
|
||||||
|
non_tensor_batch = copy.deepcopy(non_tensor_batch)
|
||||||
|
|
||||||
|
if meta_info_keys is not None:
|
||||||
|
sub_meta_info = {key: val for key, val in self.meta_info.items() if key in meta_info_keys}
|
||||||
|
else:
|
||||||
|
sub_meta_info = self.meta_info
|
||||||
|
|
||||||
|
if deepcopy:
|
||||||
|
sub_meta_info = copy.deepcopy(sub_meta_info)
|
||||||
|
|
||||||
|
return DataProto(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info)
|
||||||
|
|
||||||
|
def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> 'DataProto':
|
||||||
|
"""Pop a subset of the DataProto via `batch_keys` and `meta_info_keys`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_keys (list, optional): a list of strings indicating the keys in batch to pop
|
||||||
|
meta_info_keys (list, optional): a list of keys indicating the meta info to pop
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataProto: the DataProto with the poped batch_keys and meta_info_keys
|
||||||
|
"""
|
||||||
|
assert batch_keys is not None
|
||||||
|
if meta_info_keys is None:
|
||||||
|
meta_info_keys = []
|
||||||
|
if non_tensor_batch_keys is None:
|
||||||
|
non_tensor_batch_keys = []
|
||||||
|
|
||||||
|
tensors = {}
|
||||||
|
# tensor batch
|
||||||
|
for key in batch_keys:
|
||||||
|
assert key in self.batch.keys()
|
||||||
|
tensors[key] = self.batch.pop(key)
|
||||||
|
non_tensors = {}
|
||||||
|
# non tensor batch
|
||||||
|
for key in non_tensor_batch_keys:
|
||||||
|
assert key in self.non_tensor_batch.keys()
|
||||||
|
non_tensors[key] = self.non_tensor_batch.pop(key)
|
||||||
|
meta_info = {}
|
||||||
|
for key in meta_info_keys:
|
||||||
|
assert key in self.meta_info.keys()
|
||||||
|
meta_info[key] = self.meta_info.pop(key)
|
||||||
|
return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info)
|
||||||
|
|
||||||
|
def rename(self, old_keys=None, new_keys=None) -> 'DataProto':
|
||||||
|
"""
|
||||||
|
Note that this function only rename the key in the batch
|
||||||
|
"""
|
||||||
|
|
||||||
|
def validate_input(keys):
|
||||||
|
if keys is not None:
|
||||||
|
if isinstance(keys, str):
|
||||||
|
keys = [keys]
|
||||||
|
elif isinstance(keys, list):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise TypeError(f'keys must be a list or a string, but got {type(keys)}')
|
||||||
|
return keys
|
||||||
|
|
||||||
|
old_keys = validate_input(old_keys)
|
||||||
|
new_keys = validate_input(new_keys)
|
||||||
|
|
||||||
|
if len(new_keys) != len(old_keys):
|
||||||
|
raise ValueError(
|
||||||
|
f'new_keys and old_keys must have the same length, but got {len(new_keys)} and {len(old_keys)}')
|
||||||
|
|
||||||
|
self.batch.rename_key_(tuple(old_keys), tuple(new_keys))
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def union(self, other: 'DataProto') -> 'DataProto':
|
||||||
|
"""Union with another DataProto. Union batch and meta_info separately.
|
||||||
|
Throw an error if
|
||||||
|
- there are conflict keys in batch and they are not equal
|
||||||
|
- the batch size of two data batch is not the same
|
||||||
|
- there are conflict keys in meta_info and they are not the same.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other (DataProto): another DataProto to union
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataProto: the DataProto after union
|
||||||
|
"""
|
||||||
|
self.batch = union_tensor_dict(self.batch, other.batch)
|
||||||
|
self.non_tensor_batch = union_numpy_dict(self.non_tensor_batch, other.non_tensor_batch)
|
||||||
|
self.meta_info = union_two_dict(self.meta_info, other.meta_info)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def make_iterator(self, mini_batch_size, epochs, seed=None, dataloader_kwargs=None):
|
||||||
|
"""Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch
|
||||||
|
dataset. See https://pytorch.org/tensordict/tutorials/data_fashion for more details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mini_batch_size (int): mini-batch size when iterating the dataset. We require that
|
||||||
|
``batch.batch_size[0] % mini_batch_size == 0``
|
||||||
|
epochs (int): number of epochs when iterating the dataset.
|
||||||
|
dataloader_kwargs: internally, it returns a DataLoader over the batch.
|
||||||
|
The dataloader_kwargs is the kwargs passed to the DataLoader
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration steps is
|
||||||
|
``self.batch.batch_size * epochs // mini_batch_size``
|
||||||
|
"""
|
||||||
|
assert self.batch.batch_size[0] % mini_batch_size == 0, f"{self.batch.batch_size[0]} % {mini_batch_size} != 0"
|
||||||
|
# we can directly create a dataloader from TensorDict
|
||||||
|
if dataloader_kwargs is None:
|
||||||
|
dataloader_kwargs = {}
|
||||||
|
|
||||||
|
if seed is not None:
|
||||||
|
generator = torch.Generator()
|
||||||
|
generator.manual_seed(seed)
|
||||||
|
else:
|
||||||
|
generator = None
|
||||||
|
|
||||||
|
assert isinstance(dataloader_kwargs, Dict)
|
||||||
|
train_dataloader = DataLoader(dataset=self,
|
||||||
|
batch_size=mini_batch_size,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
generator=generator,
|
||||||
|
**dataloader_kwargs)
|
||||||
|
|
||||||
|
def get_data():
|
||||||
|
for _ in range(epochs):
|
||||||
|
for d in train_dataloader:
|
||||||
|
d.meta_info = self.meta_info
|
||||||
|
yield d
|
||||||
|
|
||||||
|
return iter(get_data())
|
||||||
|
|
||||||
|
def chunk(self, chunks: int) -> List['DataProto']:
|
||||||
|
"""Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunks (int): the number of chunks to split on dim=0
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[DataProto]: a list of DataProto after splitting
|
||||||
|
"""
|
||||||
|
assert len(
|
||||||
|
self) % chunks == 0, f'only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}.'
|
||||||
|
|
||||||
|
if self.batch is not None:
|
||||||
|
batch_lst = self.batch.chunk(chunks=chunks, dim=0)
|
||||||
|
else:
|
||||||
|
batch_lst = [None for _ in range(chunks)]
|
||||||
|
|
||||||
|
non_tensor_batch_lst = [{} for _ in range(chunks)]
|
||||||
|
for key, val in self.non_tensor_batch.items():
|
||||||
|
assert isinstance(val, np.ndarray)
|
||||||
|
non_tensor_lst = np.array_split(val, chunks)
|
||||||
|
assert len(non_tensor_lst) == chunks
|
||||||
|
for i in range(chunks):
|
||||||
|
non_tensor_batch_lst[i][key] = non_tensor_lst[i]
|
||||||
|
|
||||||
|
output = []
|
||||||
|
for i in range(chunks):
|
||||||
|
output.append(
|
||||||
|
DataProto(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info))
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def concat(data: List['DataProto']) -> 'DataProto':
|
||||||
|
"""Concat a list of DataProto. The batch is concatenated among dim=0.
|
||||||
|
The meta_info is assumed to be identical and will use the first one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (List[DataProto]): list of DataProto
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataProto: concatenated DataProto
|
||||||
|
"""
|
||||||
|
batch_lst = []
|
||||||
|
for batch in data:
|
||||||
|
batch_lst.append(batch.batch)
|
||||||
|
if batch_lst[0] is not None:
|
||||||
|
new_batch = torch.cat(batch_lst, dim=0)
|
||||||
|
else:
|
||||||
|
new_batch = None
|
||||||
|
|
||||||
|
non_tensor_batch = list_of_dict_to_dict_of_list(list_of_dict=[d.non_tensor_batch for d in data])
|
||||||
|
for key, val in non_tensor_batch.items():
|
||||||
|
non_tensor_batch[key] = np.concatenate(val, axis=0)
|
||||||
|
|
||||||
|
return DataProto(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info)
|
||||||
|
|
||||||
|
def reorder(self, indices):
|
||||||
|
"""
|
||||||
|
Note that this operation is in-place
|
||||||
|
"""
|
||||||
|
indices_np = indices.detach().numpy()
|
||||||
|
self.batch = self.batch[indices]
|
||||||
|
self.non_tensor_batch = {key: val[indices_np] for key, val in self.non_tensor_batch.items()}
|
||||||
|
|
||||||
|
def repeat(self, repeat_times=2, interleave=True):
|
||||||
|
"""
|
||||||
|
Repeat the batch data a specified number of times.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repeat_times (int): Number of times to repeat the data.
|
||||||
|
interleave (bool): Whether to interleave the repeated data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataProto: A new DataProto with repeated data.
|
||||||
|
"""
|
||||||
|
if self.batch is not None:
|
||||||
|
if interleave:
|
||||||
|
# Interleave the data
|
||||||
|
repeated_tensors = {
|
||||||
|
key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Stack the data
|
||||||
|
repeated_tensors = {
|
||||||
|
key: tensor.unsqueeze(0).expand(repeat_times, *tensor.shape).reshape(-1, *tensor.shape[1:])
|
||||||
|
for key, tensor in self.batch.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
repeated_batch = TensorDict(
|
||||||
|
source=repeated_tensors,
|
||||||
|
batch_size=(self.batch.batch_size[0] * repeat_times,),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
repeated_batch = None
|
||||||
|
|
||||||
|
repeated_non_tensor_batch = {}
|
||||||
|
for key, val in self.non_tensor_batch.items():
|
||||||
|
if interleave:
|
||||||
|
repeated_non_tensor_batch[key] = np.repeat(val, repeat_times, axis=0)
|
||||||
|
else:
|
||||||
|
repeated_non_tensor_batch[key] = np.tile(val, (repeat_times,) + (1,) * (val.ndim - 1))
|
||||||
|
|
||||||
|
return DataProto(
|
||||||
|
batch=repeated_batch,
|
||||||
|
non_tensor_batch=repeated_non_tensor_batch,
|
||||||
|
meta_info=self.meta_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
import ray
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataProtoFuture:
|
||||||
|
"""
|
||||||
|
DataProtoFuture aims to eliminate actual data fetching on driver. By doing so, the driver doesn't have to wait
|
||||||
|
for data so that asynchronous execution becomes possible.
|
||||||
|
DataProtoFuture contains a list of futures from another WorkerGroup of size world_size.
|
||||||
|
- collect_fn is a Callable that reduces the list of futures to a DataProto
|
||||||
|
- dispatch_fn is a Callable that partitions the DataProto into a list of DataProto of size world_size and then select
|
||||||
|
|
||||||
|
Potential issue: we can optimize dispatch_fn(collect_fn) such that only needed data is fetched on destination
|
||||||
|
- DataProtoFuture only supports directly passing from the output of a method to another input. You can't perform any
|
||||||
|
operation on the DataProtoFuture in driver.
|
||||||
|
"""
|
||||||
|
collect_fn: Callable
|
||||||
|
futures: List[ray.ObjectRef]
|
||||||
|
dispatch_fn: Callable = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def concat(data: List[ray.ObjectRef]) -> 'DataProtoFuture':
|
||||||
|
output = DataProtoFuture(collect_fn=DataProto.concat, futures=data)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def chunk(self, chunks: int) -> List['DataProtoFuture']:
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
arg_future_lst = []
|
||||||
|
for i in range(chunks):
|
||||||
|
# note that we can't directly pass i and chunks
|
||||||
|
def dispatch_fn(x, i, chunks):
|
||||||
|
return x.chunk(chunks=chunks)[i]
|
||||||
|
|
||||||
|
arg_future = DataProtoFuture(collect_fn=self.collect_fn,
|
||||||
|
dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks),
|
||||||
|
futures=self.futures)
|
||||||
|
arg_future_lst.append(arg_future)
|
||||||
|
return arg_future_lst
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
output = ray.get(self.futures) # dp_size.
|
||||||
|
for o in output:
|
||||||
|
assert isinstance(o, DataProto)
|
||||||
|
output = self.collect_fn(output) # select dp, concat
|
||||||
|
if self.dispatch_fn is not None:
|
||||||
|
output = self.dispatch_fn(output) # split in batch dim, select using dp
|
||||||
|
return output
|
||||||
20
verl/single_controller/__init__.py
Normal file
20
verl/single_controller/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
with open(os.path.join(version_folder, 'version/version')) as f:
|
||||||
|
__version__ = f.read().strip()
|
||||||
16
verl/single_controller/base/__init__.py
Normal file
16
verl/single_controller/base/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .worker import Worker
|
||||||
|
from .worker_group import WorkerGroup, ClassWithInitArgs, ResourcePool
|
||||||
410
verl/single_controller/base/decorator.py
Normal file
410
verl/single_controller/base/decorator.py
Normal file
@@ -0,0 +1,410 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
from types import FunctionType
|
||||||
|
from verl.protocol import DataProtoFuture
|
||||||
|
|
||||||
|
# here we add a magic number of avoid user-defined function already have this attribute
|
||||||
|
MAGIC_ATTR = 'attrs_3141562937'
|
||||||
|
|
||||||
|
|
||||||
|
class Dispatch(Enum):
|
||||||
|
RANK_ZERO = 0
|
||||||
|
ONE_TO_ALL = 1
|
||||||
|
ALL_TO_ALL = 2
|
||||||
|
MEGATRON_COMPUTE = 3
|
||||||
|
MEGATRON_PP_AS_DP = 4
|
||||||
|
MEGATRON_PP_ONLY = 5
|
||||||
|
MEGATRON_COMPUTE_PROTO = 6
|
||||||
|
MEGATRON_PP_AS_DP_PROTO = 7
|
||||||
|
DP_COMPUTE = 8
|
||||||
|
DP_COMPUTE_PROTO = 9
|
||||||
|
DP_COMPUTE_PROTO_WITH_FUNC = 10
|
||||||
|
DP_COMPUTE_METRIC = 11
|
||||||
|
|
||||||
|
|
||||||
|
class Execute(Enum):
|
||||||
|
ALL = 0
|
||||||
|
RANK_ZERO = 1
|
||||||
|
|
||||||
|
|
||||||
|
def _split_args_kwargs_data_proto(chunks, *args, **kwargs):
|
||||||
|
from verl.protocol import DataProto, DataProtoFuture
|
||||||
|
splitted_args = []
|
||||||
|
for arg in args:
|
||||||
|
assert isinstance(arg, (DataProto, DataProtoFuture))
|
||||||
|
splitted_args.append(arg.chunk(chunks=chunks))
|
||||||
|
|
||||||
|
splitted_kwargs = {}
|
||||||
|
for key, val in kwargs.items():
|
||||||
|
assert isinstance(val, (DataProto, DataProtoFuture))
|
||||||
|
splitted_kwargs[key] = val.chunk(chunks=chunks)
|
||||||
|
|
||||||
|
return splitted_args, splitted_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def dispatch_one_to_all(worker_group, *args, **kwargs):
|
||||||
|
args = tuple([arg] * worker_group.world_size for arg in args)
|
||||||
|
kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}
|
||||||
|
return args, kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def dispatch_all_to_all(worker_group, *args, **kwargs):
|
||||||
|
return args, kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def collect_all_to_all(worker_group, output):
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def dispatch_megatron_compute(worker_group, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
User passes in dp data. The data is dispatched to all tp/pp ranks with the same dp
|
||||||
|
"""
|
||||||
|
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
|
||||||
|
assert isinstance(worker_group,
|
||||||
|
MegatronWorkerGroup), f'worker_group must be MegatronWorkerGroup, Got {type(worker_group)}'
|
||||||
|
|
||||||
|
all_args = []
|
||||||
|
for arg in args:
|
||||||
|
assert isinstance(arg, (Tuple, List)) and len(arg) == worker_group.dp_size
|
||||||
|
transformed_args = []
|
||||||
|
for i in range(worker_group.world_size):
|
||||||
|
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
|
||||||
|
transformed_args.append(arg[local_dp_rank])
|
||||||
|
all_args.append(transformed_args)
|
||||||
|
all_args = tuple(all_args)
|
||||||
|
|
||||||
|
all_kwargs = {}
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
assert isinstance(v, (Tuple, List)) and len(v) == worker_group.dp_size
|
||||||
|
transformed_v = []
|
||||||
|
for i in range(worker_group.world_size):
|
||||||
|
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
|
||||||
|
transformed_v.append(v[local_dp_rank])
|
||||||
|
all_kwargs[k] = transformed_v
|
||||||
|
return all_args, all_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def collect_megatron_compute(worker_group, output):
|
||||||
|
"""
|
||||||
|
Only collect the data from the tp=0 and pp=last and every dp ranks
|
||||||
|
"""
|
||||||
|
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
|
||||||
|
assert isinstance(worker_group, MegatronWorkerGroup)
|
||||||
|
output_in_dp = []
|
||||||
|
pp_size = worker_group.get_megatron_global_info().pp_size
|
||||||
|
for global_rank in range(worker_group.world_size):
|
||||||
|
local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank)
|
||||||
|
if local_rank_info.tp_rank == 0 and local_rank_info.pp_rank == pp_size - 1:
|
||||||
|
output_in_dp.append(output[global_rank])
|
||||||
|
return output_in_dp
|
||||||
|
|
||||||
|
|
||||||
|
def dispatch_megatron_compute_data_proto(worker_group, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
All the args and kwargs must be DataProto. The batch will be chunked by dp_size and passed to each rank
|
||||||
|
"""
|
||||||
|
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
|
||||||
|
assert isinstance(worker_group, MegatronWorkerGroup)
|
||||||
|
|
||||||
|
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.dp_size, *args, **kwargs)
|
||||||
|
return dispatch_megatron_compute(worker_group, *splitted_args, **splitted_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _concat_data_proto_or_future(output: List):
|
||||||
|
from verl.protocol import DataProto, DataProtoFuture
|
||||||
|
import ray
|
||||||
|
|
||||||
|
# make sure all the elements in output has the same type
|
||||||
|
for o in output:
|
||||||
|
assert type(o) == type(output[0])
|
||||||
|
|
||||||
|
o = output[0]
|
||||||
|
|
||||||
|
if isinstance(o, DataProto):
|
||||||
|
return DataProto.concat(output)
|
||||||
|
elif isinstance(o, ray.ObjectRef):
|
||||||
|
return DataProtoFuture.concat(output)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def collect_megatron_compute_data_proto(worker_group, output):
|
||||||
|
"""
|
||||||
|
Each output must be a DataProto. We concat the dim=0 of output
|
||||||
|
"""
|
||||||
|
from verl.protocol import DataProto
|
||||||
|
import ray
|
||||||
|
|
||||||
|
output = collect_megatron_compute(worker_group, output)
|
||||||
|
for o in output:
|
||||||
|
assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}"
|
||||||
|
|
||||||
|
return _concat_data_proto_or_future(output)
|
||||||
|
|
||||||
|
|
||||||
|
def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
treat pp as dp.
|
||||||
|
"""
|
||||||
|
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
|
||||||
|
assert isinstance(worker_group, MegatronWorkerGroup)
|
||||||
|
|
||||||
|
pp_size = worker_group.pp_size
|
||||||
|
dp_size = worker_group.dp_size
|
||||||
|
|
||||||
|
pp_dp_size = pp_size * dp_size
|
||||||
|
|
||||||
|
all_args = []
|
||||||
|
for arg in args:
|
||||||
|
assert isinstance(arg, (List, Tuple)) and len(arg) == pp_dp_size
|
||||||
|
transformed_args = []
|
||||||
|
for i in range(worker_group.world_size):
|
||||||
|
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
|
||||||
|
local_pp_rank = worker_group.get_megatron_rank_info(rank=i).pp_rank
|
||||||
|
# compute the rank in arg. Note that the order is dp then pp
|
||||||
|
# Also note that the outputs within a pp group will be firstly allgathered, then only the output of pp0 will be collected.
|
||||||
|
# For pp=2 dp=4, a batch of data "ABCDEFGH" should be dispatched and collected in below order:
|
||||||
|
# dispatch: pp_allgther: collect:
|
||||||
|
# dp 0 1 2 3 dp 0 1 2 3
|
||||||
|
# pp +---------+ pp +-------------+
|
||||||
|
# 0 | A C E G | 0 | AB CD EF GH | ABCDEFGH
|
||||||
|
# 1 | B D F H | 1 | AB CD EF GH |
|
||||||
|
# +---------+ +-------------+
|
||||||
|
arg_rank = local_dp_rank * worker_group.pp_size + local_pp_rank
|
||||||
|
|
||||||
|
transformed_args.append(arg[arg_rank])
|
||||||
|
all_args.append(transformed_args)
|
||||||
|
all_args = tuple(all_args)
|
||||||
|
|
||||||
|
all_kwargs = {}
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
assert isinstance(v, (List, Tuple)) and len(v) == pp_dp_size, f'expect len(v)=={pp_dp_size}, got {len(v)}'
|
||||||
|
transformed_v = []
|
||||||
|
for i in range(worker_group.world_size):
|
||||||
|
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
|
||||||
|
local_pp_rank = worker_group.get_megatron_rank_info(rank=i).pp_rank
|
||||||
|
# compute the rank in arg. Note that the order is dp then pp
|
||||||
|
arg_rank = local_dp_rank * worker_group.pp_size + local_pp_rank
|
||||||
|
transformed_v.append(v[arg_rank])
|
||||||
|
all_kwargs[k] = transformed_v
|
||||||
|
return all_args, all_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def collect_megatron_pp_as_dp(worker_group, output):
|
||||||
|
"""
|
||||||
|
treat pp as dp. Only collect data on tp=0
|
||||||
|
"""
|
||||||
|
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
|
||||||
|
assert isinstance(worker_group, MegatronWorkerGroup)
|
||||||
|
output_in_dp = []
|
||||||
|
for global_rank in range(worker_group.world_size):
|
||||||
|
local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank)
|
||||||
|
if local_rank_info.tp_rank == 0 and local_rank_info.pp_rank == 0:
|
||||||
|
output_in_dp.append(output[global_rank])
|
||||||
|
return output_in_dp
|
||||||
|
|
||||||
|
|
||||||
|
def collect_megatron_pp_only(worker_group, output):
|
||||||
|
"""
|
||||||
|
Only collect output of megatron pp. This is useful when examine weight names as they are identical in tp/dp
|
||||||
|
"""
|
||||||
|
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
|
||||||
|
assert isinstance(worker_group, MegatronWorkerGroup)
|
||||||
|
output_in_pp = []
|
||||||
|
for global_rank in range(worker_group.world_size):
|
||||||
|
local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank)
|
||||||
|
if local_rank_info.tp_rank == 0 and local_rank_info.dp_rank == 0:
|
||||||
|
output_in_pp.append(output[global_rank])
|
||||||
|
return output_in_pp
|
||||||
|
|
||||||
|
|
||||||
|
def dispatch_megatron_pp_as_dp_data_proto(worker_group, *args, **kwargs):
|
||||||
|
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
|
||||||
|
assert isinstance(worker_group, MegatronWorkerGroup)
|
||||||
|
|
||||||
|
pp_dp_size = worker_group.dp_size * worker_group.pp_size
|
||||||
|
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(pp_dp_size, *args, **kwargs)
|
||||||
|
return dispatch_megatron_pp_as_dp(worker_group, *splitted_args, **splitted_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def collect_megatron_pp_as_dp_data_proto(worker_group, output):
|
||||||
|
from verl.protocol import DataProto
|
||||||
|
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
|
||||||
|
assert isinstance(worker_group, MegatronWorkerGroup)
|
||||||
|
|
||||||
|
output = collect_megatron_pp_as_dp(worker_group, output)
|
||||||
|
return _concat_data_proto_or_future(output)
|
||||||
|
|
||||||
|
|
||||||
|
def dispatch_dp_compute(worker_group, *args, **kwargs):
|
||||||
|
from verl.single_controller.base.worker_group import WorkerGroup
|
||||||
|
assert isinstance(worker_group, WorkerGroup)
|
||||||
|
for arg in args:
|
||||||
|
assert isinstance(arg, (Tuple, List)) and len(arg) == worker_group.world_size
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
assert isinstance(v, (Tuple, List)) and len(v) == worker_group.world_size
|
||||||
|
return args, kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def collect_dp_compute(worker_group, output):
|
||||||
|
from verl.single_controller.base.worker_group import WorkerGroup
|
||||||
|
assert isinstance(worker_group, WorkerGroup)
|
||||||
|
assert len(output) == worker_group.world_size
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs):
|
||||||
|
from verl.single_controller.base.worker_group import WorkerGroup
|
||||||
|
assert isinstance(worker_group, WorkerGroup)
|
||||||
|
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs)
|
||||||
|
return splitted_args, splitted_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def dispatch_dp_compute_data_proto_with_func(worker_group, *args, **kwargs):
|
||||||
|
from verl.single_controller.base.worker_group import WorkerGroup
|
||||||
|
assert isinstance(worker_group, WorkerGroup)
|
||||||
|
assert type(args[0]) == FunctionType # NOTE: The first one args is a function!
|
||||||
|
|
||||||
|
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs)
|
||||||
|
splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args
|
||||||
|
return splitted_args_with_func, splitted_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def collect_dp_compute_data_proto(worker_group, output):
|
||||||
|
from verl.protocol import DataProto
|
||||||
|
import ray
|
||||||
|
|
||||||
|
for o in output:
|
||||||
|
assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}"
|
||||||
|
|
||||||
|
output = collect_dp_compute(worker_group, output)
|
||||||
|
return _concat_data_proto_or_future(output)
|
||||||
|
|
||||||
|
|
||||||
|
def get_predefined_dispatch_fn(dispatch_mode):
|
||||||
|
predefined_dispatch_mode_fn = {
|
||||||
|
Dispatch.ONE_TO_ALL: {
|
||||||
|
'dispatch_fn': dispatch_one_to_all,
|
||||||
|
'collect_fn': collect_all_to_all,
|
||||||
|
},
|
||||||
|
Dispatch.ALL_TO_ALL: {
|
||||||
|
'dispatch_fn': dispatch_all_to_all,
|
||||||
|
'collect_fn': collect_all_to_all,
|
||||||
|
},
|
||||||
|
Dispatch.MEGATRON_COMPUTE: {
|
||||||
|
'dispatch_fn': dispatch_megatron_compute,
|
||||||
|
'collect_fn': collect_megatron_compute,
|
||||||
|
},
|
||||||
|
Dispatch.MEGATRON_PP_AS_DP: {
|
||||||
|
'dispatch_fn': dispatch_megatron_pp_as_dp,
|
||||||
|
'collect_fn': collect_megatron_pp_as_dp,
|
||||||
|
},
|
||||||
|
Dispatch.MEGATRON_PP_ONLY: {
|
||||||
|
'dispatch_fn': dispatch_one_to_all,
|
||||||
|
'collect_fn': collect_megatron_pp_only
|
||||||
|
},
|
||||||
|
Dispatch.MEGATRON_COMPUTE_PROTO: {
|
||||||
|
'dispatch_fn': dispatch_megatron_compute_data_proto,
|
||||||
|
'collect_fn': collect_megatron_compute_data_proto
|
||||||
|
},
|
||||||
|
Dispatch.MEGATRON_PP_AS_DP_PROTO: {
|
||||||
|
'dispatch_fn': dispatch_megatron_pp_as_dp_data_proto,
|
||||||
|
'collect_fn': collect_megatron_pp_as_dp_data_proto
|
||||||
|
},
|
||||||
|
Dispatch.DP_COMPUTE: {
|
||||||
|
'dispatch_fn': dispatch_dp_compute,
|
||||||
|
'collect_fn': collect_dp_compute
|
||||||
|
},
|
||||||
|
Dispatch.DP_COMPUTE_PROTO: {
|
||||||
|
'dispatch_fn': dispatch_dp_compute_data_proto,
|
||||||
|
'collect_fn': collect_dp_compute_data_proto
|
||||||
|
},
|
||||||
|
Dispatch.DP_COMPUTE_PROTO_WITH_FUNC: {
|
||||||
|
'dispatch_fn': dispatch_dp_compute_data_proto_with_func,
|
||||||
|
'collect_fn': collect_dp_compute_data_proto
|
||||||
|
},
|
||||||
|
Dispatch.DP_COMPUTE_METRIC: {
|
||||||
|
'dispatch_fn': dispatch_dp_compute_data_proto,
|
||||||
|
'collect_fn': collect_dp_compute
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return predefined_dispatch_mode_fn[dispatch_mode]
|
||||||
|
|
||||||
|
|
||||||
|
def get_predefined_execute_fn(execute_mode):
|
||||||
|
"""
|
||||||
|
Note that here we only asks execute_all and execute_rank_zero to be implemented
|
||||||
|
Leave the choice of how these two functions handle argument 'blocking' to users
|
||||||
|
"""
|
||||||
|
predefined_execute_mode_fn = {
|
||||||
|
Execute.ALL: {
|
||||||
|
'execute_fn_name': 'execute_all'
|
||||||
|
},
|
||||||
|
Execute.RANK_ZERO: {
|
||||||
|
'execute_fn_name': 'execute_rank_zero'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return predefined_execute_mode_fn[execute_mode]
|
||||||
|
|
||||||
|
|
||||||
|
def _check_dispatch_mode(dispatch_mode):
|
||||||
|
assert isinstance(dispatch_mode,
|
||||||
|
(Dispatch, Dict)), f'dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}'
|
||||||
|
if isinstance(dispatch_mode, Dict):
|
||||||
|
necessary_keys = ['dispatch_fn', 'collect_fn']
|
||||||
|
for key in necessary_keys:
|
||||||
|
assert key in dispatch_mode, f'key {key} should be in dispatch_mode if it is a dictionary'
|
||||||
|
|
||||||
|
|
||||||
|
def _check_execute_mode(execute_mode):
|
||||||
|
assert isinstance(execute_mode, Execute), f'execute_mode must be a Execute. Got {execute_mode}'
|
||||||
|
|
||||||
|
|
||||||
|
def _materialize_futures(*args, **kwargs):
|
||||||
|
new_args = []
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, DataProtoFuture):
|
||||||
|
arg = arg.get()
|
||||||
|
# add more type to materialize
|
||||||
|
new_args.append(arg)
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if isinstance(v, DataProtoFuture):
|
||||||
|
kwargs[k] = v.get()
|
||||||
|
|
||||||
|
new_args = tuple(new_args)
|
||||||
|
return new_args, kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True):
|
||||||
|
_check_dispatch_mode(dispatch_mode=dispatch_mode)
|
||||||
|
_check_execute_mode(execute_mode=execute_mode)
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def inner(*args, **kwargs):
|
||||||
|
if materialize_futures:
|
||||||
|
args, kwargs = _materialize_futures(*args, **kwargs)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
attrs = {'dispatch_mode': dispatch_mode, 'execute_mode': execute_mode, 'blocking': blocking}
|
||||||
|
setattr(inner, MAGIC_ATTR, attrs)
|
||||||
|
return inner
|
||||||
|
|
||||||
|
return decorator
|
||||||
13
verl/single_controller/base/megatron/__init__.py
Normal file
13
verl/single_controller/base/megatron/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
39
verl/single_controller/base/megatron/worker.py
Normal file
39
verl/single_controller/base/megatron/worker.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from verl.single_controller.base.worker import Worker, DistRankInfo, DistGlobalInfo
|
||||||
|
|
||||||
|
|
||||||
|
class MegatronWorker(Worker):
|
||||||
|
|
||||||
|
def __init__(self, cuda_visible_devices=None) -> None:
|
||||||
|
super().__init__(cuda_visible_devices)
|
||||||
|
|
||||||
|
def get_megatron_global_info(self):
|
||||||
|
from megatron.core import parallel_state as mpu
|
||||||
|
tp_size = mpu.get_tensor_model_parallel_world_size()
|
||||||
|
dp_size = mpu.get_data_parallel_world_size()
|
||||||
|
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
||||||
|
info = DistGlobalInfo(tp_size=tp_size, dp_size=dp_size, pp_size=pp_size)
|
||||||
|
return info
|
||||||
|
|
||||||
|
def get_megatron_rank_info(self):
|
||||||
|
from megatron.core import parallel_state as mpu
|
||||||
|
tp_rank = mpu.get_tensor_model_parallel_rank()
|
||||||
|
dp_rank = mpu.get_data_parallel_rank()
|
||||||
|
pp_rank = mpu.get_pipeline_model_parallel_rank()
|
||||||
|
info = DistRankInfo(tp_rank=tp_rank, dp_rank=dp_rank, pp_rank=pp_rank)
|
||||||
|
return info
|
||||||
51
verl/single_controller/base/megatron/worker_group.py
Normal file
51
verl/single_controller/base/megatron/worker_group.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from .worker import DistRankInfo, DistGlobalInfo
|
||||||
|
from verl.single_controller.base import ResourcePool, WorkerGroup
|
||||||
|
|
||||||
|
|
||||||
|
class MegatronWorkerGroup(WorkerGroup):
|
||||||
|
|
||||||
|
def __init__(self, resource_pool: ResourcePool, **kwargs):
|
||||||
|
super().__init__(resource_pool=resource_pool, **kwargs)
|
||||||
|
self._megatron_rank_info = None
|
||||||
|
self._megatron_global_info: DistGlobalInfo = None
|
||||||
|
|
||||||
|
def init_megatron(self, default_megatron_kwargs: Dict = None):
|
||||||
|
raise NotImplementedError(f"MegatronWorkerGroup.init_megatron should be overwritten")
|
||||||
|
|
||||||
|
def get_megatron_rank_info(self, rank: int) -> DistRankInfo:
|
||||||
|
assert 0 <= rank < self.world_size, f'rank must be from [0, world_size), Got {rank}'
|
||||||
|
return self._megatron_rank_info[rank]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tp_size(self):
|
||||||
|
assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized"
|
||||||
|
return self._megatron_global_info.tp_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dp_size(self):
|
||||||
|
assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized"
|
||||||
|
return self._megatron_global_info.dp_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pp_size(self):
|
||||||
|
assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized"
|
||||||
|
return self._megatron_global_info.pp_size
|
||||||
|
|
||||||
|
def get_megatron_global_info(self):
|
||||||
|
return self._megatron_global_info
|
||||||
13
verl/single_controller/base/register_center/__init__.py
Normal file
13
verl/single_controller/base/register_center/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
29
verl/single_controller/base/register_center/ray.py
Normal file
29
verl/single_controller/base/register_center/ray.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import ray
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
class WorkerGroupRegisterCenter:
|
||||||
|
|
||||||
|
def __init__(self, rank_zero_info):
|
||||||
|
self.rank_zero_info = rank_zero_info
|
||||||
|
|
||||||
|
def get_rank_zero_info(self):
|
||||||
|
return self.rank_zero_info
|
||||||
|
|
||||||
|
|
||||||
|
def create_worker_group_register_center(name, info):
|
||||||
|
return WorkerGroupRegisterCenter.options(name=name).remote(info)
|
||||||
186
verl/single_controller/base/worker.py
Normal file
186
verl/single_controller/base/worker.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
the class for Worker
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from verl.single_controller.base.decorator import register, Dispatch, Execute
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DistRankInfo:
|
||||||
|
tp_rank: int
|
||||||
|
dp_rank: int
|
||||||
|
pp_rank: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DistGlobalInfo:
|
||||||
|
tp_size: int
|
||||||
|
dp_size: int
|
||||||
|
pp_size: int
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerHelper:
|
||||||
|
|
||||||
|
def _get_node_ip(self):
|
||||||
|
|
||||||
|
def get_node_ip_by_sdk():
|
||||||
|
if os.getenv("WG_BACKEND", None) == "ray":
|
||||||
|
import ray
|
||||||
|
return ray._private.services.get_node_ip_address()
|
||||||
|
elif os.getenv("WG_BACKEND", None) == "torch_rpc":
|
||||||
|
from verl.single_controller.torchrpc.k8s_client import get_ip_addr
|
||||||
|
return get_ip_addr()
|
||||||
|
return None
|
||||||
|
|
||||||
|
host_ipv4 = os.getenv("MY_HOST_IP", None)
|
||||||
|
host_ipv6 = os.getenv("MY_HOST_IPV6", None)
|
||||||
|
host_ip_by_env = host_ipv4 or host_ipv6
|
||||||
|
host_ip_by_sdk = get_node_ip_by_sdk()
|
||||||
|
|
||||||
|
host_ip = host_ip_by_env or host_ip_by_sdk
|
||||||
|
return host_ip
|
||||||
|
|
||||||
|
def _get_free_port(self):
|
||||||
|
with socket.socket() as sock:
|
||||||
|
sock.bind(('', 0))
|
||||||
|
return sock.getsockname()[1]
|
||||||
|
|
||||||
|
def get_availale_master_addr_port(self):
|
||||||
|
return self._get_node_ip(), str(self._get_free_port())
|
||||||
|
|
||||||
|
def _get_pid(self):
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerMeta:
|
||||||
|
keys = [
|
||||||
|
"WORLD_SIZE", "RANK", "LOCAL_WORLD_SIZE", "LOCAL_RANK", "MASTER_ADDR", "MASTER_PORT", "CUDA_VISIBLE_DEVICES"
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, store) -> None:
|
||||||
|
self._store = store
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {f"_{key.lower()}": self._store.get(f"_{key.lower()}", None) for key in WorkerMeta.keys}
|
||||||
|
|
||||||
|
|
||||||
|
# we assume that in each WorkerGroup, there is a Master Worker
|
||||||
|
class Worker(WorkerHelper):
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
instance = super().__new__(cls)
|
||||||
|
|
||||||
|
# note that here we use int to distinguish
|
||||||
|
disable_worker_init = int(os.environ.get('DISABLE_WORKER_INIT', 0))
|
||||||
|
if disable_worker_init:
|
||||||
|
return instance
|
||||||
|
|
||||||
|
rank = os.environ.get("RANK", None)
|
||||||
|
worker_group_prefix = os.environ.get("WG_PREFIX", None)
|
||||||
|
|
||||||
|
# when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init
|
||||||
|
if None not in [rank, worker_group_prefix] and 'ActorClass(' not in cls.__name__:
|
||||||
|
instance._configure_before_init(f"{worker_group_prefix}_register_center", int(rank))
|
||||||
|
|
||||||
|
return instance
|
||||||
|
|
||||||
|
def _configure_before_init(self, register_center_name: str, rank: int):
|
||||||
|
assert isinstance(rank, int), f"rank must be int, instead of {type(rank)}"
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
master_addr, master_port = self.get_availale_master_addr_port()
|
||||||
|
rank_zero_info = {
|
||||||
|
"MASTER_ADDR": master_addr,
|
||||||
|
"MASTER_PORT": master_port,
|
||||||
|
}
|
||||||
|
|
||||||
|
if os.getenv("WG_BACKEND", None) == "ray":
|
||||||
|
from verl.single_controller.base.register_center.ray import create_worker_group_register_center
|
||||||
|
self.register_center = create_worker_group_register_center(name=register_center_name,
|
||||||
|
info=rank_zero_info)
|
||||||
|
|
||||||
|
os.environ.update(rank_zero_info)
|
||||||
|
|
||||||
|
def __init__(self, cuda_visible_devices=None) -> None:
|
||||||
|
# construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely
|
||||||
|
import os
|
||||||
|
world_size = int(os.environ['WORLD_SIZE'])
|
||||||
|
rank = int(os.environ['RANK'])
|
||||||
|
self._rank = rank
|
||||||
|
self._world_size = world_size
|
||||||
|
|
||||||
|
master_addr = os.environ["MASTER_ADDR"]
|
||||||
|
master_port = os.environ["MASTER_PORT"]
|
||||||
|
|
||||||
|
local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
|
||||||
|
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||||
|
|
||||||
|
store = {
|
||||||
|
'_world_size': world_size,
|
||||||
|
'_rank': rank,
|
||||||
|
'_local_world_size': local_world_size,
|
||||||
|
'_local_rank': local_rank,
|
||||||
|
'_master_addr': master_addr,
|
||||||
|
'_master_port': master_port
|
||||||
|
}
|
||||||
|
if cuda_visible_devices is not None:
|
||||||
|
store['_cuda_visible_devices'] = cuda_visible_devices
|
||||||
|
|
||||||
|
meta = WorkerMeta(store=store)
|
||||||
|
self._configure_with_meta(meta=meta)
|
||||||
|
|
||||||
|
def _configure_with_meta(self, meta: WorkerMeta):
|
||||||
|
"""
|
||||||
|
This function should only be called inside by WorkerGroup
|
||||||
|
"""
|
||||||
|
assert isinstance(meta, WorkerMeta)
|
||||||
|
self.__dict__.update(meta.to_dict()) # this is hacky
|
||||||
|
# print(f"__dict__: {self.__dict__}")
|
||||||
|
for key in WorkerMeta.keys:
|
||||||
|
val = self.__dict__.get(f"_{key.lower()}", None)
|
||||||
|
if val is not None:
|
||||||
|
# print(f"set {key} to {val}")
|
||||||
|
os.environ[key] = str(val)
|
||||||
|
os.environ["REDIS_STORE_SERVER_HOST"] = str(self._master_addr).replace("[", "").replace(
|
||||||
|
"]", "") if self._master_addr else ""
|
||||||
|
|
||||||
|
def get_master_addr_port(self):
|
||||||
|
return self._master_addr, self._master_port
|
||||||
|
|
||||||
|
def get_cuda_visible_devices(self):
|
||||||
|
import os
|
||||||
|
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "not set")
|
||||||
|
return cuda_visible_devices
|
||||||
|
|
||||||
|
@property
|
||||||
|
def world_size(self):
|
||||||
|
return self._world_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rank(self):
|
||||||
|
return self._rank
|
||||||
|
|
||||||
|
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO_WITH_FUNC)
|
||||||
|
def execute_with_func_generator(self, func, *args, **kwargs):
|
||||||
|
ret_proto = func(self, *args, **kwargs)
|
||||||
|
return ret_proto
|
||||||
|
|
||||||
|
@register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)
|
||||||
|
def execute_func_rank_zero(self, func, *args, **kwargs):
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
return result
|
||||||
196
verl/single_controller/base/worker_group.py
Normal file
196
verl/single_controller/base/worker_group.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
the class of WorkerGroup
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import signal
|
||||||
|
import time
|
||||||
|
from typing import List, Any, Callable, Dict
|
||||||
|
|
||||||
|
from verl.single_controller.base.decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn
|
||||||
|
|
||||||
|
|
||||||
|
class ResourcePool:
|
||||||
|
|
||||||
|
def __init__(self, process_on_nodes=None, max_collocate_count: int = 10, n_gpus_per_node=8) -> None:
|
||||||
|
if process_on_nodes is None:
|
||||||
|
process_on_nodes = []
|
||||||
|
self._store = process_on_nodes
|
||||||
|
self.max_collocate_count = max_collocate_count
|
||||||
|
self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node
|
||||||
|
|
||||||
|
def add_node(self, process_count):
|
||||||
|
self._store.append(process_count)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def world_size(self):
|
||||||
|
return sum(self._store)
|
||||||
|
|
||||||
|
def __call__(self) -> Any:
|
||||||
|
return self._store
|
||||||
|
|
||||||
|
@property
|
||||||
|
def store(self):
|
||||||
|
return self._store
|
||||||
|
|
||||||
|
def local_world_size_list(self) -> List[int]:
|
||||||
|
nested_local_world_size_list = [
|
||||||
|
[local_world_size for _ in range(local_world_size)] for local_world_size in self._store
|
||||||
|
]
|
||||||
|
return [item for row in nested_local_world_size_list for item in row]
|
||||||
|
|
||||||
|
def local_rank_list(self) -> List[int]:
|
||||||
|
nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store]
|
||||||
|
return [item for row in nested_local_rank_list for item in row]
|
||||||
|
|
||||||
|
|
||||||
|
class ClassWithInitArgs:
|
||||||
|
"""
|
||||||
|
This class stores a class constructor and the args/kwargs to construct the class.
|
||||||
|
It is used to instantiate the remote class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cls, *args, **kwargs) -> None:
|
||||||
|
self.cls = cls
|
||||||
|
self.args = args
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
# def add_arg(self, arg):
|
||||||
|
# self.args += (arg,)
|
||||||
|
|
||||||
|
# def add_kwarg(self, key, value):
|
||||||
|
# self.kwargs[key] = value
|
||||||
|
|
||||||
|
def __call__(self) -> Any:
|
||||||
|
return self.cls(*self.args, **self.kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None:
|
||||||
|
import time
|
||||||
|
while True:
|
||||||
|
for worker in workers:
|
||||||
|
if not is_alive(worker):
|
||||||
|
logging.warning(f"worker {worker} is not alive" + " sending signal to main thread")
|
||||||
|
signal.raise_signal(signal.SIGABRT)
|
||||||
|
time.sleep(gap_time)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerGroup:
|
||||||
|
|
||||||
|
def __init__(self, resource_pool: ResourcePool, **kwargs) -> None:
|
||||||
|
self._is_init_with_detached_workers = True if resource_pool is None else False
|
||||||
|
|
||||||
|
if resource_pool is not None:
|
||||||
|
# handle the case when WorkGroup is attached to an existing one
|
||||||
|
self._procecss_dispatch_config = resource_pool()
|
||||||
|
else:
|
||||||
|
self._procecss_dispatch_config = None
|
||||||
|
|
||||||
|
self._workers = []
|
||||||
|
self._worker_names = []
|
||||||
|
|
||||||
|
self._master_addr = None
|
||||||
|
self._master_port = None
|
||||||
|
|
||||||
|
self._checker_thread: threading.Thread = None
|
||||||
|
|
||||||
|
def _is_worker_alive(self, worker):
|
||||||
|
raise NotImplementedError(f"WorkerGroup._is_worker_alive called, should be implemented in derived class.")
|
||||||
|
|
||||||
|
def _block_until_all_workers_alive(self) -> None:
|
||||||
|
while True:
|
||||||
|
all_state = [self._is_worker_alive(worker) for worker in self._workers]
|
||||||
|
if False in all_state:
|
||||||
|
time.sleep(1)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
def start_worker_aliveness_check(self, every_n_seconds=1) -> None:
|
||||||
|
# before starting checking worker aliveness, make sure all workers are already alive
|
||||||
|
self._block_until_all_workers_alive()
|
||||||
|
|
||||||
|
self._checker_thread = threading.Thread(target=check_workers_alive,
|
||||||
|
args=(self._workers, self._is_worker_alive, every_n_seconds))
|
||||||
|
self._checker_thread.start()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def world_size(self):
|
||||||
|
return len(self._workers)
|
||||||
|
|
||||||
|
# execute_all_async and execute_rank_zero_async should be implemented by RayWorkerGroup, TorchRPCWorkerGroup,
|
||||||
|
# MegatronWorkerGroup, XperfWorkerGroup should skip
|
||||||
|
|
||||||
|
def _bind_worker_method(self, user_defined_cls, func_generator):
|
||||||
|
"""
|
||||||
|
Bind the worker method to the WorkerGroup
|
||||||
|
"""
|
||||||
|
|
||||||
|
for method_name in dir(user_defined_cls):
|
||||||
|
|
||||||
|
try:
|
||||||
|
method = getattr(user_defined_cls, method_name)
|
||||||
|
assert callable(method), f"{method_name} in {user_defined_cls} is not callable"
|
||||||
|
except Exception as e:
|
||||||
|
# if it is a property, it will fail because Class doesn't have instance property
|
||||||
|
continue
|
||||||
|
|
||||||
|
if hasattr(method, MAGIC_ATTR):
|
||||||
|
# this method is decorated by register
|
||||||
|
attribute = getattr(method, MAGIC_ATTR)
|
||||||
|
assert isinstance(attribute, Dict), f'attribute must be a dictionary. Got {type(attribute)}'
|
||||||
|
assert 'dispatch_mode' in attribute, f'attribute must contain dispatch_mode in its key'
|
||||||
|
|
||||||
|
dispatch_mode = attribute['dispatch_mode']
|
||||||
|
execute_mode = attribute['execute_mode']
|
||||||
|
blocking = attribute['blocking']
|
||||||
|
|
||||||
|
# get dispatch fn
|
||||||
|
if isinstance(dispatch_mode, Dispatch):
|
||||||
|
# get default dispatch fn
|
||||||
|
fn = get_predefined_dispatch_fn(dispatch_mode=dispatch_mode)
|
||||||
|
dispatch_fn = fn['dispatch_fn']
|
||||||
|
collect_fn = fn['collect_fn']
|
||||||
|
else:
|
||||||
|
assert isinstance(dispatch_mode, dict)
|
||||||
|
assert 'dispatch_fn' in dispatch_mode
|
||||||
|
assert 'collect_fn' in dispatch_mode
|
||||||
|
dispatch_fn = dispatch_mode['dispatch_fn']
|
||||||
|
collect_fn = dispatch_mode['collect_fn']
|
||||||
|
|
||||||
|
# get execute_fn_name
|
||||||
|
execute_mode = get_predefined_execute_fn(execute_mode=execute_mode)
|
||||||
|
wg_execute_fn_name = execute_mode['execute_fn_name']
|
||||||
|
|
||||||
|
# get execute_fn from string
|
||||||
|
try:
|
||||||
|
execute_fn = getattr(self, wg_execute_fn_name)
|
||||||
|
assert callable(execute_fn), 'execute_fn must be callable'
|
||||||
|
except Exception as e:
|
||||||
|
print(f'execute_fn {wg_execute_fn_name} is invalid')
|
||||||
|
raise
|
||||||
|
|
||||||
|
# bind a new method to the RayWorkerGroup
|
||||||
|
func = func_generator(self,
|
||||||
|
method_name,
|
||||||
|
dispatch_fn=dispatch_fn,
|
||||||
|
collect_fn=collect_fn,
|
||||||
|
execute_fn=execute_fn,
|
||||||
|
blocking=blocking)
|
||||||
|
|
||||||
|
try:
|
||||||
|
setattr(self, method_name, func)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f'Fail to set method_name {method_name}')
|
||||||
16
verl/single_controller/ray/__init__.py
Normal file
16
verl/single_controller/ray/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, create_colocated_worker_cls
|
||||||
|
from .megatron import (MegatronRayWorkerGroup, DistRankInfo, DistGlobalInfo)
|
||||||
459
verl/single_controller/ray/base.py
Normal file
459
verl/single_controller/ray/base.py
Normal file
@@ -0,0 +1,459 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Any, Tuple
|
||||||
|
|
||||||
|
import ray
|
||||||
|
from ray.util import list_named_actors
|
||||||
|
from ray.util.placement_group import placement_group, PlacementGroup
|
||||||
|
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy
|
||||||
|
from ray.experimental.state.api import get_actor
|
||||||
|
|
||||||
|
from verl.single_controller.base import WorkerGroup, ResourcePool, ClassWithInitArgs, Worker
|
||||||
|
|
||||||
|
__all__ = ['Worker']
|
||||||
|
|
||||||
|
|
||||||
|
def get_random_string(length: int) -> str:
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
letters_digits = string.ascii_letters + string.digits
|
||||||
|
return ''.join(random.choice(letters_digits) for _ in range(length))
|
||||||
|
|
||||||
|
|
||||||
|
def func_generator(self, method_name, dispatch_fn, collect_fn, execute_fn, blocking):
|
||||||
|
|
||||||
|
def func(*args, **kwargs):
|
||||||
|
args, kwargs = dispatch_fn(self, *args, **kwargs)
|
||||||
|
output = execute_fn(method_name, *args, **kwargs)
|
||||||
|
if blocking:
|
||||||
|
output = ray.get(output)
|
||||||
|
output = collect_fn(self, output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
class RayResourcePool(ResourcePool):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
process_on_nodes: List[int] = None,
|
||||||
|
use_gpu: bool = True,
|
||||||
|
name_prefix: str = "",
|
||||||
|
max_colocate_count: int = 5,
|
||||||
|
detached=False) -> None:
|
||||||
|
super().__init__(process_on_nodes, max_colocate_count)
|
||||||
|
self.use_gpu = use_gpu
|
||||||
|
# print(f"in RayProcessDispatchConfiguration: name_prefix = {name_prefix}")
|
||||||
|
self.name_prefix = name_prefix
|
||||||
|
self.pgs = None
|
||||||
|
self.detached = detached
|
||||||
|
|
||||||
|
def get_placement_groups(self, strategy="STRICT_PACK", name=None):
|
||||||
|
if self.pgs is not None:
|
||||||
|
return self.pgs
|
||||||
|
|
||||||
|
pg_name_prefix = name if name else \
|
||||||
|
f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:"
|
||||||
|
# print(f"pg_name_prefix = {pg_name_prefix}")
|
||||||
|
pg_scheme = [[{
|
||||||
|
"CPU": self.max_collocate_count,
|
||||||
|
"GPU": 1
|
||||||
|
} if self.use_gpu else {
|
||||||
|
"CPU": self.max_collocate_count
|
||||||
|
} for _ in range(process_count)] for process_count in self._store]
|
||||||
|
|
||||||
|
lifetime = 'detached' if self.detached else None
|
||||||
|
|
||||||
|
pgs = [
|
||||||
|
placement_group(bundles=bundles, strategy=strategy, name=pg_name_prefix + str(idx), lifetime=lifetime)
|
||||||
|
for idx, bundles in enumerate(pg_scheme)
|
||||||
|
]
|
||||||
|
|
||||||
|
ray.get([pg.ready() for pg in pgs])
|
||||||
|
|
||||||
|
self.pgs = pgs
|
||||||
|
return pgs
|
||||||
|
|
||||||
|
|
||||||
|
def extract_pg_from_exist(resource_pools: Dict[str, RayResourcePool], src_role_names: List[str],
|
||||||
|
resource_pool: RayResourcePool) -> List:
|
||||||
|
|
||||||
|
src_pgs = [
|
||||||
|
pg for role_name, resource_pool in resource_pools.items() for pg in resource_pool.get_placement_groups()
|
||||||
|
if role_name in src_role_names
|
||||||
|
]
|
||||||
|
|
||||||
|
sorted_src_pgs = sorted(src_pgs, key=lambda pg: pg.bundle_count, reverse=True)
|
||||||
|
sorted_process_on_nodes = sorted([(val, idx) for idx, val in enumerate(resource_pool.store)], reverse=True)
|
||||||
|
|
||||||
|
unsorted_pgs: List[Tuple[int, PlacementGroup]] = []
|
||||||
|
searching_idx = 0
|
||||||
|
for request_process, original_idx in sorted_process_on_nodes:
|
||||||
|
assert searching_idx < len(sorted_src_pgs), f"no enough nodes for request: searching {searching_idx} th node"
|
||||||
|
assert request_process <= sorted_src_pgs[searching_idx].bundle_count, \
|
||||||
|
f"requesting {request_process} processes, bundle count cannot satisfy"
|
||||||
|
unsorted_pgs.append((original_idx, sorted_src_pgs[searching_idx]))
|
||||||
|
searching_idx += 1
|
||||||
|
|
||||||
|
return [pg for _, pg in sorted(unsorted_pgs)]
|
||||||
|
|
||||||
|
|
||||||
|
def merge_resource_pool(rp1: RayResourcePool, rp2: RayResourcePool) -> RayResourcePool:
|
||||||
|
assert rp1.use_gpu == rp2.use_gpu, 'Both RayResourcePool must either use_gpu or not'
|
||||||
|
assert rp1.max_collocate_count == rp2.max_collocate_count, 'Both RayResourcePool must has the same max_collocate_count'
|
||||||
|
assert rp1.n_gpus_per_node == rp2.n_gpus_per_node, 'Both RayResourcePool must has the same n_gpus_per_node'
|
||||||
|
assert rp1.detached == rp2.detached, 'Detached ResourcePool cannot be merged with non-detached ResourcePool'
|
||||||
|
|
||||||
|
new_store = rp1.store + rp2.store
|
||||||
|
|
||||||
|
merged = RayResourcePool(new_store, rp1.use_gpu, f"{rp1.name_prefix}_{rp2.name_prefix}")
|
||||||
|
merged.pgs = rp1.get_placement_groups() + rp2.get_placement_groups()
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
class RayClassWithInitArgs(ClassWithInitArgs):
|
||||||
|
|
||||||
|
def __init__(self, cls, *args, **kwargs) -> None:
|
||||||
|
# self._options = kwargs.pop('options', dict())
|
||||||
|
super().__init__(cls, *args, **kwargs)
|
||||||
|
self._options = {}
|
||||||
|
self._additional_resource = {}
|
||||||
|
|
||||||
|
def set_additional_resource(self, additional_resource):
|
||||||
|
self._additional_resource = additional_resource
|
||||||
|
|
||||||
|
def update_options(self, options: Dict):
|
||||||
|
self._options.update(options)
|
||||||
|
|
||||||
|
def __call__(self,
|
||||||
|
placement_group,
|
||||||
|
placement_group_bundle_idx,
|
||||||
|
use_gpu: bool = True,
|
||||||
|
num_gpus=1,
|
||||||
|
sharing_with=None) -> Any:
|
||||||
|
if sharing_with is not None:
|
||||||
|
target_node_id = ray.get(sharing_with.get_node_id.remote())
|
||||||
|
cuda_visible_devices = ray.get(sharing_with.get_cuda_visible_devices.remote())
|
||||||
|
options = {"scheduling_strategy": NodeAffinitySchedulingStrategy(node_id=target_node_id, soft=False)}
|
||||||
|
return self.cls.options(**options).remote(*self.args,
|
||||||
|
cuda_visible_devices=cuda_visible_devices,
|
||||||
|
**self.kwargs)
|
||||||
|
|
||||||
|
options = {
|
||||||
|
"scheduling_strategy":
|
||||||
|
PlacementGroupSchedulingStrategy(placement_group=placement_group,
|
||||||
|
placement_group_bundle_index=placement_group_bundle_idx)
|
||||||
|
}
|
||||||
|
options.update(self._options)
|
||||||
|
|
||||||
|
if use_gpu:
|
||||||
|
options["num_gpus"] = num_gpus
|
||||||
|
|
||||||
|
if len(self._additional_resource) > 1:
|
||||||
|
for k, v in self._additional_resource.items():
|
||||||
|
options[k] = v
|
||||||
|
|
||||||
|
# print("cls:", self.cls)
|
||||||
|
# print("args: ", self.args)
|
||||||
|
# print("kwargs: ", self.kwargs)
|
||||||
|
return self.cls.options(**options).remote(*self.args, **self.kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class RayWorkerGroup(WorkerGroup):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
resource_pool: RayResourcePool = None,
|
||||||
|
ray_cls_with_init: RayClassWithInitArgs = None,
|
||||||
|
bin_pack: bool = True,
|
||||||
|
name_prefix: str = None,
|
||||||
|
detached=False,
|
||||||
|
worker_names=None,
|
||||||
|
**kwargs) -> None:
|
||||||
|
super().__init__(resource_pool=resource_pool, **kwargs)
|
||||||
|
self.ray_cls_with_init = ray_cls_with_init
|
||||||
|
self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix
|
||||||
|
|
||||||
|
if worker_names is not None:
|
||||||
|
assert self._is_init_with_detached_workers
|
||||||
|
self._worker_names = worker_names
|
||||||
|
|
||||||
|
if self._is_init_with_detached_workers:
|
||||||
|
self._init_with_detached_workers(worker_names=worker_names)
|
||||||
|
else:
|
||||||
|
self._init_with_resource_pool(resource_pool=resource_pool,
|
||||||
|
ray_cls_with_init=ray_cls_with_init,
|
||||||
|
bin_pack=bin_pack,
|
||||||
|
detached=detached)
|
||||||
|
|
||||||
|
if ray_cls_with_init is not None:
|
||||||
|
self._bind_worker_method(self.ray_cls_with_init.cls, func_generator)
|
||||||
|
|
||||||
|
def _is_worker_alive(self, worker: ray.actor.ActorHandle):
|
||||||
|
worker_state_dict = get_actor(worker._actor_id.hex())
|
||||||
|
return worker_state_dict.get("state", "undefined") == "ALIVE" if worker_state_dict is not None else False
|
||||||
|
|
||||||
|
def _init_with_detached_workers(self, worker_names):
|
||||||
|
workers = [ray.get_actor(name=name) for name in worker_names]
|
||||||
|
self._workers = workers
|
||||||
|
self._world_size = len(worker_names)
|
||||||
|
|
||||||
|
def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, detached):
|
||||||
|
use_gpu = resource_pool.use_gpu
|
||||||
|
|
||||||
|
strategy = "PACK"
|
||||||
|
if bin_pack:
|
||||||
|
strategy = "STRICT_PACK"
|
||||||
|
pgs = resource_pool.get_placement_groups(strategy=strategy)
|
||||||
|
world_size = resource_pool.world_size
|
||||||
|
self._world_size = world_size
|
||||||
|
# cia.add_kwarg("_world_size", world_size)
|
||||||
|
num_gpus = 1 / resource_pool.max_collocate_count
|
||||||
|
|
||||||
|
rank = -1
|
||||||
|
for pg_idx, local_world_size in enumerate(resource_pool.store):
|
||||||
|
pg = pgs[pg_idx]
|
||||||
|
assert local_world_size <= pg.bundle_count, \
|
||||||
|
f"when generating for {self.name_prefix}, for the "
|
||||||
|
for local_rank in range(local_world_size):
|
||||||
|
rank += 1
|
||||||
|
|
||||||
|
# we pass in environment variable at option so that Worker can use environment variable to set
|
||||||
|
env_vars = {
|
||||||
|
'WORLD_SIZE': str(world_size),
|
||||||
|
'RANK': str(rank),
|
||||||
|
'WG_PREFIX': self.name_prefix,
|
||||||
|
'WG_BACKEND': 'ray',
|
||||||
|
'RAY_LOCAL_WORLD_SIZE': str(local_world_size),
|
||||||
|
'RAY_LOCAL_RANK': str(local_rank),
|
||||||
|
}
|
||||||
|
if rank != 0:
|
||||||
|
env_vars['MASTER_ADDR'] = self._master_addr
|
||||||
|
env_vars['MASTER_PORT'] = self._master_port
|
||||||
|
|
||||||
|
import re
|
||||||
|
cia_name = type(ray_cls_with_init.cls).__name__
|
||||||
|
match = re.search(r"ActorClass\(([^)]+)\)", cia_name) # ray.remote(Obj) -> "ActorClass(Obj)"
|
||||||
|
cia_name = match.group(1) if match else cia_name # "ActorClass(Obj)" -> "Obj"
|
||||||
|
name = f"{self.name_prefix}{cia_name}_{pg_idx}:{local_rank}" # e.g. Worker_2:5
|
||||||
|
|
||||||
|
ray_cls_with_init.update_options({'runtime_env': {'env_vars': env_vars}, 'name': name})
|
||||||
|
|
||||||
|
if detached:
|
||||||
|
ray_cls_with_init.update_options({'lifetime': 'detached'})
|
||||||
|
|
||||||
|
# create a worker
|
||||||
|
worker = ray_cls_with_init(placement_group=pg,
|
||||||
|
placement_group_bundle_idx=local_rank,
|
||||||
|
use_gpu=use_gpu,
|
||||||
|
num_gpus=num_gpus)
|
||||||
|
self._workers.append(worker)
|
||||||
|
self._worker_names.append(name)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
register_center_actor = None
|
||||||
|
for _ in range(120):
|
||||||
|
if f"{self.name_prefix}_register_center" not in list_named_actors():
|
||||||
|
time.sleep(1)
|
||||||
|
else:
|
||||||
|
register_center_actor = ray.get_actor(f"{self.name_prefix}_register_center")
|
||||||
|
break
|
||||||
|
assert register_center_actor is not None, f"failed to get register_center_actor: {self.name_prefix}_register_center in {list_named_actors(all_namespaces=True)}"
|
||||||
|
rank_zero_info = ray.get(register_center_actor.get_rank_zero_info.remote())
|
||||||
|
self._master_addr, self._master_port = rank_zero_info['MASTER_ADDR'], rank_zero_info['MASTER_PORT']
|
||||||
|
# print(f"rank_zero_info: {rank_zero_info}")
|
||||||
|
# print(f"master_addr: {self._master_addr}, master_port: {self._master_port}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def worker_names(self):
|
||||||
|
return self._worker_names
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_detached(cls, worker_names=None, ray_cls_with_init=None):
|
||||||
|
worker_group = cls(resource_pool=None,
|
||||||
|
ray_cls_with_init=ray_cls_with_init,
|
||||||
|
name_prefix=None,
|
||||||
|
worker_names=worker_names)
|
||||||
|
return worker_group
|
||||||
|
|
||||||
|
def spawn(self, prefix_set):
|
||||||
|
"""
|
||||||
|
spawn to a dictionary of worker groups, each with a subset of method with prefix.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _rebind_actor_methods(worker_group, actor_name):
|
||||||
|
"""
|
||||||
|
bind the method with actor_prefix to its original name
|
||||||
|
"""
|
||||||
|
prefix: str = actor_name + '_'
|
||||||
|
for method_name in dir(worker_group):
|
||||||
|
if method_name.startswith(prefix):
|
||||||
|
# only valid when Python >= 3.9
|
||||||
|
original_method_name = method_name.removeprefix(prefix)
|
||||||
|
method = getattr(worker_group, method_name)
|
||||||
|
setattr(worker_group, original_method_name, method)
|
||||||
|
|
||||||
|
new_worker_group_dict = {}
|
||||||
|
for prefix in prefix_set:
|
||||||
|
new_worker_group = self.from_detached(worker_names=self._worker_names,
|
||||||
|
ray_cls_with_init=self.ray_cls_with_init)
|
||||||
|
|
||||||
|
_rebind_actor_methods(new_worker_group, prefix)
|
||||||
|
new_worker_group_dict[prefix] = new_worker_group
|
||||||
|
return new_worker_group_dict
|
||||||
|
|
||||||
|
def execute_rank_zero_sync(self, method_name: str, *args, **kwargs):
|
||||||
|
return ray.get(self.execute_all_async(method_name, **args, **kwargs))
|
||||||
|
|
||||||
|
def execute_rank_zero_async(self, method_name: str, *args, **kwargs):
|
||||||
|
remote_call = getattr(self._workers[0], method_name)
|
||||||
|
return remote_call.remote(*args, **kwargs)
|
||||||
|
|
||||||
|
def execute_rank_zero(self, method_name: str, *args, **kwargs):
|
||||||
|
return self.execute_rank_zero_async(method_name, *args, **kwargs)
|
||||||
|
|
||||||
|
def execute_all(self, method_name: str, *args, **kwargs):
|
||||||
|
return self.execute_all_async(method_name, *args, **kwargs)
|
||||||
|
|
||||||
|
def execute_all_sync(self, method_name: str, *args, **kwargs):
|
||||||
|
return ray.get(self.execute_all_async(method_name, *args, **kwargs))
|
||||||
|
|
||||||
|
def execute_all_async(self, method_name: str, *args, **kwargs):
|
||||||
|
# 这里我们假设,如果 args 和 kwargs 里面所有的参数都是 list,且所有的 list 长度都与 len(self._workers) 一致的话,我们会把
|
||||||
|
# list 中的每一个分别发到对应的 worker 上去
|
||||||
|
# print(f"execute_all_async: method {method_name}({args}, {kwargs})")
|
||||||
|
length = len(self._workers)
|
||||||
|
if all(isinstance(arg, list) for arg in args) and all(isinstance(kwarg, list) for kwarg in kwargs.values()):
|
||||||
|
if all(len(arg) == length for arg in args) and all(len(kwarg) == length for kwarg in kwargs.values()):
|
||||||
|
# print(f"splitting args and kwargs into {length} shards")
|
||||||
|
result = []
|
||||||
|
for i in range(length):
|
||||||
|
sliced_args = tuple(arg[i] for arg in args)
|
||||||
|
sliced_kwargs = {k: v[i] for k, v in kwargs.items()}
|
||||||
|
remote_call = getattr(self._workers[i], method_name)
|
||||||
|
result.append(remote_call.remote(*sliced_args, **sliced_kwargs))
|
||||||
|
return result
|
||||||
|
|
||||||
|
return [getattr(worker, method_name).remote(*args, **kwargs) for worker in self._workers]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def master_address(self):
|
||||||
|
return self._master_addr
|
||||||
|
|
||||||
|
@property
|
||||||
|
def master_port(self):
|
||||||
|
return self._master_port
|
||||||
|
|
||||||
|
@property
|
||||||
|
def workers(self):
|
||||||
|
return self._workers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def world_size(self):
|
||||||
|
return self._world_size
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Utilities that enables creating workers inside the same ray.Actor,
|
||||||
|
with code written in separate ray.Actors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
from verl.single_controller.base.decorator import MAGIC_ATTR
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def _bind_workers_method_to_parent(cls, key, user_defined_cls):
|
||||||
|
"""
|
||||||
|
Binds the methods of each worker to the WorkerDict.
|
||||||
|
Note that we only bind public methods that are decorated by register
|
||||||
|
"""
|
||||||
|
for method_name in dir(user_defined_cls):
|
||||||
|
try:
|
||||||
|
method = getattr(user_defined_cls, method_name)
|
||||||
|
assert callable(method), f"{method_name} in {user_defined_cls} is not callable"
|
||||||
|
except Exception as e:
|
||||||
|
# if it is a property, it will fail because Class doesn't have instance property
|
||||||
|
continue
|
||||||
|
|
||||||
|
if hasattr(method, MAGIC_ATTR):
|
||||||
|
|
||||||
|
def generate_function(name):
|
||||||
|
|
||||||
|
def func(self, *args, **kwargs):
|
||||||
|
# dispatch to the actual worker
|
||||||
|
return getattr(self.worker_dict[key], name)(*args, **kwargs)
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
|
func = generate_function(method_name)
|
||||||
|
# pass MAGIC_ATTR for outer worker group
|
||||||
|
setattr(func, MAGIC_ATTR, getattr(method, MAGIC_ATTR))
|
||||||
|
try:
|
||||||
|
method_name_with_prefix = key + '_' + method_name
|
||||||
|
setattr(cls, method_name_with_prefix, func)
|
||||||
|
# print(f'Binding {method_name_with_prefix}')
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f'Fail to set method_name {method_name}')
|
||||||
|
|
||||||
|
|
||||||
|
def _unwrap_ray_remote(cls):
|
||||||
|
if hasattr(cls, '__ray_actor_class__'):
|
||||||
|
cls = cls.__ray_actor_class__
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]):
|
||||||
|
"""
|
||||||
|
This function should return a class instance that delegates the calls to every
|
||||||
|
cls in cls_dict
|
||||||
|
"""
|
||||||
|
cls_dict = {}
|
||||||
|
init_args_dict = {}
|
||||||
|
worker_cls = None
|
||||||
|
for key, cls in class_dict.items():
|
||||||
|
if worker_cls == None:
|
||||||
|
worker_cls = cls.cls.__ray_actor_class__.__base__
|
||||||
|
else:
|
||||||
|
assert worker_cls == cls.cls.__ray_actor_class__.__base__, \
|
||||||
|
'the worker class should be the same when share the same process'
|
||||||
|
cls_dict[key] = cls.cls
|
||||||
|
init_args_dict[key] = {'args': cls.args, 'kwargs': cls.kwargs}
|
||||||
|
|
||||||
|
assert cls_dict.keys() == init_args_dict.keys()
|
||||||
|
|
||||||
|
# TODO: create a class with customizable name
|
||||||
|
class WorkerDict(worker_cls):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.worker_dict = {}
|
||||||
|
for key, user_defined_cls in cls_dict.items():
|
||||||
|
user_defined_cls = _unwrap_ray_remote(user_defined_cls)
|
||||||
|
# directly instantiate the class without remote
|
||||||
|
with patch.dict(os.environ, {'DISABLE_WORKER_INIT': '1'}):
|
||||||
|
self.worker_dict[key] = user_defined_cls(*init_args_dict[key].get('args', ()),
|
||||||
|
**init_args_dict[key].get('kwargs', {}))
|
||||||
|
|
||||||
|
# now monkey-patch the methods from inner class to WorkerDict
|
||||||
|
for key, user_defined_cls in cls_dict.items():
|
||||||
|
user_defined_cls = _unwrap_ray_remote(user_defined_cls)
|
||||||
|
_bind_workers_method_to_parent(WorkerDict, key, user_defined_cls)
|
||||||
|
|
||||||
|
remote_cls = ray.remote(WorkerDict)
|
||||||
|
remote_cls = RayClassWithInitArgs(cls=remote_cls)
|
||||||
|
return remote_cls
|
||||||
62
verl/single_controller/ray/megatron.py
Normal file
62
verl/single_controller/ray/megatron.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import ray
|
||||||
|
|
||||||
|
from .base import RayWorkerGroup, RayResourcePool, RayClassWithInitArgs
|
||||||
|
from verl.single_controller.base.megatron.worker import DistRankInfo, DistGlobalInfo
|
||||||
|
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE(sgm): for opensource megatron-core
|
||||||
|
class NVMegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup):
|
||||||
|
"""
|
||||||
|
MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup
|
||||||
|
so that the dispatcher can use it to dispatch data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWithInitArgs, **kwargs):
|
||||||
|
super().__init__(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, **kwargs)
|
||||||
|
self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name='get_megatron_rank_info')
|
||||||
|
self._megatron_global_info: DistGlobalInfo = ray.get(
|
||||||
|
self.execute_rank_zero_async(method_name='get_megatron_global_info'))
|
||||||
|
|
||||||
|
|
||||||
|
class MegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup):
|
||||||
|
"""
|
||||||
|
MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup
|
||||||
|
so that the dispatcher can use it to dispatch data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
resource_pool: RayResourcePool,
|
||||||
|
ray_cls_with_init: RayClassWithInitArgs,
|
||||||
|
default_megatron_kwargs: Dict = None,
|
||||||
|
**kwargs):
|
||||||
|
super().__init__(resource_pool=resource_pool,
|
||||||
|
ray_cls_with_init=ray_cls_with_init,
|
||||||
|
default_megatron_kwargs=default_megatron_kwargs,
|
||||||
|
**kwargs)
|
||||||
|
self.init_megatron(default_megatron_kwargs=default_megatron_kwargs)
|
||||||
|
self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name='get_megatron_rank_info')
|
||||||
|
self._megatron_global_info: DistGlobalInfo = ray.get(
|
||||||
|
self.execute_rank_zero_async(method_name='get_megatron_global_info'))
|
||||||
|
|
||||||
|
def init_megatron(self, default_megatron_kwargs: Optional[Dict] = None):
|
||||||
|
# after super, we will call init of each worker
|
||||||
|
if not self._is_init_with_detached_workers:
|
||||||
|
# only init_megatron if the WorkerGroup is created from scratch
|
||||||
|
self.execute_all_sync(method_name='init_megatron', default_megatron_kwargs=default_megatron_kwargs)
|
||||||
1
verl/single_controller/version/version
Normal file
1
verl/single_controller/version/version
Normal file
@@ -0,0 +1 @@
|
|||||||
|
0.0.2
|
||||||
13
verl/third_party/__init__.py
vendored
Normal file
13
verl/third_party/__init__.py
vendored
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
51
verl/third_party/vllm/__init__.py
vendored
Normal file
51
verl/third_party/vllm/__init__.py
vendored
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from importlib.metadata import version, PackageNotFoundError
|
||||||
|
|
||||||
|
|
||||||
|
def get_version(pkg):
|
||||||
|
try:
|
||||||
|
return version(pkg)
|
||||||
|
except PackageNotFoundError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
package_name = 'vllm'
|
||||||
|
package_version = get_version(package_name)
|
||||||
|
|
||||||
|
if package_version == '0.3.1':
|
||||||
|
vllm_version = '0.3.1'
|
||||||
|
from .vllm_v_0_3_1.llm import LLM
|
||||||
|
from .vllm_v_0_3_1.llm import LLMEngine
|
||||||
|
from .vllm_v_0_3_1 import parallel_state
|
||||||
|
elif package_version == '0.4.2':
|
||||||
|
vllm_version = '0.4.2'
|
||||||
|
from .vllm_v_0_4_2.llm import LLM
|
||||||
|
from .vllm_v_0_4_2.llm import LLMEngine
|
||||||
|
from .vllm_v_0_4_2 import parallel_state
|
||||||
|
elif package_version == '0.5.4':
|
||||||
|
vllm_version = '0.5.4'
|
||||||
|
from .vllm_v_0_5_4.llm import LLM
|
||||||
|
from .vllm_v_0_5_4.llm import LLMEngine
|
||||||
|
from .vllm_v_0_5_4 import parallel_state
|
||||||
|
elif package_version == '0.6.3':
|
||||||
|
vllm_version = '0.6.3'
|
||||||
|
from .vllm_v_0_6_3.llm import LLM
|
||||||
|
from .vllm_v_0_6_3.llm import LLMEngine
|
||||||
|
from .vllm_v_0_6_3 import parallel_state
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f'vllm version {package_version} not supported. Currently supported versions are 0.3.1, 0.4.2, 0.5.4 and 0.6.3.'
|
||||||
|
)
|
||||||
13
verl/third_party/vllm/vllm_v_0_3_1/__init__.py
vendored
Normal file
13
verl/third_party/vllm/vllm_v_0_3_1/__init__.py
vendored
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
228
verl/third_party/vllm/vllm_v_0_3_1/arg_utils.py
vendored
Normal file
228
verl/third_party/vllm/vllm_v_0_3_1/arg_utils.py
vendored
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py
|
||||||
|
import argparse
|
||||||
|
import dataclasses
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig)
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
from .config import ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EngineArgs:
|
||||||
|
"""Arguments for vLLM engine."""
|
||||||
|
model_hf_config: PretrainedConfig = None
|
||||||
|
dtype: str = 'auto'
|
||||||
|
kv_cache_dtype: str = 'auto'
|
||||||
|
seed: int = 0
|
||||||
|
max_model_len: Optional[int] = None
|
||||||
|
worker_use_ray: bool = False
|
||||||
|
pipeline_parallel_size: int = 1
|
||||||
|
tensor_parallel_size: int = 1
|
||||||
|
max_parallel_loading_workers: Optional[int] = None
|
||||||
|
block_size: int = 16
|
||||||
|
swap_space: int = 4 # GiB
|
||||||
|
gpu_memory_utilization: float = 0.90
|
||||||
|
max_num_batched_tokens: Optional[int] = None
|
||||||
|
max_num_seqs: int = 256
|
||||||
|
max_paddings: int = 256
|
||||||
|
disable_log_stats: bool = False
|
||||||
|
revision: Optional[str] = None
|
||||||
|
tokenizer_revision: Optional[str] = None
|
||||||
|
quantization: Optional[str] = None
|
||||||
|
load_format: str = 'model'
|
||||||
|
enforce_eager: bool = False
|
||||||
|
max_context_len_to_capture: int = 8192
|
||||||
|
disable_custom_all_reduce: bool = False
|
||||||
|
enable_lora: bool = False
|
||||||
|
max_loras: int = 1
|
||||||
|
max_lora_rank: int = 16
|
||||||
|
lora_extra_vocab_size: int = 256
|
||||||
|
lora_dtype = 'auto'
|
||||||
|
max_cpu_loras: Optional[int] = None
|
||||||
|
device: str = 'cuda'
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||||
|
"""Shared CLI arguments for vLLM engine."""
|
||||||
|
# Model arguments
|
||||||
|
# TODO(shengguangming): delete the unused args
|
||||||
|
parser.add_argument('--model',
|
||||||
|
type=str,
|
||||||
|
default='facebook/opt-125m',
|
||||||
|
help='name or path of the huggingface model to use')
|
||||||
|
parser.add_argument('--tokenizer',
|
||||||
|
type=str,
|
||||||
|
default=EngineArgs.tokenizer,
|
||||||
|
help='name or path of the huggingface tokenizer to use')
|
||||||
|
parser.add_argument('--revision',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='the specific model version to use. It can be a branch '
|
||||||
|
'name, a tag name, or a commit id. If unspecified, will use '
|
||||||
|
'the default version.')
|
||||||
|
parser.add_argument('--tokenizer-revision',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='the specific tokenizer version to use. It can be a branch '
|
||||||
|
'name, a tag name, or a commit id. If unspecified, will use '
|
||||||
|
'the default version.')
|
||||||
|
parser.add_argument('--tokenizer-mode',
|
||||||
|
type=str,
|
||||||
|
default=EngineArgs.tokenizer_mode,
|
||||||
|
choices=['auto', 'slow'],
|
||||||
|
help='tokenizer mode. "auto" will use the fast '
|
||||||
|
'tokenizer if available, and "slow" will '
|
||||||
|
'always use the slow tokenizer.')
|
||||||
|
parser.add_argument('--trust-remote-code', action='store_true', help='trust remote code from huggingface')
|
||||||
|
parser.add_argument('--download-dir',
|
||||||
|
type=str,
|
||||||
|
default=EngineArgs.download_dir,
|
||||||
|
help='directory to download and load the weights, '
|
||||||
|
'default to the default cache dir of '
|
||||||
|
'huggingface')
|
||||||
|
parser.add_argument('--load-format',
|
||||||
|
type=str,
|
||||||
|
default=EngineArgs.load_format,
|
||||||
|
choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
|
||||||
|
help='The format of the model weights to load. '
|
||||||
|
'"auto" will try to load the weights in the safetensors format '
|
||||||
|
'and fall back to the pytorch bin format if safetensors format '
|
||||||
|
'is not available. '
|
||||||
|
'"pt" will load the weights in the pytorch bin format. '
|
||||||
|
'"safetensors" will load the weights in the safetensors format. '
|
||||||
|
'"npcache" will load the weights in pytorch format and store '
|
||||||
|
'a numpy cache to speed up the loading. '
|
||||||
|
'"dummy" will initialize the weights with random values, '
|
||||||
|
'which is mainly for profiling.')
|
||||||
|
parser.add_argument('--dtype',
|
||||||
|
type=str,
|
||||||
|
default=EngineArgs.dtype,
|
||||||
|
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
|
||||||
|
help='data type for model weights and activations. '
|
||||||
|
'The "auto" option will use FP16 precision '
|
||||||
|
'for FP32 and FP16 models, and BF16 precision '
|
||||||
|
'for BF16 models.')
|
||||||
|
parser.add_argument('--max-model-len',
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help='model context length. If unspecified, '
|
||||||
|
'will be automatically derived from the model.')
|
||||||
|
# Parallel arguments
|
||||||
|
parser.add_argument('--worker-use-ray',
|
||||||
|
action='store_true',
|
||||||
|
help='use Ray for distributed serving, will be '
|
||||||
|
'automatically set when using more than 1 GPU')
|
||||||
|
parser.add_argument('--pipeline-parallel-size',
|
||||||
|
'-pp',
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.pipeline_parallel_size,
|
||||||
|
help='number of pipeline stages')
|
||||||
|
parser.add_argument('--tensor-parallel-size',
|
||||||
|
'-tp',
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.tensor_parallel_size,
|
||||||
|
help='number of tensor parallel replicas')
|
||||||
|
# KV cache arguments
|
||||||
|
parser.add_argument('--block-size',
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.block_size,
|
||||||
|
choices=[8, 16, 32],
|
||||||
|
help='token block size')
|
||||||
|
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||||
|
parser.add_argument('--seed', type=int, default=EngineArgs.seed, help='random seed')
|
||||||
|
parser.add_argument('--swap-space',
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.swap_space,
|
||||||
|
help='CPU swap space size (GiB) per GPU')
|
||||||
|
parser.add_argument('--gpu-memory-utilization',
|
||||||
|
type=float,
|
||||||
|
default=EngineArgs.gpu_memory_utilization,
|
||||||
|
help='the percentage of GPU memory to be used for'
|
||||||
|
'the model executor')
|
||||||
|
parser.add_argument('--max-num-batched-tokens',
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.max_num_batched_tokens,
|
||||||
|
help='maximum number of batched tokens per '
|
||||||
|
'iteration')
|
||||||
|
parser.add_argument('--max-num-seqs',
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.max_num_seqs,
|
||||||
|
help='maximum number of sequences per iteration')
|
||||||
|
parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics')
|
||||||
|
# Quantization settings.
|
||||||
|
parser.add_argument('--quantization',
|
||||||
|
'-q',
|
||||||
|
type=str,
|
||||||
|
choices=['awq', None],
|
||||||
|
default=None,
|
||||||
|
help='Method used to quantize the weights')
|
||||||
|
return parser
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
|
||||||
|
# Get the list of attributes of this dataclass.
|
||||||
|
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||||
|
# Set the attributes from the parsed arguments.
|
||||||
|
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||||
|
return engine_args
|
||||||
|
|
||||||
|
def create_engine_configs(
|
||||||
|
self,
|
||||||
|
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
|
||||||
|
device_config = DeviceConfig(self.device)
|
||||||
|
model_config = ModelConfig(self.model_hf_config, self.dtype, self.seed, self.load_format, self.revision,
|
||||||
|
self.tokenizer_revision, self.max_model_len, self.quantization, self.enforce_eager,
|
||||||
|
self.max_context_len_to_capture)
|
||||||
|
cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype,
|
||||||
|
model_config.get_sliding_window())
|
||||||
|
parallel_config = ParallelConfig(self.pipeline_parallel_size, self.tensor_parallel_size, self.worker_use_ray,
|
||||||
|
self.max_parallel_loading_workers, self.disable_custom_all_reduce)
|
||||||
|
scheduler_config = SchedulerConfig(self.max_num_batched_tokens, self.max_num_seqs, model_config.max_model_len,
|
||||||
|
self.max_paddings)
|
||||||
|
lora_config = LoRAConfig(max_lora_rank=self.max_lora_rank,
|
||||||
|
max_loras=self.max_loras,
|
||||||
|
lora_extra_vocab_size=self.lora_extra_vocab_size,
|
||||||
|
lora_dtype=self.lora_dtype,
|
||||||
|
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else
|
||||||
|
None) if self.enable_lora else None
|
||||||
|
return (model_config, cache_config, parallel_config, scheduler_config, device_config, lora_config)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AsyncEngineArgs(EngineArgs):
|
||||||
|
"""Arguments for asynchronous vLLM engine."""
|
||||||
|
engine_use_ray: bool = False
|
||||||
|
disable_log_requests: bool = False
|
||||||
|
max_log_len: Optional[int] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||||
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
|
parser.add_argument('--engine-use-ray',
|
||||||
|
action='store_true',
|
||||||
|
help='use Ray to start the LLM engine in a '
|
||||||
|
'separate process as the server process.')
|
||||||
|
parser.add_argument('--disable-log-requests', action='store_true', help='disable logging requests')
|
||||||
|
parser.add_argument('--max-log-len',
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help='max number of prompt characters or prompt '
|
||||||
|
'ID numbers being printed in log. '
|
||||||
|
'Default: unlimited.')
|
||||||
|
return parser
|
||||||
577
verl/third_party/vllm/vllm_v_0_3_1/config.py
vendored
Normal file
577
verl/third_party/vllm/vllm_v_0_3_1/config.py
vendored
Normal file
@@ -0,0 +1,577 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py
|
||||||
|
|
||||||
|
from typing import Optional, Union, ClassVar
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import torch
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
from packaging.version import Version
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.transformers_utils.config import get_config
|
||||||
|
from vllm.utils import get_cpu_memory, is_hip, get_nvcc_cuda_version
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
_GB = 1 << 30
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConfig:
|
||||||
|
"""Configuration for the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Name or path of the huggingface model to use.
|
||||||
|
tokenizer: Name or path of the huggingface tokenizer to use.
|
||||||
|
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
|
||||||
|
available, and "slow" will always use the slow tokenizer.
|
||||||
|
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
||||||
|
downloading the model and tokenizer.
|
||||||
|
download_dir: Directory to download and load the weights, default to the
|
||||||
|
default cache directory of huggingface.
|
||||||
|
load_format: The format of the model weights to load:
|
||||||
|
"auto" will try to load the weights in the safetensors format and
|
||||||
|
fall back to the pytorch bin format if safetensors format is
|
||||||
|
not available.
|
||||||
|
"pt" will load the weights in the pytorch bin format.
|
||||||
|
"safetensors" will load the weights in the safetensors format.
|
||||||
|
"npcache" will load the weights in pytorch format and store
|
||||||
|
a numpy cache to speed up the loading.
|
||||||
|
"dummy" will initialize the weights with random values, which is
|
||||||
|
mainly for profiling.
|
||||||
|
dtype: Data type for model weights and activations. The "auto" option
|
||||||
|
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
||||||
|
for BF16 models.
|
||||||
|
seed: Random seed for reproducibility.
|
||||||
|
revision: The specific model version to use. It can be a branch name,
|
||||||
|
a tag name, or a commit id. If unspecified, will use the default
|
||||||
|
version.
|
||||||
|
tokenizer_revision: The specific tokenizer version to use. It can be a
|
||||||
|
branch name, a tag name, or a commit id. If unspecified, will use
|
||||||
|
the default version.
|
||||||
|
max_model_len: Maximum length of a sequence (including prompt and
|
||||||
|
output). If None, will be derived from the model.
|
||||||
|
quantization: Quantization method that was used to quantize the model
|
||||||
|
weights. If None, we assume the model weights are not quantized.
|
||||||
|
enforce_eager: Whether to enforce eager execution. If True, we will
|
||||||
|
disable CUDA graph and always execute the model in eager mode.
|
||||||
|
If False, we will use CUDA graph and eager execution in hybrid.
|
||||||
|
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
||||||
|
When a sequence has context length larger than this, we fall back
|
||||||
|
to eager mode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hf_config: PretrainedConfig,
|
||||||
|
dtype: str,
|
||||||
|
seed: int,
|
||||||
|
load_format: str = 'model',
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
tokenizer_revision: Optional[str] = None,
|
||||||
|
max_model_len: Optional[int] = None,
|
||||||
|
quantization: Optional[str] = None,
|
||||||
|
trust_remote_code: Optional[bool] = True,
|
||||||
|
enforce_eager: bool = False,
|
||||||
|
max_context_len_to_capture: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
self.model = hf_config._name_or_path
|
||||||
|
self.tokenizer = hf_config._name_or_path
|
||||||
|
self.load_format = load_format
|
||||||
|
self.seed = seed
|
||||||
|
self.revision = revision
|
||||||
|
self.tokenizer_revision = tokenizer_revision
|
||||||
|
self.quantization = quantization
|
||||||
|
self.trust_remote_code = trust_remote_code
|
||||||
|
self.enforce_eager = enforce_eager
|
||||||
|
self.max_context_len_to_capture = max_context_len_to_capture
|
||||||
|
|
||||||
|
# self.hf_config = get_config(model, trust_remote_code, revision)
|
||||||
|
self.hf_config = hf_config
|
||||||
|
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||||
|
self.max_model_len = _get_and_verify_max_len(self.hf_config, max_model_len)
|
||||||
|
# self._verify_load_format()
|
||||||
|
# self._verify_tokenizer_mode()
|
||||||
|
self._verify_quantization()
|
||||||
|
self._verify_cuda_graph()
|
||||||
|
|
||||||
|
def _verify_load_format(self) -> None:
|
||||||
|
load_format = self.load_format.lower()
|
||||||
|
if load_format not in ["auto", "pt", "safetensors", "npcache", "dummy", "model"]:
|
||||||
|
raise ValueError(f"Unknown load format: {self.load_format}. Must be one of "
|
||||||
|
"'auto', 'pt', 'safetensors', 'npcache', 'dummy' or 'model'.")
|
||||||
|
self.load_format = load_format
|
||||||
|
|
||||||
|
# def _verify_tokenizer_mode(self) -> None:
|
||||||
|
# tokenizer_mode = self.tokenizer_mode.lower()
|
||||||
|
# if tokenizer_mode not in ["auto", "slow"]:
|
||||||
|
# raise ValueError(
|
||||||
|
# f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
|
||||||
|
# "either 'auto' or 'slow'.")
|
||||||
|
# self.tokenizer_mode = tokenizer_mode
|
||||||
|
|
||||||
|
def _verify_quantization(self) -> None:
|
||||||
|
supported_quantization = ["awq", "gptq", "squeezellm"]
|
||||||
|
rocm_not_supported_quantization = ["awq", "gptq"]
|
||||||
|
if self.quantization is not None:
|
||||||
|
self.quantization = self.quantization.lower()
|
||||||
|
|
||||||
|
# Parse quantization method from the HF model config, if available.
|
||||||
|
hf_quant_config = getattr(self.hf_config, "quantization_config", None)
|
||||||
|
if hf_quant_config is not None:
|
||||||
|
hf_quant_method = str(hf_quant_config["quant_method"]).lower()
|
||||||
|
if self.quantization is None:
|
||||||
|
self.quantization = hf_quant_method
|
||||||
|
elif self.quantization != hf_quant_method:
|
||||||
|
raise ValueError("Quantization method specified in the model config "
|
||||||
|
f"({hf_quant_method}) does not match the quantization "
|
||||||
|
f"method specified in the `quantization` argument "
|
||||||
|
f"({self.quantization}).")
|
||||||
|
|
||||||
|
if self.quantization is not None:
|
||||||
|
if self.quantization not in supported_quantization:
|
||||||
|
raise ValueError(f"Unknown quantization method: {self.quantization}. Must "
|
||||||
|
f"be one of {supported_quantization}.")
|
||||||
|
if is_hip() and self.quantization in rocm_not_supported_quantization:
|
||||||
|
raise ValueError(f"{self.quantization} quantization is currently not supported "
|
||||||
|
f"in ROCm.")
|
||||||
|
logger.warning(f"{self.quantization} quantization is not fully "
|
||||||
|
"optimized yet. The speed can be slower than "
|
||||||
|
"non-quantized models.")
|
||||||
|
|
||||||
|
def _verify_cuda_graph(self) -> None:
|
||||||
|
if self.max_context_len_to_capture is None:
|
||||||
|
self.max_context_len_to_capture = self.max_model_len
|
||||||
|
self.max_context_len_to_capture = min(self.max_context_len_to_capture, self.max_model_len)
|
||||||
|
if (self.quantization in ["gptq", "squeezellm"] and not self.enforce_eager):
|
||||||
|
# Related issue: https://github.com/vllm-project/vllm/issues/2147
|
||||||
|
logger.warning(f"{self.quantization} does not support CUDA graph "
|
||||||
|
"yet. Disabling CUDA graph.")
|
||||||
|
self.enforce_eager = True
|
||||||
|
|
||||||
|
def verify_with_parallel_config(
|
||||||
|
self,
|
||||||
|
parallel_config: "ParallelConfig",
|
||||||
|
) -> None:
|
||||||
|
total_num_attention_heads = self.hf_config.num_attention_heads
|
||||||
|
tensor_parallel_size = parallel_config.tensor_parallel_size
|
||||||
|
if total_num_attention_heads % tensor_parallel_size != 0:
|
||||||
|
raise ValueError(f"Total number of attention heads ({total_num_attention_heads})"
|
||||||
|
" must be divisible by tensor parallel size "
|
||||||
|
f"({tensor_parallel_size}).")
|
||||||
|
|
||||||
|
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
||||||
|
pipeline_parallel_size = parallel_config.pipeline_parallel_size
|
||||||
|
if total_num_hidden_layers % pipeline_parallel_size != 0:
|
||||||
|
raise ValueError(f"Total number of hidden layers ({total_num_hidden_layers}) "
|
||||||
|
"must be divisible by pipeline parallel size "
|
||||||
|
f"({pipeline_parallel_size}).")
|
||||||
|
|
||||||
|
def get_sliding_window(self) -> Optional[int]:
|
||||||
|
return getattr(self.hf_config, "sliding_window", None)
|
||||||
|
|
||||||
|
def get_vocab_size(self) -> int:
|
||||||
|
return self.hf_config.vocab_size
|
||||||
|
|
||||||
|
def get_hidden_size(self) -> int:
|
||||||
|
return self.hf_config.hidden_size
|
||||||
|
|
||||||
|
def get_head_size(self) -> int:
|
||||||
|
# FIXME(woosuk): This may not be true for all models.
|
||||||
|
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
||||||
|
|
||||||
|
def get_total_num_kv_heads(self) -> int:
|
||||||
|
"""Returns the total number of KV heads."""
|
||||||
|
# For GPTBigCode & Falcon:
|
||||||
|
# NOTE: for falcon, when new_decoder_architecture is True, the
|
||||||
|
# multi_query flag is ignored and we use n_head_kv for the number of
|
||||||
|
# KV heads.
|
||||||
|
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
|
||||||
|
new_decoder_arch_falcon = (self.hf_config.model_type in falcon_model_types and
|
||||||
|
getattr(self.hf_config, "new_decoder_architecture", False))
|
||||||
|
if not new_decoder_arch_falcon and getattr(self.hf_config, "multi_query", False):
|
||||||
|
# Multi-query attention, only one KV head.
|
||||||
|
# Currently, tensor parallelism is not supported in this case.
|
||||||
|
return 1
|
||||||
|
|
||||||
|
attributes = [
|
||||||
|
# For Falcon:
|
||||||
|
"n_head_kv",
|
||||||
|
"num_kv_heads",
|
||||||
|
# For LLaMA-2:
|
||||||
|
"num_key_value_heads",
|
||||||
|
# For ChatGLM:
|
||||||
|
"multi_query_group_num",
|
||||||
|
]
|
||||||
|
for attr in attributes:
|
||||||
|
num_kv_heads = getattr(self.hf_config, attr, None)
|
||||||
|
if num_kv_heads is not None:
|
||||||
|
return num_kv_heads
|
||||||
|
|
||||||
|
# For non-grouped-query attention models, the number of KV heads is
|
||||||
|
# equal to the number of attention heads.
|
||||||
|
return self.hf_config.num_attention_heads
|
||||||
|
|
||||||
|
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||||
|
"""Returns the number of KV heads per GPU."""
|
||||||
|
total_num_kv_heads = self.get_total_num_kv_heads()
|
||||||
|
# If tensor parallelism is used, we divide the number of KV heads by
|
||||||
|
# the tensor parallel size. We will replicate the KV heads in the
|
||||||
|
# case where the number of KV heads is smaller than the tensor
|
||||||
|
# parallel size so each GPU has at least one KV head.
|
||||||
|
return max(1, total_num_kv_heads // parallel_config.tensor_parallel_size)
|
||||||
|
|
||||||
|
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
||||||
|
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
||||||
|
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
|
||||||
|
|
||||||
|
|
||||||
|
class CacheConfig:
|
||||||
|
"""Configuration for the KV cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block_size: Size of a cache block in number of tokens.
|
||||||
|
gpu_memory_utilization: Fraction of GPU memory to use for the
|
||||||
|
vLLM execution.
|
||||||
|
swap_space: Size of the CPU swap space per GPU (in GiB).
|
||||||
|
cache_dtype: Data type for kv cache storage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
block_size: int,
|
||||||
|
gpu_memory_utilization: float,
|
||||||
|
swap_space: int,
|
||||||
|
cache_dtype: str,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
self.block_size = block_size
|
||||||
|
self.gpu_memory_utilization = gpu_memory_utilization
|
||||||
|
self.swap_space_bytes = swap_space * _GB
|
||||||
|
self.cache_dtype = cache_dtype
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
self._verify_args()
|
||||||
|
self._verify_cache_dtype()
|
||||||
|
|
||||||
|
# Will be set after profiling.
|
||||||
|
self.num_gpu_blocks = None
|
||||||
|
self.num_cpu_blocks = None
|
||||||
|
|
||||||
|
def _verify_args(self) -> None:
|
||||||
|
if self.gpu_memory_utilization > 1.0:
|
||||||
|
raise ValueError("GPU memory utilization must be less than 1.0. Got "
|
||||||
|
f"{self.gpu_memory_utilization}.")
|
||||||
|
|
||||||
|
def _verify_cache_dtype(self) -> None:
|
||||||
|
if self.cache_dtype == "auto":
|
||||||
|
pass
|
||||||
|
elif self.cache_dtype == "fp8_e5m2":
|
||||||
|
nvcc_cuda_version = get_nvcc_cuda_version()
|
||||||
|
if nvcc_cuda_version < Version("11.8"):
|
||||||
|
raise ValueError("FP8 is not supported when cuda version is lower than 11.8.")
|
||||||
|
device_name = torch.cuda.get_device_name()
|
||||||
|
if "AMD" in device_name:
|
||||||
|
raise NotImplementedError("FP8_E5M2 KV Cache on AMD GPU has not been supported yet.")
|
||||||
|
logger.info("Using fp8_e5m2 data type to store kv cache. It reduces "
|
||||||
|
"the GPU memory footprint and boosts the performance. "
|
||||||
|
"But it may cause slight accuracy drop. "
|
||||||
|
"Currently we only support fp8 without scaling factors and "
|
||||||
|
"make e5m2 as a default format.")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
|
||||||
|
|
||||||
|
def verify_with_parallel_config(
|
||||||
|
self,
|
||||||
|
parallel_config: "ParallelConfig",
|
||||||
|
) -> None:
|
||||||
|
total_cpu_memory = get_cpu_memory()
|
||||||
|
# FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
|
||||||
|
# group are in the same node. However, the GPUs may span multiple nodes.
|
||||||
|
num_gpus_per_node = parallel_config.tensor_parallel_size
|
||||||
|
cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
|
||||||
|
|
||||||
|
msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
|
||||||
|
f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
|
||||||
|
"allocated for the swap space.")
|
||||||
|
if cpu_memory_usage > 0.7 * total_cpu_memory:
|
||||||
|
raise ValueError("Too large swap space. " + msg)
|
||||||
|
elif cpu_memory_usage > 0.4 * total_cpu_memory:
|
||||||
|
logger.warning("Possibly too large swap space. " + msg)
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelConfig:
|
||||||
|
"""Configuration for the distributed execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pipeline_parallel_size: Number of pipeline parallel groups.
|
||||||
|
tensor_parallel_size: Number of tensor parallel groups.
|
||||||
|
worker_use_ray: Whether to use Ray for model workers. Will be set to
|
||||||
|
True if either pipeline_parallel_size or tensor_parallel_size is
|
||||||
|
greater than 1.
|
||||||
|
max_parallel_loading_workers: Maximum number of multiple batches
|
||||||
|
when load model sequentially. To avoid RAM OOM when using tensor
|
||||||
|
parallel and large models.
|
||||||
|
disable_custom_all_reduce: Disable the custom all-reduce kernel and
|
||||||
|
fall back to NCCL.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pipeline_parallel_size: int,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
worker_use_ray: bool,
|
||||||
|
max_parallel_loading_workers: Optional[int] = None,
|
||||||
|
disable_custom_all_reduce: bool = False,
|
||||||
|
) -> None:
|
||||||
|
self.pipeline_parallel_size = pipeline_parallel_size
|
||||||
|
self.tensor_parallel_size = tensor_parallel_size
|
||||||
|
self.worker_use_ray = worker_use_ray
|
||||||
|
self.max_parallel_loading_workers = max_parallel_loading_workers
|
||||||
|
self.disable_custom_all_reduce = disable_custom_all_reduce
|
||||||
|
|
||||||
|
self.world_size = pipeline_parallel_size * tensor_parallel_size
|
||||||
|
if self.world_size > 1:
|
||||||
|
self.worker_use_ray = True
|
||||||
|
self._verify_args()
|
||||||
|
|
||||||
|
def _verify_args(self) -> None:
|
||||||
|
if self.pipeline_parallel_size > 1:
|
||||||
|
raise NotImplementedError("Pipeline parallelism is not supported yet.")
|
||||||
|
if not self.disable_custom_all_reduce and self.world_size > 1:
|
||||||
|
if is_hip():
|
||||||
|
self.disable_custom_all_reduce = True
|
||||||
|
logger.info("Disabled the custom all-reduce kernel because it is not "
|
||||||
|
"supported on AMD GPUs.")
|
||||||
|
elif self.pipeline_parallel_size > 1:
|
||||||
|
self.disable_custom_all_reduce = True
|
||||||
|
logger.info("Disabled the custom all-reduce kernel because it is not "
|
||||||
|
"supported with pipeline parallelism.")
|
||||||
|
|
||||||
|
# FIXME(woosuk): Fix the stability issues and re-enable the custom
|
||||||
|
# all-reduce kernel.
|
||||||
|
if not self.disable_custom_all_reduce and self.world_size > 1:
|
||||||
|
self.disable_custom_all_reduce = True
|
||||||
|
logger.info("Custom all-reduce kernels are temporarily disabled due to "
|
||||||
|
"stability issues. We will re-enable them once the issues are "
|
||||||
|
"resolved.")
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerConfig:
|
||||||
|
"""Scheduler configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_num_batched_tokens: Maximum number of tokens to be processed in
|
||||||
|
a single iteration.
|
||||||
|
max_num_seqs: Maximum number of sequences to be processed in a single
|
||||||
|
iteration.
|
||||||
|
max_model_len: Maximum length of a sequence (including prompt
|
||||||
|
and generated text).
|
||||||
|
max_paddings: Maximum number of paddings to be added to a batch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_num_batched_tokens: Optional[int],
|
||||||
|
max_num_seqs: int,
|
||||||
|
max_model_len: int,
|
||||||
|
max_paddings: int,
|
||||||
|
) -> None:
|
||||||
|
if max_num_batched_tokens is not None:
|
||||||
|
self.max_num_batched_tokens = max_num_batched_tokens
|
||||||
|
else:
|
||||||
|
# If max_model_len is too short, use 2048 as the default value for
|
||||||
|
# higher throughput.
|
||||||
|
self.max_num_batched_tokens = max(max_model_len, 2048)
|
||||||
|
self.max_num_seqs = max_num_seqs
|
||||||
|
self.max_model_len = max_model_len
|
||||||
|
self.max_paddings = max_paddings
|
||||||
|
self._verify_args()
|
||||||
|
|
||||||
|
def _verify_args(self) -> None:
|
||||||
|
if self.max_num_batched_tokens < self.max_model_len:
|
||||||
|
raise ValueError(f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
|
||||||
|
f"smaller than max_model_len ({self.max_model_len}). "
|
||||||
|
"This effectively limits the maximum sequence length to "
|
||||||
|
"max_num_batched_tokens and makes vLLM reject longer "
|
||||||
|
"sequences. Please increase max_num_batched_tokens or "
|
||||||
|
"decrease max_model_len.")
|
||||||
|
if self.max_num_batched_tokens < self.max_num_seqs:
|
||||||
|
raise ValueError(f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
|
||||||
|
"be greater than or equal to max_num_seqs "
|
||||||
|
f"({self.max_num_seqs}).")
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceConfig:
|
||||||
|
|
||||||
|
def __init__(self, device: str = "cuda") -> None:
|
||||||
|
self.device = torch.device(device)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoRAConfig:
|
||||||
|
max_lora_rank: int
|
||||||
|
max_loras: int
|
||||||
|
max_cpu_loras: Optional[int] = None
|
||||||
|
lora_dtype: Optional[torch.dtype] = None
|
||||||
|
lora_extra_vocab_size: int = 256
|
||||||
|
# This is a constant.
|
||||||
|
lora_vocab_padding_size: ClassVar[int] = 256
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
|
||||||
|
possible_max_ranks = (8, 16, 32, 64)
|
||||||
|
possible_lora_extra_vocab_size = (0, 256, 512)
|
||||||
|
if self.max_lora_rank not in possible_max_ranks:
|
||||||
|
raise ValueError(f"max_lora_rank ({self.max_lora_rank}) must be one of "
|
||||||
|
f"{possible_max_ranks}.")
|
||||||
|
if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
|
||||||
|
raise ValueError(f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
|
||||||
|
f"must be one of {possible_lora_extra_vocab_size}.")
|
||||||
|
if self.max_loras < 1:
|
||||||
|
raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
|
||||||
|
if self.max_cpu_loras is None:
|
||||||
|
self.max_cpu_loras = self.max_loras
|
||||||
|
elif self.max_cpu_loras < self.max_loras:
|
||||||
|
raise ValueError(f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
|
||||||
|
f"max_loras ({self.max_loras})")
|
||||||
|
|
||||||
|
def verify_with_model_config(self, model_config: ModelConfig):
|
||||||
|
if self.lora_dtype in (None, "auto"):
|
||||||
|
self.lora_dtype = model_config.dtype
|
||||||
|
elif isinstance(self.lora_dtype, str):
|
||||||
|
self.lora_dtype = getattr(torch, self.lora_dtype)
|
||||||
|
if model_config.quantization is not None:
|
||||||
|
raise ValueError("LoRA is not supported with quantized models yet.")
|
||||||
|
|
||||||
|
def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
|
||||||
|
if scheduler_config.max_num_batched_tokens > 65528:
|
||||||
|
raise ValueError("Due to limitations of the custom LoRA CUDA kernel, "
|
||||||
|
"max_num_batched_tokens must be <= 65528 when "
|
||||||
|
"LoRA is enabled.")
|
||||||
|
|
||||||
|
|
||||||
|
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||||
|
"half": torch.float16,
|
||||||
|
"float16": torch.float16,
|
||||||
|
"float": torch.float32,
|
||||||
|
"float32": torch.float32,
|
||||||
|
"bfloat16": torch.bfloat16,
|
||||||
|
}
|
||||||
|
|
||||||
|
_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_and_verify_dtype(
|
||||||
|
config: PretrainedConfig,
|
||||||
|
dtype: Union[str, torch.dtype],
|
||||||
|
) -> torch.dtype:
|
||||||
|
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
||||||
|
# because config.torch_dtype can be None.
|
||||||
|
config_dtype = getattr(config, "torch_dtype", None)
|
||||||
|
if config_dtype is None:
|
||||||
|
config_dtype = torch.float32
|
||||||
|
|
||||||
|
if isinstance(dtype, str):
|
||||||
|
dtype = dtype.lower()
|
||||||
|
if dtype == "auto":
|
||||||
|
if config_dtype == torch.float32:
|
||||||
|
# Following the common practice, we use float16 for float32
|
||||||
|
# models.
|
||||||
|
torch_dtype = torch.float16
|
||||||
|
else:
|
||||||
|
torch_dtype = config_dtype
|
||||||
|
else:
|
||||||
|
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
|
||||||
|
raise ValueError(f"Unknown dtype: {dtype}")
|
||||||
|
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||||
|
elif isinstance(dtype, torch.dtype):
|
||||||
|
torch_dtype = dtype
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown dtype: {dtype}")
|
||||||
|
|
||||||
|
if is_hip() and torch_dtype == torch.float32:
|
||||||
|
rocm_supported_dtypes = [
|
||||||
|
k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items() if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
|
||||||
|
]
|
||||||
|
raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
|
||||||
|
f"Supported dtypes are {rocm_supported_dtypes}")
|
||||||
|
|
||||||
|
# Verify the dtype.
|
||||||
|
if torch_dtype != config_dtype:
|
||||||
|
if torch_dtype == torch.float32:
|
||||||
|
# Upcasting to float32 is allowed.
|
||||||
|
pass
|
||||||
|
elif config_dtype == torch.float32:
|
||||||
|
# Downcasting from float32 to float16 or bfloat16 is allowed.
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Casting between float16 and bfloat16 is allowed with a warning.
|
||||||
|
logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
|
||||||
|
|
||||||
|
return torch_dtype
|
||||||
|
|
||||||
|
|
||||||
|
def _get_and_verify_max_len(
|
||||||
|
hf_config: PretrainedConfig,
|
||||||
|
max_model_len: Optional[int],
|
||||||
|
) -> int:
|
||||||
|
"""Get and verify the model's maximum length."""
|
||||||
|
derived_max_model_len = float("inf")
|
||||||
|
possible_keys = [
|
||||||
|
# OPT
|
||||||
|
"max_position_embeddings",
|
||||||
|
# GPT-2
|
||||||
|
"n_positions",
|
||||||
|
# MPT
|
||||||
|
"max_seq_len",
|
||||||
|
# ChatGLM2
|
||||||
|
"seq_length",
|
||||||
|
# Others
|
||||||
|
"max_sequence_length",
|
||||||
|
"max_seq_length",
|
||||||
|
"seq_len",
|
||||||
|
]
|
||||||
|
for key in possible_keys:
|
||||||
|
max_len_key = getattr(hf_config, key, None)
|
||||||
|
if max_len_key is not None:
|
||||||
|
derived_max_model_len = min(derived_max_model_len, max_len_key)
|
||||||
|
if derived_max_model_len == float("inf"):
|
||||||
|
if max_model_len is not None:
|
||||||
|
# If max_model_len is specified, we use it.
|
||||||
|
return max_model_len
|
||||||
|
|
||||||
|
default_max_len = 2048
|
||||||
|
logger.warning("The model's config.json does not contain any of the following "
|
||||||
|
"keys to determine the original maximum length of the model: "
|
||||||
|
f"{possible_keys}. Assuming the model's maximum length is "
|
||||||
|
f"{default_max_len}.")
|
||||||
|
derived_max_model_len = default_max_len
|
||||||
|
|
||||||
|
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
||||||
|
if rope_scaling is not None:
|
||||||
|
assert "factor" in rope_scaling
|
||||||
|
scaling_factor = rope_scaling["factor"]
|
||||||
|
if rope_scaling["type"] == "yarn":
|
||||||
|
derived_max_model_len = rope_scaling["original_max_position_embeddings"]
|
||||||
|
derived_max_model_len *= scaling_factor
|
||||||
|
|
||||||
|
if max_model_len is None:
|
||||||
|
max_model_len = derived_max_model_len
|
||||||
|
elif max_model_len > derived_max_model_len:
|
||||||
|
raise ValueError(f"User-specified max_model_len ({max_model_len}) is greater than "
|
||||||
|
f"the derived max_model_len ({max_len_key}={derived_max_model_len}"
|
||||||
|
" in model's config.json). This may lead to incorrect model "
|
||||||
|
"outputs or CUDA errors. Make sure the value is correct and "
|
||||||
|
"within the model context size.")
|
||||||
|
return int(max_model_len)
|
||||||
275
verl/third_party/vllm/vllm_v_0_3_1/llm.py
vendored
Normal file
275
verl/third_party/vllm/vllm_v_0_3_1/llm.py
vendored
Normal file
@@ -0,0 +1,275 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
import torch.nn as nn
|
||||||
|
from .arg_utils import EngineArgs
|
||||||
|
from .llm_engine_sp import LLMEngine
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.outputs import RequestOutput
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.utils import Counter
|
||||||
|
import torch
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class LLM:
|
||||||
|
"""An LLM for generating texts from given prompts and sampling parameters.
|
||||||
|
|
||||||
|
This class includes a tokenizer, a language model (possibly distributed
|
||||||
|
across multiple GPUs), and GPU memory space allocated for intermediate
|
||||||
|
states (aka KV cache). Given a batch of prompts and sampling parameters,
|
||||||
|
this class generates texts from the model, using an intelligent batching
|
||||||
|
mechanism and efficient memory management.
|
||||||
|
|
||||||
|
NOTE: This class is intended to be used for offline inference. For online
|
||||||
|
serving, use the `AsyncLLMEngine` class instead.
|
||||||
|
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: A HuggingFace Transformers model instance.
|
||||||
|
tokenizer: A HuggingFace Transformers tokenizer instance.
|
||||||
|
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
|
||||||
|
if available, and "slow" will always use the slow tokenizer.
|
||||||
|
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
||||||
|
downloading the model and tokenizer.
|
||||||
|
tensor_parallel_size: The number of GPUs to use for distributed
|
||||||
|
execution with tensor parallelism.
|
||||||
|
dtype: The data type for the model weights and activations. Currently,
|
||||||
|
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
|
||||||
|
the `torch_dtype` attribute specified in the model config file.
|
||||||
|
However, if the `torch_dtype` in the config is `float32`, we will
|
||||||
|
use `float16` instead.
|
||||||
|
quantization: The method used to quantize the model weights. Currently,
|
||||||
|
we support "awq". If None, we assume the model weights are not
|
||||||
|
quantized and use `dtype` to determine the data type of the weights.
|
||||||
|
revision: The specific model version to use. It can be a branch name,
|
||||||
|
a tag name, or a commit id.
|
||||||
|
tokenizer_revision: The specific tokenizer version to use. It can be a
|
||||||
|
branch name, a tag name, or a commit id.
|
||||||
|
seed: The seed to initialize the random number generator for sampling.
|
||||||
|
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
|
||||||
|
reserve for the model weights, activations, and KV cache. Higher
|
||||||
|
values will increase the KV cache size and thus improve the model's
|
||||||
|
throughput. However, if the value is too high, it may cause out-of-
|
||||||
|
memory (OOM) errors.
|
||||||
|
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
|
||||||
|
This can be used for temporarily storing the states of the requests
|
||||||
|
when their `best_of` sampling parameters are larger than 1. If all
|
||||||
|
requests will have `best_of=1`, you can safely set this to 0.
|
||||||
|
Otherwise, too small values may cause out-of-memory (OOM) errors.
|
||||||
|
enforce_eager: Whether to enforce eager execution. If True, we will
|
||||||
|
disable CUDA graph and always execute the model in eager mode.
|
||||||
|
If False, we will use CUDA graph and eager execution in hybrid.
|
||||||
|
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
||||||
|
When a sequence has context length larger than this, we fall back
|
||||||
|
to eager mode.
|
||||||
|
disable_custom_all_reduce: See ParallelConfig
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Union[nn.Module, Dict], # model itself or its parameter dict
|
||||||
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer],
|
||||||
|
model_hf_config: PretrainedConfig,
|
||||||
|
tokenizer_mode: str = "auto",
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
tensor_parallel_size: int = 1,
|
||||||
|
dtype: str = "auto",
|
||||||
|
quantization: Optional[str] = None,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
tokenizer_revision: Optional[str] = None,
|
||||||
|
seed: int = 0,
|
||||||
|
gpu_memory_utilization: float = 0.9,
|
||||||
|
swap_space: int = 4,
|
||||||
|
enforce_eager: bool = False,
|
||||||
|
max_context_len_to_capture: int = 8192,
|
||||||
|
disable_custom_all_reduce: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
if "disable_log_stats" not in kwargs:
|
||||||
|
kwargs["disable_log_stats"] = True
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model_hf_config=model_hf_config,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
dtype=dtype,
|
||||||
|
quantization=quantization,
|
||||||
|
revision=revision,
|
||||||
|
tokenizer_revision=tokenizer_revision,
|
||||||
|
seed=seed,
|
||||||
|
gpu_memory_utilization=gpu_memory_utilization,
|
||||||
|
swap_space=swap_space,
|
||||||
|
enforce_eager=enforce_eager,
|
||||||
|
max_context_len_to_capture=max_context_len_to_capture,
|
||||||
|
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
tokenizer_cls = (PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer)
|
||||||
|
if not isinstance(tokenizer, tokenizer_cls):
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected tokenizer type: {type(tokenizer)}. Must be"
|
||||||
|
"one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer"
|
||||||
|
)
|
||||||
|
self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args)
|
||||||
|
self.request_counter = Counter()
|
||||||
|
|
||||||
|
def init_cache_engine(self):
|
||||||
|
self.llm_engine.init_cache_engine()
|
||||||
|
|
||||||
|
def free_cache_engine(self):
|
||||||
|
self.llm_engine.free_cache_engine()
|
||||||
|
|
||||||
|
def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||||
|
return self.llm_engine.tokenizer
|
||||||
|
|
||||||
|
def set_tokenizer(
|
||||||
|
self,
|
||||||
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
|
) -> None:
|
||||||
|
self.llm_engine.tokenizer = tokenizer
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompts: Optional[Union[str, List[str]]] = None,
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||||
|
prefix_pos: Optional[Union[int, List[int]]] = None,
|
||||||
|
use_tqdm: bool = True,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
) -> List[RequestOutput]:
|
||||||
|
"""Generates the completions for the input prompts.
|
||||||
|
|
||||||
|
NOTE: This class automatically batches the given prompts, considering
|
||||||
|
the memory constraint. For the best performance, put all of your prompts
|
||||||
|
into a single list and pass it to this method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts: A list of prompts to generate completions for.
|
||||||
|
sampling_params: The sampling parameters for text generation. If
|
||||||
|
None, we use the default sampling parameters.
|
||||||
|
prompt_token_ids: A list of token IDs for the prompts. If None, we
|
||||||
|
use the tokenizer to convert the prompts to token IDs.
|
||||||
|
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of `RequestOutput` objects containing the generated
|
||||||
|
completions in the same order as the input prompts.
|
||||||
|
"""
|
||||||
|
if prompts is None and prompt_token_ids is None:
|
||||||
|
raise ValueError("Either prompts or prompt_token_ids must be "
|
||||||
|
"provided.")
|
||||||
|
if isinstance(prompts, str):
|
||||||
|
# Convert a single prompt to a list.
|
||||||
|
prompts = [prompts]
|
||||||
|
if prompts is not None and prompt_token_ids is not None:
|
||||||
|
if len(prompts) != len(prompt_token_ids):
|
||||||
|
raise ValueError("The lengths of prompts and prompt_token_ids "
|
||||||
|
"must be the same.")
|
||||||
|
if sampling_params is None:
|
||||||
|
# Use default sampling params.
|
||||||
|
sampling_params = SamplingParams()
|
||||||
|
|
||||||
|
# Add requests to the engine.
|
||||||
|
num_requests = len(prompts) if prompts is not None else len(prompt_token_ids)
|
||||||
|
for i in range(num_requests):
|
||||||
|
prompt = prompts[i] if prompts is not None else None
|
||||||
|
prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None
|
||||||
|
token_ids = None if prompt_token_ids is None else prompt_token_ids[i]
|
||||||
|
if not isinstance(token_ids, list):
|
||||||
|
# NOTE(shengguangming): convert the rollout input into List[str]
|
||||||
|
token_ids = self._pre_process_inputs(token_ids)
|
||||||
|
self._add_request(prompt, sampling_params, token_ids, lora_request=lora_request, prefix_pos=prefix_pos_i)
|
||||||
|
return self._run_engine(use_tqdm)
|
||||||
|
|
||||||
|
def _add_request(
|
||||||
|
self,
|
||||||
|
prompt: Optional[str],
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
prompt_token_ids: Optional[List[int]],
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
prefix_pos: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
request_id = str(next(self.request_counter))
|
||||||
|
self.llm_engine.add_request(request_id,
|
||||||
|
prompt,
|
||||||
|
sampling_params,
|
||||||
|
prompt_token_ids,
|
||||||
|
lora_request=lora_request,
|
||||||
|
prefix_pos=prefix_pos)
|
||||||
|
|
||||||
|
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
|
||||||
|
# Initialize tqdm.
|
||||||
|
if use_tqdm:
|
||||||
|
num_requests = self.llm_engine.get_num_unfinished_requests()
|
||||||
|
pbar = tqdm(total=num_requests, desc="Processed prompts")
|
||||||
|
# Run the engine.
|
||||||
|
outputs: List[RequestOutput] = []
|
||||||
|
while self.llm_engine.has_unfinished_requests():
|
||||||
|
step_outputs = self.llm_engine.step()
|
||||||
|
for output in step_outputs:
|
||||||
|
if output.finished:
|
||||||
|
outputs.append(output)
|
||||||
|
if use_tqdm:
|
||||||
|
pbar.update(1)
|
||||||
|
if use_tqdm:
|
||||||
|
pbar.close()
|
||||||
|
# Sort the outputs by request ID.
|
||||||
|
# This is necessary because some requests may be finished earlier than
|
||||||
|
# its previous requests.
|
||||||
|
outputs = sorted(outputs, key=lambda x: int(x.request_id))
|
||||||
|
# TODO(shengguangming): maybe we can hack the autoregressive logics without only apply post process for better performance
|
||||||
|
return self._post_process_outputs(outputs)
|
||||||
|
|
||||||
|
# NOTE(shengguangming): add for verl
|
||||||
|
# TODO(sgm): we can optimize it by making the dataloader yield List[int] without padding.
|
||||||
|
def _pre_process_inputs(self, prompt_token_ids: torch.Tensor) -> List[int]:
|
||||||
|
# remove the left padding in the prompt token_id
|
||||||
|
pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
|
||||||
|
non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
|
||||||
|
token_ids = prompt_token_ids[non_pad_index:].tolist()
|
||||||
|
return token_ids
|
||||||
|
|
||||||
|
# NOTE(shengguangming): add for verl
|
||||||
|
def _post_process_outputs(self, outputs: List[RequestOutput]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
output_token_ids = []
|
||||||
|
logprobs = []
|
||||||
|
for output in outputs: # List[RequestOutput]
|
||||||
|
output = output.outputs
|
||||||
|
for output in output: # List[CompletionOutput], usually len == 1
|
||||||
|
output_token_ids.append(torch.tensor(output.token_ids))
|
||||||
|
# TODO(shengguangming): can be optimzied by rewrite the Sampler._get_logprobs() logits
|
||||||
|
logprobs_dicts = output.logprobs
|
||||||
|
if logprobs_dicts is not None:
|
||||||
|
logprob = []
|
||||||
|
for logprobs_dict, id in zip(logprobs_dicts, output.token_ids):
|
||||||
|
logprob.append(logprobs_dict[id])
|
||||||
|
logprobs.append(torch.tensor(logprob))
|
||||||
|
|
||||||
|
pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
|
||||||
|
output_token_ids = pad_sequence(output_token_ids, batch_first=True, padding_value=pad_token_id)
|
||||||
|
if len(logprobs) > 0:
|
||||||
|
logprobs = pad_sequence(logprobs, batch_first=True, padding_value=pad_token_id)
|
||||||
|
return output_token_ids, logprobs
|
||||||
|
|
||||||
|
def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor]) -> None:
|
||||||
|
self.llm_engine.sync_model_weights(actor_weights=actor_weights)
|
||||||
|
|
||||||
|
def offload_model_weights(self) -> None:
|
||||||
|
self.llm_engine.offload_model_weights()
|
||||||
765
verl/third_party/vllm/vllm_v_0_3_1/llm_engine_sp.py
vendored
Normal file
765
verl/third_party/vllm/vllm_v_0_3_1/llm_engine_sp.py
vendored
Normal file
@@ -0,0 +1,765 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/llm_engine.py
|
||||||
|
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig)
|
||||||
|
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.outputs import RequestOutput
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, SequenceGroupMetadata, SequenceGroupOutput,
|
||||||
|
SequenceOutput, SequenceStatus)
|
||||||
|
from vllm.transformers_utils.tokenizer import detokenize_incrementally
|
||||||
|
from vllm.engine.metrics import StatLogger, Stats
|
||||||
|
from vllm.utils import Counter
|
||||||
|
import torch.nn as nn
|
||||||
|
from .arg_utils import EngineArgs
|
||||||
|
from .tokenizer import TokenizerGroup
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
||||||
|
|
||||||
|
|
||||||
|
class LLMEngine:
|
||||||
|
"""An LLM engine that receives requests and generates texts.
|
||||||
|
|
||||||
|
This is the main class for the vLLM engine. It receives requests
|
||||||
|
from clients and generates texts from the LLM. It includes a tokenizer, a
|
||||||
|
language model (possibly distributed across multiple GPUs), and GPU memory
|
||||||
|
space allocated for intermediate states (aka KV cache). This class utilizes
|
||||||
|
iteration-level scheduling and efficient memory management to maximize the
|
||||||
|
serving throughput.
|
||||||
|
|
||||||
|
The `LLM` class wraps this class for offline batched inference and the
|
||||||
|
`AsyncLLMEngine` class wraps this class for online serving.
|
||||||
|
|
||||||
|
NOTE: The config arguments are derived from the `EngineArgs` class. For the
|
||||||
|
comprehensive list of arguments, see `EngineArgs`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_config: The configuration related to the LLM model.
|
||||||
|
cache_config: The configuration related to the KV cache memory
|
||||||
|
management.
|
||||||
|
parallel_config: The configuration related to distributed execution.
|
||||||
|
scheduler_config: The configuration related to the request scheduler.
|
||||||
|
distributed_init_method: The initialization method for distributed
|
||||||
|
execution. See `torch.distributed.init_process_group` for details.
|
||||||
|
placement_group: Ray placement group for distributed execution.
|
||||||
|
Required for distributed execution.
|
||||||
|
log_stats: Whether to log statistics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Union[nn.Module, Dict], # model itself or its parameter dict
|
||||||
|
tokenizer: nn.Module,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
cache_config: CacheConfig,
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
scheduler_config: SchedulerConfig,
|
||||||
|
device_config: DeviceConfig,
|
||||||
|
lora_config: Optional[LoRAConfig],
|
||||||
|
distributed_init_method: str,
|
||||||
|
placement_group: Optional[None],
|
||||||
|
log_stats: bool,
|
||||||
|
) -> None:
|
||||||
|
logger.info("Initializing an LLM engine with config: "
|
||||||
|
f"model={model_config.model!r}, "
|
||||||
|
f"tokenizer={model_config.tokenizer!r}, "
|
||||||
|
# f"tokenizer_mode={model_config.tokenizer_mode}, "
|
||||||
|
f"revision={model_config.revision}, "
|
||||||
|
f"tokenizer_revision={model_config.tokenizer_revision}, "
|
||||||
|
# f"trust_remote_code={model_config.trust_remote_code}, "
|
||||||
|
f"dtype={model_config.dtype}, "
|
||||||
|
f"max_seq_len={model_config.max_model_len}, "
|
||||||
|
# f"download_dir={model_config.download_dir!r}, "
|
||||||
|
# f"load_format={model_config.load_format}, "
|
||||||
|
f"disable_custom_all_reduce={parallel_config.disable_custom_all_reduce}, "
|
||||||
|
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
||||||
|
f"quantization={model_config.quantization}, "
|
||||||
|
f"seed={model_config.seed})")
|
||||||
|
# TODO(woosuk): Print more configs in debug mode.
|
||||||
|
|
||||||
|
self.model_config = model_config # TODO: currently is hfconfig
|
||||||
|
self.cache_config = cache_config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
assert self.cache_config.sliding_window == getattr(self.model_config.hf_config, "sliding_window", None)
|
||||||
|
self.parallel_config = parallel_config
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.device_config = device_config
|
||||||
|
self.log_stats = log_stats
|
||||||
|
self._verify_args()
|
||||||
|
|
||||||
|
# self.model = model # should not store the model, it should be deleted
|
||||||
|
# TODO(shengguangming): maybe we can choose init here or from arguments
|
||||||
|
self._init_tokenizer(tokenizer)
|
||||||
|
|
||||||
|
self.seq_counter = Counter()
|
||||||
|
|
||||||
|
# Create the parallel GPU workers.
|
||||||
|
self._init_workers_sp(model, distributed_init_method)
|
||||||
|
|
||||||
|
# Profile the memory usage and initialize the cache.
|
||||||
|
self._init_cache_sp()
|
||||||
|
|
||||||
|
# Create the scheduler.
|
||||||
|
# NOTE(shengguangming): each process will have independent scheduler
|
||||||
|
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
|
||||||
|
|
||||||
|
# Metric Logging.
|
||||||
|
if self.log_stats:
|
||||||
|
self.stat_logger = StatLogger(local_interval=_LOCAL_LOGGING_INTERVAL_SEC)
|
||||||
|
|
||||||
|
# Logging.
|
||||||
|
self.last_logging_time = 0.0
|
||||||
|
# List of (timestamp, num_tokens)
|
||||||
|
self.num_prompt_tokens: List[Tuple[float, int]] = []
|
||||||
|
# List of (timestamp, num_tokens)
|
||||||
|
self.num_generation_tokens: List[Tuple[float, int]] = []
|
||||||
|
|
||||||
|
def _init_tokenizer(self, tokenizer, **tokenizer_init_kwargs):
|
||||||
|
init_kwargs = dict(enable_lora=bool(self.lora_config),
|
||||||
|
max_num_seqs=self.scheduler_config.max_num_seqs,
|
||||||
|
max_input_length=None)
|
||||||
|
init_kwargs.update(tokenizer_init_kwargs)
|
||||||
|
self.tokenizer: TokenizerGroup = TokenizerGroup(tokenizer, **init_kwargs)
|
||||||
|
|
||||||
|
# TODO: check get_lora_tokenizer func
|
||||||
|
def get_tokenizer_for_seq(self, sequence: Sequence):
|
||||||
|
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
|
||||||
|
|
||||||
|
def _init_workers_sp(self, model, distributed_init_method: str):
|
||||||
|
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||||
|
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||||
|
from .worker import Worker # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
|
rank = int(os.getenv("RANK"))
|
||||||
|
|
||||||
|
self.worker = Worker(
|
||||||
|
model,
|
||||||
|
self.model_config,
|
||||||
|
self.parallel_config,
|
||||||
|
self.scheduler_config,
|
||||||
|
self.device_config,
|
||||||
|
rank,
|
||||||
|
distributed_init_method,
|
||||||
|
lora_config=self.lora_config,
|
||||||
|
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE(shengguangming): torch.distributed.init_process_group will be called inside the init_model()
|
||||||
|
self.worker.init_model()
|
||||||
|
self.worker.load_model()
|
||||||
|
|
||||||
|
def _verify_args(self) -> None:
|
||||||
|
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||||
|
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||||
|
|
||||||
|
def _init_cache_sp(self) -> None:
|
||||||
|
"""Profiles the memory usage and initializes the KV cache."""
|
||||||
|
# Get the maximum number of blocks that can be allocated on GPU and CPU.
|
||||||
|
num_blocks = self.worker.profile_num_available_blocks(
|
||||||
|
block_size=self.cache_config.block_size,
|
||||||
|
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
|
||||||
|
cpu_swap_space=self.cache_config.swap_space_bytes,
|
||||||
|
cache_dtype=self.cache_config.cache_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE(shengguangming): Now we don't use a shared centralized controler but each process will
|
||||||
|
# have its own scheduler
|
||||||
|
num_gpu_blocks = num_blocks[0]
|
||||||
|
num_cpu_blocks = num_blocks[1]
|
||||||
|
|
||||||
|
# FIXME(woosuk): Change to debug log.
|
||||||
|
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
|
||||||
|
f"# CPU blocks: {num_cpu_blocks}")
|
||||||
|
|
||||||
|
if num_gpu_blocks <= 0:
|
||||||
|
raise ValueError("No available memory for the cache blocks. "
|
||||||
|
"Try increasing `gpu_memory_utilization` when "
|
||||||
|
"initializing the engine.")
|
||||||
|
|
||||||
|
max_seq_len = self.cache_config.block_size * num_gpu_blocks
|
||||||
|
if self.model_config.max_model_len > max_seq_len:
|
||||||
|
raise ValueError(f"The model's max seq len ({self.model_config.max_model_len}) "
|
||||||
|
"is larger than the maximum number of tokens that can be "
|
||||||
|
f"stored in KV cache ({max_seq_len}). Try increasing "
|
||||||
|
"`gpu_memory_utilization` or decreasing `max_model_len` when "
|
||||||
|
"initializing the engine.")
|
||||||
|
|
||||||
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||||
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||||
|
|
||||||
|
# Initialize the cache.
|
||||||
|
self.worker.init_cache_engine(cache_config=self.cache_config)
|
||||||
|
self.worker.warm_up_model()
|
||||||
|
|
||||||
|
def init_cache_engine(self):
|
||||||
|
self.worker.init_cache_engine(cache_config=self.cache_config)
|
||||||
|
|
||||||
|
def free_cache_engine(self):
|
||||||
|
self.worker.free_cache_engine()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_engine_args(cls, model, tokenizer, engine_args: EngineArgs) -> "LLMEngine":
|
||||||
|
"""Creates an LLM engine from the engine arguments."""
|
||||||
|
# Create the engine configs.
|
||||||
|
engine_configs = engine_args.create_engine_configs()
|
||||||
|
parallel_config = engine_configs[2]
|
||||||
|
# Initialize the cluster.
|
||||||
|
distributed_init_method, placement_group = initialize_cluster(parallel_config)
|
||||||
|
# Create the LLM engine.
|
||||||
|
engine = cls(model,
|
||||||
|
tokenizer,
|
||||||
|
*engine_configs,
|
||||||
|
distributed_init_method,
|
||||||
|
placement_group,
|
||||||
|
log_stats=not engine_args.disable_log_stats)
|
||||||
|
return engine
|
||||||
|
|
||||||
|
def add_request(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
prompt: Optional[str],
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
|
arrival_time: Optional[float] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
prefix_pos: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Add a request to the engine's request pool.
|
||||||
|
|
||||||
|
The request is added to the request pool and will be processed by the
|
||||||
|
scheduler as `engine.step()` is called. The exact scheduling policy is
|
||||||
|
determined by the scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id: The unique ID of the request.
|
||||||
|
prompt: The prompt string. Can be None if prompt_token_ids is
|
||||||
|
provided.
|
||||||
|
sampling_params: The sampling parameters for text generation.
|
||||||
|
prompt_token_ids: The token IDs of the prompt. If None, we
|
||||||
|
use the tokenizer to convert the prompts to token IDs.
|
||||||
|
arrival_time: The arrival time of the request. If None, we use
|
||||||
|
the current monotonic time.
|
||||||
|
prefix_pos: If not None, we use the given position as the prefix
|
||||||
|
position for each prompt. We will cache the prefix's KV
|
||||||
|
cache and reuse it for the next request with the same prefix.
|
||||||
|
This is an experimental feature, and may be replaced with
|
||||||
|
automatic prefix caching in the future.
|
||||||
|
|
||||||
|
Details:
|
||||||
|
- Set arrival_time to the current time if it is None.
|
||||||
|
- Set prompt_token_ids to the encoded prompt if it is None.
|
||||||
|
- Create `best_of` number of :class:`~vllm.Sequence` objects.
|
||||||
|
- Create a :class:`~vllm.SequenceGroup` object
|
||||||
|
from the list of :class:`~vllm.Sequence`.
|
||||||
|
- Add the :class:`~vllm.SequenceGroup` object to the scheduler.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> # initialize engine
|
||||||
|
>>> engine = LLMEngine.from_engine_args(engine_args)
|
||||||
|
>>> # set request arguments
|
||||||
|
>>> example_prompt = "Who is the president of the United States?"
|
||||||
|
>>> sampling_params = SamplingParams(temperature=0.0)
|
||||||
|
>>> request_id = 0
|
||||||
|
>>>
|
||||||
|
>>> # add the request to the engine
|
||||||
|
>>> engine.add_request(
|
||||||
|
>>> str(request_id),
|
||||||
|
>>> example_prompt,
|
||||||
|
>>> SamplingParams(temperature=0.0))
|
||||||
|
>>> # continue the request processing
|
||||||
|
>>> ...
|
||||||
|
"""
|
||||||
|
if lora_request is not None and not self.lora_config:
|
||||||
|
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||||
|
"not enabled!")
|
||||||
|
if arrival_time is None:
|
||||||
|
arrival_time = time.monotonic()
|
||||||
|
if prompt_token_ids is None:
|
||||||
|
assert prompt is not None
|
||||||
|
prompt_token_ids = self.tokenizer.encode(prompt)
|
||||||
|
|
||||||
|
# Create the sequences.
|
||||||
|
block_size = self.cache_config.block_size
|
||||||
|
seq_id = next(self.seq_counter)
|
||||||
|
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, lora_request)
|
||||||
|
|
||||||
|
# Check whether the input specifies prefix
|
||||||
|
prefix = self.scheduler.prefix_pool.add_or_get_prefix(prompt_token_ids[:prefix_pos], lora_request.lora_int_id if
|
||||||
|
lora_request else 0) if prefix_pos is not None else None
|
||||||
|
|
||||||
|
# Create the sequence group.
|
||||||
|
seq_group = SequenceGroup(request_id, [seq], sampling_params, arrival_time, lora_request, prefix)
|
||||||
|
|
||||||
|
# Add the sequence group to the scheduler.
|
||||||
|
self.scheduler.add_seq_group(seq_group)
|
||||||
|
|
||||||
|
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
|
||||||
|
"""Aborts a request(s) with the given ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id: The ID(s) of the request to abort.
|
||||||
|
|
||||||
|
Details:
|
||||||
|
- Refer to the
|
||||||
|
:meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
|
||||||
|
from class :class:`~vllm.core.scheduler.Scheduler`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> # initialize engine and add a request with request_id
|
||||||
|
>>> request_id = str(0)
|
||||||
|
>>> # abort the request
|
||||||
|
>>> engine.abort_request(request_id)
|
||||||
|
"""
|
||||||
|
self.scheduler.abort_seq_group(request_id)
|
||||||
|
|
||||||
|
def get_model_config(self) -> ModelConfig:
|
||||||
|
"""Gets the model configuration."""
|
||||||
|
return self.model_config
|
||||||
|
|
||||||
|
def get_num_unfinished_requests(self) -> int:
|
||||||
|
"""Gets the number of unfinished requests."""
|
||||||
|
return self.scheduler.get_num_unfinished_seq_groups()
|
||||||
|
|
||||||
|
def has_unfinished_requests(self) -> bool:
|
||||||
|
"""Returns True if there are unfinished requests."""
|
||||||
|
return self.scheduler.has_unfinished_seqs()
|
||||||
|
|
||||||
|
def _check_beam_search_early_stopping(
|
||||||
|
self,
|
||||||
|
early_stopping: Union[bool, str],
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
best_running_seq: Sequence,
|
||||||
|
current_worst_seq: Sequence,
|
||||||
|
) -> bool:
|
||||||
|
assert sampling_params.use_beam_search
|
||||||
|
length_penalty = sampling_params.length_penalty
|
||||||
|
if early_stopping is True:
|
||||||
|
return True
|
||||||
|
|
||||||
|
current_worst_score = (current_worst_seq.get_beam_search_score(
|
||||||
|
length_penalty=length_penalty, eos_token_id=self.get_tokenizer_for_seq(current_worst_seq).eos_token_id))
|
||||||
|
if early_stopping is False:
|
||||||
|
highest_attainable_score = (best_running_seq.get_beam_search_score(
|
||||||
|
length_penalty=length_penalty, eos_token_id=self.get_tokenizer_for_seq(best_running_seq).eos_token_id))
|
||||||
|
else:
|
||||||
|
assert early_stopping == "never"
|
||||||
|
if length_penalty > 0.0:
|
||||||
|
# If length_penalty > 0.0, beam search will prefer longer
|
||||||
|
# sequences. The highest attainable score calculation is
|
||||||
|
# based on the longest possible sequence length in this case.
|
||||||
|
max_possible_length = max(best_running_seq.get_prompt_len() + sampling_params.max_tokens,
|
||||||
|
self.scheduler_config.max_model_len)
|
||||||
|
highest_attainable_score = (best_running_seq.get_beam_search_score(
|
||||||
|
length_penalty=length_penalty,
|
||||||
|
eos_token_id=self.get_tokenizer_for_seq(best_running_seq).eos_token_id,
|
||||||
|
seq_len=max_possible_length))
|
||||||
|
else:
|
||||||
|
# Otherwise, beam search will prefer shorter sequences. The
|
||||||
|
# highest attainable score calculation is based on the current
|
||||||
|
# sequence length.
|
||||||
|
highest_attainable_score = (best_running_seq.get_beam_search_score(
|
||||||
|
length_penalty=length_penalty,
|
||||||
|
eos_token_id=self.get_tokenizer_for_seq(best_running_seq).eos_token_id))
|
||||||
|
|
||||||
|
def _process_sequence_group_outputs(self, seq_group: SequenceGroup, outputs: SequenceGroupOutput) -> None:
|
||||||
|
|
||||||
|
# Process prompt logprobs
|
||||||
|
prompt_logprobs = outputs.prompt_logprobs
|
||||||
|
if prompt_logprobs is not None:
|
||||||
|
seq_group.prompt_logprobs = prompt_logprobs
|
||||||
|
|
||||||
|
# Process samples
|
||||||
|
samples = outputs.samples
|
||||||
|
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||||
|
existing_finished_seqs = seq_group.get_finished_seqs()
|
||||||
|
parent_child_dict = {parent_seq.seq_id: [] for parent_seq in parent_seqs}
|
||||||
|
for sample in samples:
|
||||||
|
parent_child_dict[sample.parent_seq_id].append(sample)
|
||||||
|
# List of (child, parent)
|
||||||
|
child_seqs: List[Tuple[Sequence, Sequence]] = []
|
||||||
|
|
||||||
|
# Process the child samples for each parent sequence
|
||||||
|
for parent in parent_seqs:
|
||||||
|
child_samples: List[SequenceOutput] = parent_child_dict[parent.seq_id]
|
||||||
|
if len(child_samples) == 0:
|
||||||
|
# This parent sequence has no children samples. Remove
|
||||||
|
# the parent sequence from the sequence group since it will
|
||||||
|
# not be used in the future iterations.
|
||||||
|
parent.status = SequenceStatus.FINISHED_ABORTED
|
||||||
|
seq_group.remove(parent.seq_id)
|
||||||
|
self.scheduler.free_seq(parent)
|
||||||
|
continue
|
||||||
|
# Fork the parent sequence if there are multiple child samples.
|
||||||
|
for child_sample in child_samples[:-1]:
|
||||||
|
new_child_seq_id = next(self.seq_counter)
|
||||||
|
child = parent.fork(new_child_seq_id)
|
||||||
|
child.append_token_id(child_sample.output_token, child_sample.logprobs)
|
||||||
|
child_seqs.append((child, parent))
|
||||||
|
# Continue the parent sequence for the last child sample.
|
||||||
|
# We reuse the parent sequence here to reduce redundant memory
|
||||||
|
# copies, especially when using non-beam search sampling methods.
|
||||||
|
last_child_sample = child_samples[-1]
|
||||||
|
parent.append_token_id(last_child_sample.output_token, last_child_sample.logprobs)
|
||||||
|
child_seqs.append((parent, parent))
|
||||||
|
|
||||||
|
for seq, _ in child_seqs:
|
||||||
|
# self._decode_sequence(seq, seq_group.sampling_params)
|
||||||
|
self._check_stop(seq, seq_group.sampling_params)
|
||||||
|
|
||||||
|
# Non-beam search case
|
||||||
|
if not seq_group.sampling_params.use_beam_search:
|
||||||
|
# For newly created child sequences, add them to the sequence group
|
||||||
|
# and fork them in block manager if they are not finished.
|
||||||
|
for seq, parent in child_seqs:
|
||||||
|
if seq is not parent:
|
||||||
|
seq_group.add(seq)
|
||||||
|
if not seq.is_finished():
|
||||||
|
self.scheduler.fork_seq(parent, seq)
|
||||||
|
|
||||||
|
# Free the finished and selected parent sequences' memory in block
|
||||||
|
# manager. Keep them in the sequence group as candidate output.
|
||||||
|
# NOTE: we need to fork the new sequences before freeing the
|
||||||
|
# old sequences.
|
||||||
|
for seq, parent in child_seqs:
|
||||||
|
if seq is parent and seq.is_finished():
|
||||||
|
self.scheduler.free_seq(seq)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Beam search case
|
||||||
|
# Select the child sequences to keep in the sequence group.
|
||||||
|
selected_child_seqs = []
|
||||||
|
unselected_child_seqs = []
|
||||||
|
beam_width = seq_group.sampling_params.best_of
|
||||||
|
length_penalty = seq_group.sampling_params.length_penalty
|
||||||
|
|
||||||
|
# Select the newly finished sequences with the highest scores
|
||||||
|
# to replace existing finished sequences.
|
||||||
|
# Tuple of (seq, parent, is_new)
|
||||||
|
existing_finished_seqs = [(seq, None, False) for seq in existing_finished_seqs]
|
||||||
|
new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs if seq.is_finished()]
|
||||||
|
all_finished_seqs = existing_finished_seqs + new_finished_seqs
|
||||||
|
# Sort the finished sequences by their scores.
|
||||||
|
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
||||||
|
length_penalty=length_penalty, eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
|
||||||
|
reverse=True)
|
||||||
|
for seq, parent, is_new in all_finished_seqs[:beam_width]:
|
||||||
|
if is_new:
|
||||||
|
# A newly generated child sequence finishes and has a high
|
||||||
|
# score, so we will add it into the sequence group.
|
||||||
|
selected_child_seqs.append((seq, parent))
|
||||||
|
for seq, parent, is_new in all_finished_seqs[beam_width:]:
|
||||||
|
if is_new:
|
||||||
|
# A newly generated child sequence finishes but has a low
|
||||||
|
# score, so we will not add it into the sequence group.
|
||||||
|
# Additionally, if this sequence is a continuation of a
|
||||||
|
# parent sequence, we will need remove the parent sequence
|
||||||
|
# from the sequence group.
|
||||||
|
unselected_child_seqs.append((seq, parent))
|
||||||
|
else:
|
||||||
|
# An existing finished sequence has a low score, so we will
|
||||||
|
# remove it from the sequence group.
|
||||||
|
seq_group.remove(seq.seq_id)
|
||||||
|
|
||||||
|
# select the top beam_width sequences from the running
|
||||||
|
# sequences for the next iteration to continue the beam
|
||||||
|
# search.
|
||||||
|
running_child_seqs = [(seq, parent) for seq, parent in child_seqs if not seq.is_finished()]
|
||||||
|
# Sort the running sequences by their scores.
|
||||||
|
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
||||||
|
length_penalty=length_penalty, eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
|
||||||
|
reverse=True)
|
||||||
|
|
||||||
|
# Check if we can stop the beam search.
|
||||||
|
if len(running_child_seqs) == 0:
|
||||||
|
# No running sequences, stop the beam search.
|
||||||
|
stop_beam_search = True
|
||||||
|
elif len(all_finished_seqs) < beam_width:
|
||||||
|
# Not enough finished sequences, continue the beam search.
|
||||||
|
stop_beam_search = False
|
||||||
|
else:
|
||||||
|
# Check the early stopping criteria
|
||||||
|
best_running_seq = running_child_seqs[0][0]
|
||||||
|
current_worst_seq = all_finished_seqs[beam_width - 1][0]
|
||||||
|
stop_beam_search = self._check_beam_search_early_stopping(seq_group.sampling_params.early_stopping,
|
||||||
|
seq_group.sampling_params, best_running_seq,
|
||||||
|
current_worst_seq)
|
||||||
|
|
||||||
|
if stop_beam_search:
|
||||||
|
# Stop the beam search and remove all the running sequences from
|
||||||
|
# the sequence group.
|
||||||
|
unselected_child_seqs.extend(running_child_seqs)
|
||||||
|
else:
|
||||||
|
# Continue the beam search and select the top beam_width sequences
|
||||||
|
# to continue the beam search.
|
||||||
|
selected_child_seqs.extend(running_child_seqs[:beam_width])
|
||||||
|
# The remaining running sequences will not be used in the next
|
||||||
|
# iteration. Again, if these sequences are continuations of
|
||||||
|
# parent sequences, we will need to remove the parent sequences
|
||||||
|
# from the sequence group.
|
||||||
|
unselected_child_seqs.extend(running_child_seqs[beam_width:])
|
||||||
|
|
||||||
|
# For newly created child sequences, add them to the sequence group
|
||||||
|
# and fork them in block manager if they are not finished.
|
||||||
|
for seq, parent in selected_child_seqs:
|
||||||
|
if seq is not parent:
|
||||||
|
seq_group.add(seq)
|
||||||
|
if not seq.is_finished():
|
||||||
|
self.scheduler.fork_seq(parent, seq)
|
||||||
|
|
||||||
|
# Free the finished and selected parent sequences' memory in block
|
||||||
|
# manager. Keep them in the sequence group as candidate output.
|
||||||
|
for seq, parent in selected_child_seqs:
|
||||||
|
if seq is parent and seq.is_finished():
|
||||||
|
self.scheduler.free_seq(seq)
|
||||||
|
|
||||||
|
# Remove the unselected parent sequences from the sequence group and
|
||||||
|
# free their memory in block manager.
|
||||||
|
for seq, parent in unselected_child_seqs:
|
||||||
|
if seq is parent:
|
||||||
|
# Remove the parent sequence if it is not selected for next
|
||||||
|
# iteration
|
||||||
|
seq_group.remove(seq.seq_id)
|
||||||
|
self.scheduler.free_seq(seq)
|
||||||
|
|
||||||
|
def _process_model_outputs(self, output: SamplerOutput, scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
|
||||||
|
# Update the scheduled sequence groups with the model outputs.
|
||||||
|
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
|
||||||
|
for seq_group, outputs in zip(scheduled_seq_groups, output):
|
||||||
|
self._process_sequence_group_outputs(seq_group, outputs)
|
||||||
|
|
||||||
|
# Free the finished sequence groups.
|
||||||
|
self.scheduler.free_finished_seq_groups()
|
||||||
|
|
||||||
|
# Create the outputs.
|
||||||
|
request_outputs: List[RequestOutput] = []
|
||||||
|
for seq_group in scheduled_seq_groups:
|
||||||
|
request_output = RequestOutput.from_seq_group(seq_group)
|
||||||
|
request_outputs.append(request_output)
|
||||||
|
for seq_group in scheduler_outputs.ignored_seq_groups:
|
||||||
|
request_output = RequestOutput.from_seq_group(seq_group)
|
||||||
|
request_outputs.append(request_output)
|
||||||
|
|
||||||
|
# Update prefix state, now all the uncomputed prefixes are computed.
|
||||||
|
for seq_group in scheduled_seq_groups:
|
||||||
|
if (seq_group.prefix is not None and seq_group.prefix.allocated and not seq_group.prefix.computed):
|
||||||
|
seq_group.prefix.computed = True
|
||||||
|
|
||||||
|
# Log stats.
|
||||||
|
if self.log_stats:
|
||||||
|
self.stat_logger.log(self._get_stats(scheduler_outputs))
|
||||||
|
|
||||||
|
return request_outputs
|
||||||
|
|
||||||
|
def step(self) -> List[RequestOutput]:
|
||||||
|
"""Performs one decoding iteration and returns newly generated results.
|
||||||
|
|
||||||
|
This function performs one decoding iteration of the engine. It first
|
||||||
|
schedules the sequences to be executed in the next iteration and the
|
||||||
|
token blocks to be swapped in/out/copy. Then, it executes the model
|
||||||
|
and updates the scheduler with the model outputs. Finally, it decodes
|
||||||
|
the sequences and returns the newly generated results.
|
||||||
|
"""
|
||||||
|
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||||
|
if not scheduler_outputs.is_empty():
|
||||||
|
output = self.worker.execute_model(
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list, # TODO: check this input
|
||||||
|
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
||||||
|
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||||
|
blocks_to_copy=scheduler_outputs.blocks_to_copy,)
|
||||||
|
else:
|
||||||
|
return [RequestOutput.from_seq_group(seq_group) for seq_group in scheduler_outputs.ignored_seq_groups]
|
||||||
|
|
||||||
|
return self._process_model_outputs(output, scheduler_outputs)
|
||||||
|
|
||||||
|
def do_log_stats(self) -> None:
|
||||||
|
"""Forced log when no requests active."""
|
||||||
|
if self.log_stats:
|
||||||
|
self.stat_logger.log(self._get_stats(scheduler_outputs=None))
|
||||||
|
|
||||||
|
def _get_stats(self, scheduler_outputs: Optional[SchedulerOutputs]) -> Stats:
|
||||||
|
"""Get Stats to be Logged to Prometheus."""
|
||||||
|
now = time.monotonic()
|
||||||
|
|
||||||
|
# KV Cache Usage in %.
|
||||||
|
num_total_gpu = self.cache_config.num_gpu_blocks
|
||||||
|
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
|
||||||
|
gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu)
|
||||||
|
|
||||||
|
num_total_cpu = self.cache_config.num_cpu_blocks
|
||||||
|
cpu_cache_usage = 0.
|
||||||
|
if num_total_cpu > 0:
|
||||||
|
num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks()
|
||||||
|
cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu)
|
||||||
|
|
||||||
|
# Scheduler State
|
||||||
|
num_running = len(self.scheduler.running)
|
||||||
|
num_swapped = len(self.scheduler.swapped)
|
||||||
|
num_waiting = len(self.scheduler.waiting)
|
||||||
|
|
||||||
|
# Iteration stats if we have scheduler output.
|
||||||
|
num_prompt_tokens = 0
|
||||||
|
num_generation_tokens = 0
|
||||||
|
time_to_first_tokens = []
|
||||||
|
time_per_output_tokens = []
|
||||||
|
time_e2e_requests = []
|
||||||
|
if scheduler_outputs is not None:
|
||||||
|
prompt_run = scheduler_outputs.prompt_run
|
||||||
|
|
||||||
|
# Number of Tokens.
|
||||||
|
if prompt_run:
|
||||||
|
num_prompt_tokens = scheduler_outputs.num_batched_tokens
|
||||||
|
else:
|
||||||
|
num_generation_tokens = scheduler_outputs.num_batched_tokens
|
||||||
|
|
||||||
|
# Latency Timings.
|
||||||
|
time_last_iters = []
|
||||||
|
for seq_group in scheduler_outputs.scheduled_seq_groups:
|
||||||
|
# Time since last token. (n.b. updates seq_group.last_token_time)
|
||||||
|
time_last_iters.append(seq_group.get_last_latency(now))
|
||||||
|
# Time since arrival for all finished requests.
|
||||||
|
if seq_group.is_finished():
|
||||||
|
time_e2e_requests.append(now - seq_group.arrival_time)
|
||||||
|
|
||||||
|
time_to_first_tokens = time_last_iters if prompt_run else []
|
||||||
|
time_per_output_tokens = [] if prompt_run else time_last_iters
|
||||||
|
|
||||||
|
return Stats(
|
||||||
|
now=now,
|
||||||
|
num_running=num_running,
|
||||||
|
num_swapped=num_swapped,
|
||||||
|
num_waiting=num_waiting,
|
||||||
|
gpu_cache_usage=gpu_cache_usage,
|
||||||
|
cpu_cache_usage=cpu_cache_usage,
|
||||||
|
num_prompt_tokens=num_prompt_tokens,
|
||||||
|
num_generation_tokens=num_generation_tokens,
|
||||||
|
time_to_first_tokens=time_to_first_tokens,
|
||||||
|
time_per_output_tokens=time_per_output_tokens,
|
||||||
|
time_e2e_requests=time_e2e_requests,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: we may not need to decode
|
||||||
|
def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
|
||||||
|
"""Decodes the new token for a sequence."""
|
||||||
|
(new_tokens, new_output_text, prefix_offset, read_offset) = detokenize_incrementally(
|
||||||
|
self.get_tokenizer_for_seq(seq),
|
||||||
|
all_input_ids=seq.get_token_ids(),
|
||||||
|
prev_tokens=seq.tokens,
|
||||||
|
prefix_offset=seq.prefix_offset,
|
||||||
|
read_offset=seq.read_offset,
|
||||||
|
skip_special_tokens=prms.skip_special_tokens,
|
||||||
|
spaces_between_special_tokens=prms.spaces_between_special_tokens,
|
||||||
|
)
|
||||||
|
if seq.tokens is None:
|
||||||
|
seq.tokens = new_tokens
|
||||||
|
else:
|
||||||
|
seq.tokens.extend(new_tokens)
|
||||||
|
seq.prefix_offset = prefix_offset
|
||||||
|
seq.read_offset = read_offset
|
||||||
|
seq.output_text += new_output_text
|
||||||
|
|
||||||
|
def _check_stop(self, seq: Sequence, sampling_params: SamplingParams) -> None:
|
||||||
|
"""Stop the finished sequences."""
|
||||||
|
# for stop_str in sampling_params.stop:
|
||||||
|
# if seq.output_text.endswith(stop_str):
|
||||||
|
# self._finalize_sequence(seq, sampling_params, stop_str)
|
||||||
|
# seq.status = SequenceStatus.FINISHED_STOPPED
|
||||||
|
# return
|
||||||
|
# if seq.get_last_token_id() in sampling_params.stop_token_ids:
|
||||||
|
# stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(seq.get_last_token_id())
|
||||||
|
# self._finalize_sequence(seq, sampling_params, stop_str)
|
||||||
|
# seq.status = SequenceStatus.FINISHED_STOPPED
|
||||||
|
# return
|
||||||
|
|
||||||
|
# Check if the sequence has reached max_model_len.
|
||||||
|
if seq.get_len() > self.scheduler_config.max_model_len:
|
||||||
|
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if the sequence has reached max_tokens.
|
||||||
|
if seq.get_output_len() == sampling_params.max_tokens:
|
||||||
|
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if the sequence has generated the EOS token.
|
||||||
|
if ((not sampling_params.ignore_eos) and
|
||||||
|
seq.get_last_token_id() == self.get_tokenizer_for_seq(seq).eos_token_id):
|
||||||
|
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||||
|
return
|
||||||
|
|
||||||
|
def _finalize_sequence(self, seq: Sequence, sampling_params: SamplingParams, stop_string: str) -> None:
|
||||||
|
if not sampling_params.include_stop_str_in_output and stop_string:
|
||||||
|
# Truncate the output text so that the stop string is
|
||||||
|
# not included in the output.
|
||||||
|
seq.output_text = seq.output_text[:-len(stop_string)]
|
||||||
|
|
||||||
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
|
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
|
||||||
|
return self.worker.add_lora(lora_request)
|
||||||
|
|
||||||
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
|
assert lora_id > 0, "lora_id must be greater than 0."
|
||||||
|
return self.worker.remove_lora(lora_id)
|
||||||
|
|
||||||
|
def list_loras(self) -> List[int]:
|
||||||
|
return self.worker.list_loras()
|
||||||
|
|
||||||
|
def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor]) -> None:
|
||||||
|
self.worker.sync_model_weights(actor_weights=actor_weights)
|
||||||
|
|
||||||
|
def offload_model_weights(self) -> None:
|
||||||
|
self.worker.offload_model_weights()
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_cluster(
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
engine_use_ray: bool = False,
|
||||||
|
ray_address: Optional[str] = None,
|
||||||
|
) -> Tuple[str, Optional[None]]:
|
||||||
|
"""Initialize the distributed cluster probably with Ray.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parallel_config: The configurations for parallel execution.
|
||||||
|
engine_use_ray: Whether to use Ray for async engine.
|
||||||
|
ray_address: The address of the Ray cluster. If None, uses
|
||||||
|
the default Ray cluster address.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (`distributed_init_method`, `placement_group`). The
|
||||||
|
`distributed_init_method` is the address for initializing the
|
||||||
|
distributed backend. `placement_group` includes the specification
|
||||||
|
of the resources for each distributed worker.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Initialize cluster locally.
|
||||||
|
port = get_open_port()
|
||||||
|
# We need to setup the distributed init method to make sure
|
||||||
|
# the distributed megatron code (e.g., get world size) works correctly.
|
||||||
|
distributed_init_method = f"tcp://localhost:{port}"
|
||||||
|
return distributed_init_method, None
|
||||||
|
|
||||||
|
|
||||||
|
def get_open_port():
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.bind(("", 0))
|
||||||
|
return s.getsockname()[1]
|
||||||
275
verl/third_party/vllm/vllm_v_0_3_1/model_loader.py
vendored
Normal file
275
verl/third_party/vllm/vllm_v_0_3_1/model_loader.py
vendored
Normal file
@@ -0,0 +1,275 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader
|
||||||
|
"""Utilities for selecting and loading models."""
|
||||||
|
import contextlib
|
||||||
|
from typing import Dict, Type, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import PretrainedConfig, PreTrainedModel
|
||||||
|
from megatron.core.tensor_parallel.utils import VocabUtility
|
||||||
|
|
||||||
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
from vllm.model_executor.weight_utils import (get_quant_config, initialize_dummy_weights)
|
||||||
|
|
||||||
|
from .config import ModelConfig
|
||||||
|
from vllm.config import DeviceConfig, LoRAConfig
|
||||||
|
from .weight_loaders import *
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
from typing import Optional
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.layers.sampler import _prune_hidden_states, _apply_logits_processors, _apply_penalties, _apply_top_k_top_p, _apply_min_p, _apply_penalties, _sample, _get_logprobs, _build_sampler_output
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _set_default_torch_dtype(dtype: torch.dtype):
|
||||||
|
"""Sets the default torch dtype to the given dtype."""
|
||||||
|
old_dtype = torch.get_default_dtype()
|
||||||
|
torch.set_default_dtype(dtype)
|
||||||
|
yield
|
||||||
|
torch.set_default_dtype(old_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||||
|
architectures = getattr(config, "architectures", [])
|
||||||
|
for arch in architectures:
|
||||||
|
model_cls = ModelRegistry.load_model_cls(arch)
|
||||||
|
if model_cls is not None:
|
||||||
|
return model_cls
|
||||||
|
raise ValueError(f"Model architectures {architectures} are not supported for now. "
|
||||||
|
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
||||||
|
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.linear import *
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead
|
||||||
|
from vllm.model_executor.layers.activation import ScaledActivation
|
||||||
|
|
||||||
|
__LAYER_WEIGHT_LOADER_REGISTRY__ = {
|
||||||
|
ColumnParallelLinear: parallel_weight_loader,
|
||||||
|
MergedColumnParallelLinear: parallel_weight_loader,
|
||||||
|
QKVParallelLinear: parallel_weight_loader,
|
||||||
|
RowParallelLinear: parallel_weight_loader,
|
||||||
|
VocabParallelEmbedding: parallel_weight_loader,
|
||||||
|
ParallelLMHead: parallel_weight_loader
|
||||||
|
# "ScaledActivation.weight_loader": ScaledActivation, # TODO(shengguangming): latest commit in vllm fix awq for this function and add load_weights
|
||||||
|
# "default_weight_loader": default_weight_loader
|
||||||
|
}
|
||||||
|
|
||||||
|
# NOTE(gmsheng): change the weight_loader function in runtime
|
||||||
|
for layer_class, weight_loader in __LAYER_WEIGHT_LOADER_REGISTRY__.items():
|
||||||
|
layer_class.weight_loader = weight_loader
|
||||||
|
|
||||||
|
__MODEL_WEIGHT_LOADER_REGISTRY__ = {
|
||||||
|
'GPT2LMHeadModel': gpt2_weight_loader,
|
||||||
|
'LlamaForCausalLM': llama_weight_loader,
|
||||||
|
'LLaMAForCausalLM': llama_weight_loader,
|
||||||
|
'MistralForCausalLM': mistral_weight_loader,
|
||||||
|
}
|
||||||
|
|
||||||
|
# FIXME(shengguangming): the vLLM vocab will pad to 64, which may incur out of bounds
|
||||||
|
# so we need to rewrite the init function of vocab
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||||
|
|
||||||
|
|
||||||
|
def vocab_init(self,
|
||||||
|
num_embeddings: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
|
org_num_embeddings: Optional[int] = None,
|
||||||
|
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
|
||||||
|
super(VocabParallelEmbedding, self).__init__()
|
||||||
|
|
||||||
|
# Keep the input dimensions.
|
||||||
|
# TODO (pad to be divided by 4)
|
||||||
|
self.num_embeddings = num_embeddings
|
||||||
|
self.org_vocab_size = org_num_embeddings or num_embeddings
|
||||||
|
|
||||||
|
# self.num_embeddings_padded = pad_vocab_size(num_embeddings,
|
||||||
|
# padding_size)
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
if params_dtype is None:
|
||||||
|
params_dtype = torch.get_default_dtype()
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
# Divide the weight matrix along the vocaburaly dimension.
|
||||||
|
|
||||||
|
self.vocab_start_index, self.vocab_end_index = (VocabUtility.vocab_range_from_global_vocab_size(
|
||||||
|
self.num_embeddings, get_tensor_model_parallel_rank(), self.tp_size))
|
||||||
|
self.num_embeddings_per_partition = (self.vocab_end_index - self.vocab_start_index)
|
||||||
|
self.weight = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.num_embeddings_per_partition,
|
||||||
|
self.embedding_dim,
|
||||||
|
# device=torch.cuda.current_device(),
|
||||||
|
dtype=params_dtype))
|
||||||
|
set_weight_attrs(self.weight, {"parallel_dim": 0, "weight_loader": self.weight_loader})
|
||||||
|
|
||||||
|
|
||||||
|
VocabParallelEmbedding.__init__ = vocab_init
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model_weight_loader(arch: str):
|
||||||
|
if arch in __MODEL_WEIGHT_LOADER_REGISTRY__:
|
||||||
|
return __MODEL_WEIGHT_LOADER_REGISTRY__[arch]
|
||||||
|
raise ValueError(f"Model architectures {arch} are not supported for now. "
|
||||||
|
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(actor_model: Union[PreTrainedModel, Dict],
|
||||||
|
model_config: ModelConfig,
|
||||||
|
device_config: DeviceConfig,
|
||||||
|
lora_config: Optional[LoRAConfig] = None) -> nn.Module:
|
||||||
|
model_class = _get_model_architecture(model_config.hf_config)
|
||||||
|
|
||||||
|
# Get the quantization config.
|
||||||
|
linear_method = None
|
||||||
|
quant_config = None
|
||||||
|
if model_config.quantization is not None:
|
||||||
|
quant_config = get_quant_config(model_config.quantization, model_config.model, model_config.hf_config,
|
||||||
|
model_config.download_dir)
|
||||||
|
capability = torch.cuda.get_device_capability()
|
||||||
|
capability = capability[0] * 10 + capability[1]
|
||||||
|
if capability < quant_config.get_min_capability():
|
||||||
|
raise ValueError(f"The quantization method {model_config.quantization} is not "
|
||||||
|
"supported for the current GPU. "
|
||||||
|
f"Minimum capability: {quant_config.get_min_capability()}. "
|
||||||
|
f"Current capability: {capability}.")
|
||||||
|
supported_dtypes = quant_config.get_supported_act_dtypes()
|
||||||
|
if model_config.dtype not in supported_dtypes:
|
||||||
|
raise ValueError(f"{model_config.dtype} is not supported for quantization "
|
||||||
|
f"method {model_config.quantization}. Supported dtypes: "
|
||||||
|
f"{supported_dtypes}")
|
||||||
|
linear_method = quant_config.get_linear_method()
|
||||||
|
|
||||||
|
with _set_default_torch_dtype(model_config.dtype):
|
||||||
|
# Create a model instance.
|
||||||
|
# The weights will be initialized as empty tensors.
|
||||||
|
# with torch.device(device_config.device):
|
||||||
|
# NOTE(sgm): init the model in cpu
|
||||||
|
model = model_class(model_config.hf_config, linear_method)
|
||||||
|
|
||||||
|
if model_config.load_format == "dummy":
|
||||||
|
model = model.cuda()
|
||||||
|
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||||
|
# random values to the weights.
|
||||||
|
initialize_dummy_weights(model)
|
||||||
|
elif model_config.load_format == 'model' or model_config.load_format == 'auto':
|
||||||
|
# NOTE(shengguangming) Load the weights from the actor model
|
||||||
|
if isinstance(actor_model, nn.Module):
|
||||||
|
load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model)
|
||||||
|
else:
|
||||||
|
load_weights(actor_weights=actor_model, vllm_model=model)
|
||||||
|
|
||||||
|
# NOTE(sgm) Some weights are point to gpu, but still need this.
|
||||||
|
model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage
|
||||||
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
# the actor model is .state_dict()
|
||||||
|
def load_weights(actor_weights: Dict, vllm_model: nn.Module):
|
||||||
|
weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__)
|
||||||
|
weight_loader(actor_weights, vllm_model)
|
||||||
|
# NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu
|
||||||
|
# after init, and we need this after sync model weights for in first iter.
|
||||||
|
vllm_model = vllm_model.cuda()
|
||||||
|
|
||||||
|
|
||||||
|
# FIXME(sgm): hack the Sampler function in vllm v0.3.1
|
||||||
|
# as they use ray, the sampler result will only need to return to the driver node,
|
||||||
|
# therefore gather is enough. However, we use SPMD instead of a central scheduler,
|
||||||
|
# all_gather is required (aligned with v0.2.6)
|
||||||
|
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
|
||||||
|
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
|
# Get the logits for the next tokens.
|
||||||
|
logits = torch.matmul(hidden_states, embedding.t())
|
||||||
|
if embedding_bias is not None:
|
||||||
|
logits += embedding_bias
|
||||||
|
logits = tensor_model_parallel_all_gather(logits)
|
||||||
|
# Remove paddings in vocab (if any).
|
||||||
|
if logits is not None:
|
||||||
|
logits = logits[:, :self.org_vocab_size]
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
embedding: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
embedding_bias: Optional[torch.Tensor] = None,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
# Get the hidden states that we use for sampling.
|
||||||
|
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
|
||||||
|
|
||||||
|
# Get the logits for the next tokens.
|
||||||
|
logits = self._get_logits(hidden_states, embedding, embedding_bias)
|
||||||
|
# save origin logprobs for sampler_output
|
||||||
|
origin_logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||||
|
|
||||||
|
# Only perform sampling in the driver worker.
|
||||||
|
# Note: `_get_logits` is still distributed across TP workers because
|
||||||
|
# the `embedding` weight is distributed across TP workers.
|
||||||
|
# TODO(zhuohan): Change the get_logits part to a separate stage.
|
||||||
|
if not sampling_metadata.perform_sampling:
|
||||||
|
return None
|
||||||
|
|
||||||
|
assert logits is not None
|
||||||
|
_, vocab_size = logits.shape
|
||||||
|
|
||||||
|
# Apply logits processors (if any).
|
||||||
|
logits = _apply_logits_processors(logits, sampling_metadata)
|
||||||
|
|
||||||
|
# Prepare sampling tensors with pinned memory to avoid blocking.
|
||||||
|
(sampling_tensors, do_penalties, do_top_p_top_k,
|
||||||
|
do_min_p) = SamplingTensors.from_sampling_metadata(sampling_metadata, vocab_size, logits.device, logits.dtype)
|
||||||
|
|
||||||
|
# Apply presence and frequency penalties.
|
||||||
|
if do_penalties:
|
||||||
|
logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, sampling_tensors.output_tokens,
|
||||||
|
sampling_tensors.presence_penalties, sampling_tensors.frequency_penalties,
|
||||||
|
sampling_tensors.repetition_penalties)
|
||||||
|
|
||||||
|
# Apply temperature scaling.
|
||||||
|
# Use in-place division to avoid creating a new tensor.
|
||||||
|
logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))
|
||||||
|
|
||||||
|
if do_top_p_top_k:
|
||||||
|
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, sampling_tensors.top_ks)
|
||||||
|
|
||||||
|
if do_min_p:
|
||||||
|
logits = _apply_min_p(logits, sampling_tensors.min_ps)
|
||||||
|
|
||||||
|
# We use float32 for probabilities and log probabilities.
|
||||||
|
# Compute the probabilities.
|
||||||
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||||
|
# Compute the log probabilities.
|
||||||
|
# Use log_softmax to ensure numerical stability.
|
||||||
|
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||||
|
|
||||||
|
# Sample the next tokens.
|
||||||
|
sample_results = _sample(probs, logprobs, sampling_metadata)
|
||||||
|
|
||||||
|
# Get the logprobs query results.
|
||||||
|
# prompt_logprobs, sample_logprobs = _get_logprobs(
|
||||||
|
# logprobs, sampling_metadata, sample_results)
|
||||||
|
prompt_logprobs, sample_logprobs = _get_logprobs(origin_logprobs, sampling_metadata, sample_results)
|
||||||
|
|
||||||
|
return _build_sampler_output(sample_results, sampling_metadata, prompt_logprobs, sample_logprobs)
|
||||||
|
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
|
||||||
|
Sampler._get_logits = _get_logits
|
||||||
|
Sampler.forward = forward
|
||||||
285
verl/third_party/vllm/vllm_v_0_3_1/model_runner.py
vendored
Normal file
285
verl/third_party/vllm/vllm_v_0_3_1/model_runner.py
vendored
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/model_runner.py
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Tuple, Set, Union
|
||||||
|
import contextlib
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor import InputMetadata, SamplingMetadata
|
||||||
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
|
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||||
|
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||||
|
from vllm.lora.layers import LoRAMapping
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.utils import in_wsl
|
||||||
|
from vllm.worker.model_runner import ModelRunner, CUDAGraphRunner, _async_h2d
|
||||||
|
|
||||||
|
from .model_loader import get_model
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
_PAD_SLOT_ID = -1
|
||||||
|
LORA_WARMUP_RANK = 8
|
||||||
|
# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
|
||||||
|
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
|
||||||
|
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRunner(ModelRunner):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Union[nn.Module, Dict], # model itself or its parameter dict
|
||||||
|
model_config: ModelConfig,
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
scheduler_config: SchedulerConfig,
|
||||||
|
device_config: DeviceConfig,
|
||||||
|
lora_config: Optional[LoRAConfig],
|
||||||
|
kv_cache_dtype: Optional[str] = "auto",
|
||||||
|
):
|
||||||
|
self.model_config = model_config
|
||||||
|
self.parallel_config = parallel_config
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
|
||||||
|
# model_config can be None in tests/samplers/test_sampler.py.
|
||||||
|
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
|
||||||
|
self.sliding_window = (model_config.get_sliding_window() if model_config is not None else None)
|
||||||
|
|
||||||
|
self.device_config = (device_config if device_config is not None else DeviceConfig())
|
||||||
|
self.device = self.device_config.device
|
||||||
|
|
||||||
|
self.model = model # this will be replaced by get_model()
|
||||||
|
self.block_size = None # Set after initial profiling.
|
||||||
|
self.lora_manager = None
|
||||||
|
|
||||||
|
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
|
||||||
|
self.graph_memory_pool = None # Set during graph capture.
|
||||||
|
|
||||||
|
self.max_context_len_to_capture = (self.model_config.max_context_len_to_capture
|
||||||
|
if self.model_config is not None else 0)
|
||||||
|
# When using CUDA graph, the input block tables must be padded to
|
||||||
|
# max_context_len_to_capture. However, creating the block table in
|
||||||
|
# Python can be expensive. To optimize this, we cache the block table
|
||||||
|
# in numpy and only copy the actual input content at every iteration.
|
||||||
|
# The shape of the cached block table will be
|
||||||
|
# (max batch size to capture, max context len to capture / block size).
|
||||||
|
self.graph_block_tables = None # Set after initial profiling.
|
||||||
|
# cache in_wsl result
|
||||||
|
self.in_wsl = in_wsl()
|
||||||
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
|
||||||
|
def load_model(self) -> None:
|
||||||
|
self.model = get_model(actor_model=self.model,
|
||||||
|
model_config=self.model_config,
|
||||||
|
device_config=self.device_config,
|
||||||
|
lora_config=self.lora_config)
|
||||||
|
vocab_size = self.model.config.vocab_size
|
||||||
|
|
||||||
|
if self.lora_config:
|
||||||
|
assert hasattr(
|
||||||
|
self.model,
|
||||||
|
"supported_lora_modules") and self.model.supported_lora_modules, "Model does not support LoRA"
|
||||||
|
assert hasattr(self.model, "embedding_modules"), "Model does not have embedding_modules"
|
||||||
|
assert hasattr(self.model, "embedding_padding_modules"), "Model does not have embedding_padding_modules"
|
||||||
|
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||||
|
self.scheduler_config.max_num_seqs,
|
||||||
|
self.scheduler_config.max_num_batched_tokens + self.scheduler_config.max_paddings, vocab_size,
|
||||||
|
self.lora_config, self.device, self.model.embedding_modules, self.model.embedding_padding_modules)
|
||||||
|
self.model = self.lora_manager.create_lora_manager(self.model)
|
||||||
|
|
||||||
|
def _prepare_sample(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
prompt_lens: List[int],
|
||||||
|
subquery_lens: Optional[List[int]],
|
||||||
|
) -> SamplingMetadata:
|
||||||
|
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
||||||
|
selected_token_indices: List[int] = []
|
||||||
|
selected_token_start_idx = 0
|
||||||
|
categorized_sample_indices = {t: [] for t in SamplingType}
|
||||||
|
categorized_sample_indices_start_idx = 0
|
||||||
|
|
||||||
|
max_subquery_len = max(subquery_lens) if subquery_lens else 1
|
||||||
|
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||||
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
|
sampling_params = seq_group_metadata.sampling_params
|
||||||
|
seq_groups.append((seq_ids, sampling_params))
|
||||||
|
|
||||||
|
if seq_group_metadata.is_prompt:
|
||||||
|
assert len(seq_ids) == 1
|
||||||
|
assert subquery_lens is not None
|
||||||
|
subquery_len = subquery_lens[i]
|
||||||
|
if sampling_params.prompt_logprobs is not None:
|
||||||
|
# NOTE: prompt token positions do not need sample, skip
|
||||||
|
categorized_sample_indices_start_idx += subquery_len - 1
|
||||||
|
|
||||||
|
categorized_sample_indices[sampling_params.sampling_type].append(categorized_sample_indices_start_idx)
|
||||||
|
categorized_sample_indices_start_idx += 1
|
||||||
|
|
||||||
|
if sampling_params.prompt_logprobs is not None:
|
||||||
|
selected_token_indices.extend(
|
||||||
|
range(selected_token_start_idx, selected_token_start_idx + subquery_len - 1))
|
||||||
|
selected_token_indices.append(selected_token_start_idx + subquery_len - 1)
|
||||||
|
selected_token_start_idx += max_subquery_len
|
||||||
|
else:
|
||||||
|
num_seqs = len(seq_ids)
|
||||||
|
selected_token_indices.extend(range(selected_token_start_idx, selected_token_start_idx + num_seqs))
|
||||||
|
selected_token_start_idx += num_seqs
|
||||||
|
|
||||||
|
categorized_sample_indices[sampling_params.sampling_type].extend(
|
||||||
|
range(categorized_sample_indices_start_idx, categorized_sample_indices_start_idx + num_seqs))
|
||||||
|
categorized_sample_indices_start_idx += num_seqs
|
||||||
|
|
||||||
|
selected_token_indices = _async_h2d(selected_token_indices,
|
||||||
|
dtype=torch.long,
|
||||||
|
target_device=self.device,
|
||||||
|
pin_memory=not self.in_wsl)
|
||||||
|
categorized_sample_indices = {
|
||||||
|
t: _async_h2d(seq_ids, dtype=torch.int, target_device=self.device, pin_memory=not self.in_wsl)
|
||||||
|
for t, seq_ids in categorized_sample_indices.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
seq_data: Dict[int, SequenceData] = {}
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
seq_data.update(seq_group_metadata.seq_data)
|
||||||
|
|
||||||
|
sampling_metadata = SamplingMetadata(
|
||||||
|
seq_groups=seq_groups,
|
||||||
|
seq_data=seq_data,
|
||||||
|
prompt_lens=prompt_lens,
|
||||||
|
selected_token_indices=selected_token_indices,
|
||||||
|
categorized_sample_indices=categorized_sample_indices,
|
||||||
|
)
|
||||||
|
return sampling_metadata
|
||||||
|
|
||||||
|
def prepare_input_tensors(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata, Set[int], LoRAMapping]:
|
||||||
|
# NOTE: We assume that all sequences in the group are all prompts or
|
||||||
|
# all decodes.
|
||||||
|
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||||
|
# Prepare input tensors.
|
||||||
|
if is_prompt:
|
||||||
|
(input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens, lora_index_mapping,
|
||||||
|
lora_prompt_mapping, lora_requests) = self._prepare_prompt(seq_group_metadata_list)
|
||||||
|
else:
|
||||||
|
(input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping,
|
||||||
|
lora_requests) = self._prepare_decode(seq_group_metadata_list)
|
||||||
|
prompt_lens = []
|
||||||
|
subquery_lens = None
|
||||||
|
sampling_metadata = self._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens)
|
||||||
|
if self.lora_config:
|
||||||
|
flat_lora_index_mapping = [item for sublist in lora_index_mapping for item in sublist]
|
||||||
|
lora_mapping = LoRAMapping(
|
||||||
|
flat_lora_index_mapping,
|
||||||
|
lora_prompt_mapping,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lora_mapping = None
|
||||||
|
|
||||||
|
return (input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, lora_mapping)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||||
|
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
(input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests,
|
||||||
|
lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list)
|
||||||
|
|
||||||
|
if self.lora_config:
|
||||||
|
self.set_active_loras(lora_requests, lora_mapping)
|
||||||
|
|
||||||
|
# Execute the model.
|
||||||
|
if input_metadata.use_cuda_graph:
|
||||||
|
graph_batch_size = input_tokens.shape[0]
|
||||||
|
model_executable = self.graph_runners[graph_batch_size]
|
||||||
|
else:
|
||||||
|
model_executable = self.model
|
||||||
|
hidden_states = model_executable(
|
||||||
|
input_ids=input_tokens,
|
||||||
|
positions=input_positions,
|
||||||
|
kv_caches=kv_caches,
|
||||||
|
input_metadata=input_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sample the next token.
|
||||||
|
output = self.model.sample(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
sampling_metadata=sampling_metadata,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def profile_run(self) -> None:
|
||||||
|
# Enable top-k sampling to reflect the accurate memory usage.
|
||||||
|
vocab_size = self.model_config.get_vocab_size()
|
||||||
|
# FIXME(sgm): this sampling params will call cumsum(), causing the
|
||||||
|
# deterministic cumsum throw error
|
||||||
|
sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
|
||||||
|
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||||
|
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||||
|
|
||||||
|
# This represents the maximum number of different requests
|
||||||
|
# that will have unique loras, an therefore the max amount of memory
|
||||||
|
# consumption create dummy lora request copies from the lora request
|
||||||
|
# passed in, which contains a lora from the lora warmup path.
|
||||||
|
dummy_lora_requests = []
|
||||||
|
dummy_lora_requests_per_seq = []
|
||||||
|
if self.lora_config:
|
||||||
|
for idx in range(self.lora_config.max_loras):
|
||||||
|
lora_id = idx + 1
|
||||||
|
dummy_lora_request = LoRARequest(
|
||||||
|
lora_name=f"warmup_{lora_id}",
|
||||||
|
lora_int_id=lora_id,
|
||||||
|
lora_local_path="/not/a/real/path",
|
||||||
|
)
|
||||||
|
self.lora_manager.add_dummy_lora(dummy_lora_request, rank=LORA_WARMUP_RANK)
|
||||||
|
dummy_lora_requests.append(dummy_lora_request)
|
||||||
|
dummy_lora_requests_per_seq = [
|
||||||
|
dummy_lora_requests[idx % len(dummy_lora_requests)] for idx in range(max_num_seqs)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Profile memory usage with max_num_sequences sequences and the total
|
||||||
|
# number of tokens equal to max_num_batched_tokens.
|
||||||
|
seqs: List[SequenceGroupMetadata] = []
|
||||||
|
for group_id in range(max_num_seqs):
|
||||||
|
seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs))
|
||||||
|
seq_data = SequenceData([0] * seq_len)
|
||||||
|
seq = SequenceGroupMetadata(
|
||||||
|
request_id=str(group_id),
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={group_id: seq_data},
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
block_tables=None,
|
||||||
|
lora_request=dummy_lora_requests_per_seq[group_id] if dummy_lora_requests_per_seq else None,
|
||||||
|
)
|
||||||
|
seqs.append(seq)
|
||||||
|
|
||||||
|
# Run the model with the dummy inputs.
|
||||||
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||||
|
kv_caches = [(None, None)] * num_layers
|
||||||
|
self.execute_model(seqs, kv_caches)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
return
|
||||||
147
verl/third_party/vllm/vllm_v_0_3_1/parallel_state.py
vendored
Normal file
147
verl/third_party/vllm/vllm_v_0_3_1/parallel_state.py
vendored
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
||||||
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
"""Model and data parallel groups."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
import vllm.model_executor.parallel_utils.parallel_state as ps
|
||||||
|
"""
|
||||||
|
This version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron.
|
||||||
|
- We assume the Megatron tp+dp+pp world is already established before calling this function.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Tensor model parallel group that the current rank belongs to.
|
||||||
|
_TENSOR_MODEL_PARALLEL_GROUP = None
|
||||||
|
|
||||||
|
# Micro Data parallel group. Micro data parallel group is additional dp group that origins from splitting training tp
|
||||||
|
# into infer_tp and micro_tp. By default, we use order micro_dp - tp
|
||||||
|
_MICRO_DATA_PARALLEL_GROUP = None
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_model_parallel_from_megatron(
|
||||||
|
tensor_model_parallel_size=None # we set None for backward compatibility to set infer_tp = train_tp
|
||||||
|
) -> None:
|
||||||
|
from megatron.core import parallel_state as mpu
|
||||||
|
from megatron.distributed import new_group
|
||||||
|
# Get world size and rank. Ensure some consistencies.
|
||||||
|
assert torch.distributed.is_initialized()
|
||||||
|
|
||||||
|
if tensor_model_parallel_size is None:
|
||||||
|
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
|
||||||
|
else:
|
||||||
|
assert isinstance(tensor_model_parallel_size, int)
|
||||||
|
|
||||||
|
# Build the tensor model-parallel groups.
|
||||||
|
assert ps._TENSOR_MODEL_PARALLEL_GROUP is None, ("tensor model parallel group is already initialized")
|
||||||
|
|
||||||
|
assert tensor_model_parallel_size <= mpu.get_tensor_model_parallel_world_size(
|
||||||
|
), 'Not implemented for infer_tp > train_tp'
|
||||||
|
|
||||||
|
global _TENSOR_MODEL_PARALLEL_GROUP
|
||||||
|
global _MICRO_DATA_PARALLEL_GROUP
|
||||||
|
|
||||||
|
assert mpu.get_tensor_model_parallel_world_size() % tensor_model_parallel_size == 0
|
||||||
|
|
||||||
|
micro_dp_size = mpu.get_tensor_model_parallel_world_size() // tensor_model_parallel_size
|
||||||
|
|
||||||
|
world_size: int = torch.distributed.get_world_size()
|
||||||
|
|
||||||
|
num_micro_dp_groups = world_size // micro_dp_size
|
||||||
|
|
||||||
|
rank = torch.distributed.get_rank()
|
||||||
|
|
||||||
|
# Build the micro dp groups.
|
||||||
|
assert _MICRO_DATA_PARALLEL_GROUP is None, ("micro data parallel group is already initialized")
|
||||||
|
for i in range(num_micro_dp_groups):
|
||||||
|
ranks = range(i * micro_dp_size, (i + 1) * micro_dp_size)
|
||||||
|
group = new_group(rank=rank, ranks=ranks, group_type='micro_dp')
|
||||||
|
if rank in ranks:
|
||||||
|
_MICRO_DATA_PARALLEL_GROUP = group
|
||||||
|
|
||||||
|
if tensor_model_parallel_size == mpu.get_tensor_model_parallel_world_size():
|
||||||
|
# using the same tp group as Megatron
|
||||||
|
ps._TENSOR_MODEL_PARALLEL_GROUP = mpu.get_tensor_model_parallel_group()
|
||||||
|
|
||||||
|
_TENSOR_MODEL_PARALLEL_GROUP = mpu.get_tensor_model_parallel_group()
|
||||||
|
# no _MICRO_DATA_PARALLEL_GROUP
|
||||||
|
else:
|
||||||
|
# initialize a micro_dp group and a tp group
|
||||||
|
# assume training tp=4, infer tp=2, then, weight is partitioned as
|
||||||
|
# [1], [2], [3], [4] for training and [1,2], [1,2], [3,4], [3,4] for inference
|
||||||
|
|
||||||
|
# Build the inference tp groups
|
||||||
|
train_tp = mpu.get_tensor_model_parallel_world_size()
|
||||||
|
num_tensor_model_parallel_groups_per_train_tp = train_tp // tensor_model_parallel_size
|
||||||
|
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
|
||||||
|
assert _TENSOR_MODEL_PARALLEL_GROUP is None, ("tensor model parallel group is already initialized")
|
||||||
|
for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp):
|
||||||
|
start = train_tp * i
|
||||||
|
end = train_tp * (i + 1)
|
||||||
|
for j in range(num_tensor_model_parallel_groups_per_train_tp):
|
||||||
|
ranks = list(range(start, end, num_tensor_model_parallel_groups_per_train_tp))
|
||||||
|
for i in range(len(ranks)):
|
||||||
|
ranks[i] += j
|
||||||
|
# group = torch.distributed.new_group(ranks)
|
||||||
|
group = new_group(rank=rank, ranks=ranks, group_type='infer_tp')
|
||||||
|
if rank in ranks:
|
||||||
|
_TENSOR_MODEL_PARALLEL_GROUP = group
|
||||||
|
ps._TENSOR_MODEL_PARALLEL_GROUP = _TENSOR_MODEL_PARALLEL_GROUP
|
||||||
|
# Build the pipeline model-parallel groups.
|
||||||
|
# global _PIPELINE_MODEL_PARALLEL_GROUP
|
||||||
|
# global _PIPELINE_GLOBAL_RANKS
|
||||||
|
# assert ps._PIPELINE_MODEL_PARALLEL_GROUP is None, ("pipeline model parallel group is already initialized")
|
||||||
|
|
||||||
|
# ps._PIPELINE_MODEL_PARALLEL_GROUP = mpu.get_pipeline_model_parallel_group()
|
||||||
|
# ps._PIPELINE_GLOBAL_RANKS = mpu.get_pipeline_model_parallel_ranks()
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Tensor model parallel utilities
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_tensor_model_parallel_group():
|
||||||
|
"""Get the tensor model parallel group the caller rank belongs to."""
|
||||||
|
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, ("tensor model parallel group is not initialized")
|
||||||
|
return _TENSOR_MODEL_PARALLEL_GROUP
|
||||||
|
|
||||||
|
|
||||||
|
def get_tensor_model_parallel_world_size():
|
||||||
|
"""Return world size for the tensor model parallel group."""
|
||||||
|
return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
|
||||||
|
|
||||||
|
|
||||||
|
def get_tensor_model_parallel_rank():
|
||||||
|
"""Return my rank for the tensor model parallel group."""
|
||||||
|
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
|
||||||
|
|
||||||
|
|
||||||
|
def get_tensor_model_parallel_src_rank():
|
||||||
|
"""Calculate the global rank corresponding to the first local rank
|
||||||
|
in the tensor model parallel group."""
|
||||||
|
global_rank = torch.distributed.get_rank()
|
||||||
|
local_world_size = get_tensor_model_parallel_world_size()
|
||||||
|
return (global_rank // local_world_size) * local_world_size
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Micro Data parallel group
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_micro_data_parallel_group():
|
||||||
|
assert _MICRO_DATA_PARALLEL_GROUP is not None
|
||||||
|
return _MICRO_DATA_PARALLEL_GROUP
|
||||||
|
|
||||||
|
|
||||||
|
def get_micro_data_parallel_world_size():
|
||||||
|
return torch.distributed.get_world_size(group=get_micro_data_parallel_group())
|
||||||
|
|
||||||
|
|
||||||
|
def get_micro_data_parallel_rank():
|
||||||
|
return torch.distributed.get_rank(group=get_micro_data_parallel_group())
|
||||||
72
verl/third_party/vllm/vllm_v_0_3_1/tokenizer.py
vendored
Normal file
72
verl/third_party/vllm/vllm_v_0_3_1/tokenizer.py
vendored
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast)
|
||||||
|
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.utils import make_async, LRUCache
|
||||||
|
from vllm.transformers_utils.tokenizers import *
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerGroup:
|
||||||
|
"""A group of tokenizers that can be used for LoRA adapters."""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int,
|
||||||
|
max_input_length: Optional[int]):
|
||||||
|
self.enable_lora = enable_lora
|
||||||
|
self.max_input_length = max_input_length
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
if enable_lora:
|
||||||
|
self.lora_tokenizers = LRUCache(capacity=max_num_seqs)
|
||||||
|
else:
|
||||||
|
self.lora_tokenizers = None
|
||||||
|
|
||||||
|
def encode(self,
|
||||||
|
prompt: str,
|
||||||
|
request_id: Optional[str] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None) -> List[int]:
|
||||||
|
tokenizer = self.get_lora_tokenizer(lora_request)
|
||||||
|
return tokenizer.encode(prompt)
|
||||||
|
|
||||||
|
async def encode_async(self,
|
||||||
|
prompt: str,
|
||||||
|
request_id: Optional[str] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None) -> List[int]:
|
||||||
|
tokenizer = await self.get_lora_tokenizer_async(lora_request)
|
||||||
|
return tokenizer.encode(prompt)
|
||||||
|
|
||||||
|
def get_lora_tokenizer(self, lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
|
||||||
|
if not lora_request or not self.enable_lora:
|
||||||
|
return self.tokenizer
|
||||||
|
if lora_request.lora_int_id not in self.lora_tokenizers:
|
||||||
|
# TODO(sgm): the lora tokenizer is also passed, but may be different
|
||||||
|
tokenizer = self.tokenizer
|
||||||
|
# tokenizer = (get_lora_tokenizer(
|
||||||
|
# lora_request, **self.tokenizer_config) or self.tokenizer)
|
||||||
|
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
|
||||||
|
return tokenizer
|
||||||
|
else:
|
||||||
|
return self.lora_tokenizers.get(lora_request.lora_int_id)
|
||||||
|
|
||||||
|
# FIXME(sgm): for simplicity, we assign the special token here
|
||||||
|
@property
|
||||||
|
def pad_token_id(self):
|
||||||
|
return self.tokenizer.pad_token_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eos_token_id(self):
|
||||||
|
return self.tokenizer.eos_token_id
|
||||||
95
verl/third_party/vllm/vllm_v_0_3_1/weight_loaders.py
vendored
Normal file
95
verl/third_party/vllm/vllm_v_0_3_1/weight_loaders.py
vendored
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE(shengguangming): replace the origin weight loader function in the class
|
||||||
|
def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
||||||
|
"""Parallel Linear weight loader."""
|
||||||
|
assert param.size() == loaded_weight.size(
|
||||||
|
), 'the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}'.format(
|
||||||
|
param.size(), loaded_weight.size())
|
||||||
|
assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same"
|
||||||
|
|
||||||
|
param.data = loaded_weight.data
|
||||||
|
|
||||||
|
|
||||||
|
def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
||||||
|
"""Default weight loader."""
|
||||||
|
assert param.size() == loaded_weight.size()
|
||||||
|
assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same"
|
||||||
|
|
||||||
|
param.data = loaded_weight.data
|
||||||
|
|
||||||
|
|
||||||
|
def gpt2_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
|
||||||
|
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
|
||||||
|
for name, loaded_weight in actor_weights.items():
|
||||||
|
if "lm_head.weight" in name:
|
||||||
|
# GPT-2 ties the weights of the embedding layer and the final
|
||||||
|
# linear layer.
|
||||||
|
continue
|
||||||
|
if ".attn.bias" in name or ".attn.masked_bias" in name:
|
||||||
|
# Skip attention mask.
|
||||||
|
# NOTE: "c_attn.bias" should not be skipped.
|
||||||
|
continue
|
||||||
|
if not name.startswith("transformer."):
|
||||||
|
name = "transformer." + name
|
||||||
|
param = params_dict[name]
|
||||||
|
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
||||||
|
# Because of this, we need to transpose the weights.
|
||||||
|
# Note(zhuohan): the logic below might break quantized models.
|
||||||
|
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
||||||
|
if conv1d_weight_name not in name:
|
||||||
|
continue
|
||||||
|
if not name.endswith(".weight"):
|
||||||
|
continue
|
||||||
|
# TODO: check megatron
|
||||||
|
loaded_weight = loaded_weight.t()
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
def llama_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
|
||||||
|
# NOTE(shengguangming): the megatron llama may have this prefix
|
||||||
|
prefix = '0.module.module.'
|
||||||
|
params_dict = dict(vllm_model.named_parameters())
|
||||||
|
for name, loaded_weight in actor_weights.items():
|
||||||
|
if name[:len(prefix)] == prefix:
|
||||||
|
name = name[len(prefix):]
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
def mistral_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
|
||||||
|
# TODO: need to implement a general way to deal with prefix
|
||||||
|
prefix = '0.module.module.'
|
||||||
|
params_dict = dict(vllm_model.named_parameters())
|
||||||
|
for name, loaded_weight in actor_weights.items():
|
||||||
|
if name[:len(prefix)] == prefix:
|
||||||
|
name = name[len(prefix):]
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
314
verl/third_party/vllm/vllm_v_0_3_1/worker.py
vendored
Normal file
314
verl/third_party/vllm/vllm_v_0_3_1/worker.py
vendored
Normal file
@@ -0,0 +1,314 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/worker.py
|
||||||
|
"""A GPU worker class."""
|
||||||
|
import os
|
||||||
|
import gc
|
||||||
|
from typing import Dict, List, Tuple, Optional, Union, Set
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig)
|
||||||
|
from vllm.model_executor import InputMetadata, set_random_seed
|
||||||
|
from vllm.model_executor.parallel_utils.parallel_state import (initialize_model_parallel)
|
||||||
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
|
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||||
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
|
from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar
|
||||||
|
from vllm.model_executor.parallel_utils.parallel_state import get_tensor_model_parallel_group
|
||||||
|
|
||||||
|
from .model_runner import ModelRunner
|
||||||
|
from .model_loader import load_weights
|
||||||
|
from .parallel_state import initialize_model_parallel_from_megatron
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
|
|
||||||
|
class Worker:
|
||||||
|
"""A worker class that executes (a partition of) the model on a GPU.
|
||||||
|
|
||||||
|
Each worker is associated with a single GPU. The worker is responsible for
|
||||||
|
maintaining the KV cache and executing the model on the GPU. In case of
|
||||||
|
distributed inference, each worker is assigned a partition of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Union[nn.Module, Dict], # model itself or its parameter dict
|
||||||
|
model_config: ModelConfig,
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
scheduler_config: SchedulerConfig,
|
||||||
|
device_config: DeviceConfig,
|
||||||
|
rank: Optional[int] = None,
|
||||||
|
distributed_init_method: Optional[str] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
|
kv_cache_dtype: Optional[str] = "auto",
|
||||||
|
) -> None:
|
||||||
|
# self.model = model # will be replaced in the init_model
|
||||||
|
self.model_config = model_config
|
||||||
|
self.parallel_config = parallel_config
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.rank = rank
|
||||||
|
self.distributed_init_method = distributed_init_method
|
||||||
|
self.lora_config = lora_config
|
||||||
|
|
||||||
|
self.model_runner = ModelRunner(
|
||||||
|
model,
|
||||||
|
model_config,
|
||||||
|
parallel_config,
|
||||||
|
scheduler_config,
|
||||||
|
device_config,
|
||||||
|
lora_config=self.lora_config,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Uninitialized cache engine. Will be initialized by
|
||||||
|
# self.init_cache_engine().
|
||||||
|
self.cache_config = None
|
||||||
|
self.block_size = None
|
||||||
|
self.sliding_window = None
|
||||||
|
self.cache_engine = None
|
||||||
|
self.cache_events = None
|
||||||
|
self.gpu_cache = None
|
||||||
|
|
||||||
|
# For offloading inference engine params
|
||||||
|
self.cpu_model = None
|
||||||
|
|
||||||
|
def init_model(self, cupy_port: Optional[int] = None):
|
||||||
|
# torch.distributed.all_reduce does not free the input tensor until
|
||||||
|
# the synchronization point. This causes the memory usage to grow
|
||||||
|
# as the number of all_reduce calls increases. This env var disables
|
||||||
|
# this behavior.
|
||||||
|
# Related issue:
|
||||||
|
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
|
||||||
|
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||||
|
|
||||||
|
# Env vars will be set by TORCHRUN.
|
||||||
|
self.rank = self.rank if self.rank is not None else int(os.getenv("RANK", "-1"))
|
||||||
|
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||||
|
self.device = torch.device(f"cuda:{local_rank}")
|
||||||
|
if self.rank < 0:
|
||||||
|
raise ValueError("Invalid or unspecified rank.")
|
||||||
|
torch.cuda.set_device(self.device)
|
||||||
|
|
||||||
|
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
||||||
|
|
||||||
|
# Initialize the distributed environment.
|
||||||
|
# TODO: do not use cupy
|
||||||
|
_init_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method)
|
||||||
|
if not self.parallel_config.disable_custom_all_reduce:
|
||||||
|
init_custom_ar()
|
||||||
|
# Initialize the model.
|
||||||
|
set_random_seed(self.model_config.seed)
|
||||||
|
# self.model = get_model(actor_model=self.model, model_config=self.model_config)
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
self.model_runner.load_model()
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def profile_num_available_blocks(
|
||||||
|
self,
|
||||||
|
block_size: int,
|
||||||
|
gpu_memory_utilization: float,
|
||||||
|
cpu_swap_space: int,
|
||||||
|
cache_dtype: str,
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
# Profile the memory usage of the model and get the maximum number of
|
||||||
|
# cache blocks that can be allocated with the remaining free memory.
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
# torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||||
|
# of the model.
|
||||||
|
self.model_runner.profile_run()
|
||||||
|
|
||||||
|
# Calculate the number of blocks that can be allocated with the
|
||||||
|
# profiled peak memory.
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
||||||
|
peak_memory = total_gpu_memory - free_gpu_memory
|
||||||
|
|
||||||
|
cache_block_size = CacheEngine.get_cache_block_size(block_size, cache_dtype, self.model_config,
|
||||||
|
self.parallel_config)
|
||||||
|
# NOTE(sgm) use the remaining memory
|
||||||
|
num_gpu_blocks = int((free_gpu_memory * gpu_memory_utilization) // cache_block_size)
|
||||||
|
# num_gpu_blocks = int((total_gpu_memory * gpu_memory_utilization - peak_memory) // cache_block_size)
|
||||||
|
num_cpu_blocks = int(cpu_swap_space // cache_block_size)
|
||||||
|
num_gpu_blocks = max(num_gpu_blocks, 0)
|
||||||
|
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||||
|
if self.model_runner.lora_manager:
|
||||||
|
self.model_runner.remove_all_loras()
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
# Synchronize number of blocks with all the rank
|
||||||
|
num_gpu_blocks = torch.tensor([num_gpu_blocks], device='cuda')
|
||||||
|
num_cpu_blocks = torch.tensor([num_cpu_blocks], device='cuda')
|
||||||
|
torch.distributed.all_reduce(num_gpu_blocks,
|
||||||
|
op=torch.distributed.ReduceOp.MIN,
|
||||||
|
group=get_tensor_model_parallel_group())
|
||||||
|
torch.distributed.all_reduce(num_cpu_blocks,
|
||||||
|
op=torch.distributed.ReduceOp.MIN,
|
||||||
|
group=get_tensor_model_parallel_group())
|
||||||
|
num_gpu_blocks = num_gpu_blocks.item()
|
||||||
|
num_cpu_blocks = num_cpu_blocks.item()
|
||||||
|
return num_gpu_blocks, num_cpu_blocks
|
||||||
|
|
||||||
|
def init_cache_engine(self, cache_config: CacheConfig) -> None:
|
||||||
|
if self.cache_engine is None and self.gpu_cache is None:
|
||||||
|
self.cache_config = cache_config
|
||||||
|
self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.parallel_config)
|
||||||
|
self.cache_events = self.cache_engine.events
|
||||||
|
self.gpu_cache = self.cache_engine.gpu_cache
|
||||||
|
self.model_runner.set_block_size(self.cache_engine.block_size)
|
||||||
|
|
||||||
|
def free_cache_engine(self):
|
||||||
|
# ensure `enforce_eager=True`
|
||||||
|
self.cache_engine = None
|
||||||
|
self.gpu_cache = None
|
||||||
|
|
||||||
|
def warm_up_model(self) -> None:
|
||||||
|
if not self.model_config.enforce_eager:
|
||||||
|
self.model_runner.capture_model(self.gpu_cache)
|
||||||
|
# Reset the seed to ensure that the random state is not affected by
|
||||||
|
# the model initialization and profiling.
|
||||||
|
set_random_seed(self.model_config.seed)
|
||||||
|
|
||||||
|
def cache_swap(
|
||||||
|
self,
|
||||||
|
blocks_to_swap_in: Dict[int, int],
|
||||||
|
blocks_to_swap_out: Dict[int, int],
|
||||||
|
blocks_to_copy: Dict[int, List[int]],
|
||||||
|
) -> None:
|
||||||
|
# Issue cache operations.
|
||||||
|
issued_cache_op = False
|
||||||
|
if blocks_to_swap_in:
|
||||||
|
self.cache_engine.swap_in(blocks_to_swap_in)
|
||||||
|
issued_cache_op = True
|
||||||
|
if blocks_to_swap_out:
|
||||||
|
self.cache_engine.swap_out(blocks_to_swap_out)
|
||||||
|
issued_cache_op = True
|
||||||
|
if blocks_to_copy:
|
||||||
|
self.cache_engine.copy(blocks_to_copy)
|
||||||
|
issued_cache_op = True
|
||||||
|
|
||||||
|
cache_events = self.cache_events if issued_cache_op else None
|
||||||
|
|
||||||
|
# Wait for cache operations to finish.
|
||||||
|
# TODO(woosuk): Profile swapping overhead and optimize if needed.
|
||||||
|
if cache_events is not None:
|
||||||
|
for event in cache_events:
|
||||||
|
event.wait()
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
blocks_to_swap_in: Dict[int, int],
|
||||||
|
blocks_to_swap_out: Dict[int, int],
|
||||||
|
blocks_to_copy: Dict[int, List[int]],
|
||||||
|
) -> SamplerOutput:
|
||||||
|
num_seq_groups = len(seq_group_metadata_list)
|
||||||
|
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
|
||||||
|
|
||||||
|
# If there is no input, we don't need to execute the model.
|
||||||
|
if num_seq_groups == 0:
|
||||||
|
return {}
|
||||||
|
output = self.model_runner.execute_model(seq_group_metadata_list, self.gpu_cache)
|
||||||
|
return output
|
||||||
|
|
||||||
|
# # Prepare input tensors.
|
||||||
|
# # NOTE(shengguangming): currently we pad in our dataloader and unpad it in pre_process_input, j
|
||||||
|
# # we can just input un-padded sequence for better performance
|
||||||
|
# input_tokens, input_positions, input_metadata = self._prepare_inputs(seq_group_metadata_list)
|
||||||
|
|
||||||
|
# # Execute the model.
|
||||||
|
# output = self.model(
|
||||||
|
# input_ids=input_tokens,
|
||||||
|
# positions=input_positions,
|
||||||
|
# kv_caches=self.gpu_cache,
|
||||||
|
# input_metadata=input_metadata,
|
||||||
|
# cache_events=cache_events,
|
||||||
|
# )
|
||||||
|
# return output
|
||||||
|
|
||||||
|
# assume the input is .state_dict()
|
||||||
|
def sync_model_weights(self, actor_weights: Dict):
|
||||||
|
load_weights(actor_weights, self.model_runner.model)
|
||||||
|
|
||||||
|
def offload_model_weights(self) -> None:
|
||||||
|
if self.cpu_model == None:
|
||||||
|
self.cpu_model = {}
|
||||||
|
for name, params in self.model_runner.model.named_parameters():
|
||||||
|
self.cpu_model[name] = torch.empty_like(params, device='cpu')
|
||||||
|
params.data = self.cpu_model[name]
|
||||||
|
else:
|
||||||
|
for name, params in self.model_runner.model.named_parameters():
|
||||||
|
params.data = self.cpu_model[name]
|
||||||
|
|
||||||
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
|
return self.model_runner.add_lora(lora_request)
|
||||||
|
|
||||||
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
|
return self.model_runner.remove_lora(lora_id)
|
||||||
|
|
||||||
|
def list_loras(self) -> Set[int]:
|
||||||
|
return self.model_runner.list_loras()
|
||||||
|
|
||||||
|
|
||||||
|
def _init_distributed_environment(
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
rank: int,
|
||||||
|
distributed_init_method: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the distributed environment."""
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
print('The distributed environment has been initialized before vLLM')
|
||||||
|
elif not distributed_init_method:
|
||||||
|
raise ValueError("distributed_init_method must be set if torch.distributed "
|
||||||
|
"is not already initialized")
|
||||||
|
else:
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend="nccl",
|
||||||
|
world_size=parallel_config.world_size,
|
||||||
|
rank=rank,
|
||||||
|
# init_method=distributed_init_method,
|
||||||
|
)
|
||||||
|
|
||||||
|
# A small all_reduce for warmup.
|
||||||
|
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
||||||
|
# TODO (shengguangming): maybe we should also flag the megatron is initialized
|
||||||
|
if torch.distributed.get_world_size() > 1:
|
||||||
|
initialize_model_parallel_from_megatron(tensor_model_parallel_size=parallel_config.tensor_parallel_size)
|
||||||
|
else:
|
||||||
|
initialize_model_parallel()
|
||||||
|
|
||||||
|
|
||||||
|
def _pad_to_alignment(x: List[int], multiple_of: int, pad: int) -> List[int]:
|
||||||
|
return x + [pad] * ((-len(x)) % multiple_of)
|
||||||
|
|
||||||
|
|
||||||
|
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
|
||||||
|
return x + [pad] * (max_len - len(x))
|
||||||
|
|
||||||
|
|
||||||
|
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||||
|
# Check if the GPU supports the dtype.
|
||||||
|
if torch_dtype == torch.bfloat16:
|
||||||
|
compute_capability = torch.cuda.get_device_capability()
|
||||||
|
if compute_capability[0] < 8:
|
||||||
|
gpu_name = torch.cuda.get_device_name()
|
||||||
|
raise ValueError("Bfloat16 is only supported on GPUs with compute capability "
|
||||||
|
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
||||||
|
f"{compute_capability[0]}.{compute_capability[1]}.")
|
||||||
13
verl/third_party/vllm/vllm_v_0_4_2/__init__.py
vendored
Normal file
13
verl/third_party/vllm/vllm_v_0_4_2/__init__.py
vendored
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
320
verl/third_party/vllm/vllm_v_0_4_2/arg_utils.py
vendored
Normal file
320
verl/third_party/vllm/vllm_v_0_4_2/arg_utils.py
vendored
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import dataclasses
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
from .config import ModelConfig, LoadConfig
|
||||||
|
|
||||||
|
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoRAConfig, ParallelConfig,
|
||||||
|
SchedulerConfig, SpeculativeConfig, TokenizerPoolConfig, VisionLanguageConfig)
|
||||||
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
|
from vllm.utils import str_to_int_tuple
|
||||||
|
|
||||||
|
|
||||||
|
def nullable_str(val: str):
|
||||||
|
if not val or val == "None":
|
||||||
|
return None
|
||||||
|
return val
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EngineArgs:
|
||||||
|
"""Arguments for vLLM engine."""
|
||||||
|
model_hf_config: PretrainedConfig = None
|
||||||
|
skip_tokenizer_init: bool = False
|
||||||
|
served_model_name: Optional[Union[str, List[str]]] = None # TODO
|
||||||
|
download_dir: Optional[str] = None
|
||||||
|
load_format: str = 'auto'
|
||||||
|
dtype: str = 'auto'
|
||||||
|
kv_cache_dtype: str = 'auto'
|
||||||
|
quantization_param_path: Optional[str] = None
|
||||||
|
seed: int = 0
|
||||||
|
max_model_len: Optional[int] = None
|
||||||
|
worker_use_ray: bool = False
|
||||||
|
pipeline_parallel_size: int = 1
|
||||||
|
tensor_parallel_size: int = 1
|
||||||
|
max_parallel_loading_workers: Optional[int] = None
|
||||||
|
block_size: int = 16
|
||||||
|
enable_prefix_caching: bool = False
|
||||||
|
use_v2_block_manager: bool = False
|
||||||
|
swap_space: int = 4 # GiB
|
||||||
|
gpu_memory_utilization: float = 0.90
|
||||||
|
max_num_batched_tokens: Optional[int] = None
|
||||||
|
max_num_seqs: int = 256
|
||||||
|
max_logprobs: int = 5 # OpenAI default value
|
||||||
|
disable_log_stats: bool = False
|
||||||
|
revision: Optional[str] = None
|
||||||
|
code_revision: Optional[str] = None
|
||||||
|
tokenizer_revision: Optional[str] = None
|
||||||
|
quantization: Optional[str] = None
|
||||||
|
enforce_eager: bool = False
|
||||||
|
max_context_len_to_capture: Optional[int] = None
|
||||||
|
max_seq_len_to_capture: int = 8192
|
||||||
|
disable_custom_all_reduce: bool = False
|
||||||
|
tokenizer_pool_size: int = 0
|
||||||
|
tokenizer_pool_type: str = "ray"
|
||||||
|
tokenizer_pool_extra_config: Optional[dict] = None
|
||||||
|
enable_lora: bool = False
|
||||||
|
max_loras: int = 1
|
||||||
|
max_lora_rank: int = 16
|
||||||
|
fully_sharded_loras: bool = False
|
||||||
|
lora_extra_vocab_size: int = 256
|
||||||
|
lora_dtype = 'auto'
|
||||||
|
max_cpu_loras: Optional[int] = None
|
||||||
|
device: str = 'auto'
|
||||||
|
ray_workers_use_nsight: bool = False
|
||||||
|
num_gpu_blocks_override: Optional[int] = None
|
||||||
|
num_lookahead_slots: int = 0
|
||||||
|
model_loader_extra_config: Optional[dict] = None
|
||||||
|
|
||||||
|
# Related to Vision-language models such as llava
|
||||||
|
image_input_type: Optional[str] = None
|
||||||
|
image_token_id: Optional[int] = None
|
||||||
|
image_input_shape: Optional[str] = None
|
||||||
|
image_feature_size: Optional[int] = None
|
||||||
|
scheduler_delay_factor: float = 0.0
|
||||||
|
enable_chunked_prefill: bool = False
|
||||||
|
|
||||||
|
guided_decoding_backend: str = 'outlines'
|
||||||
|
# Speculative decoding configuration.
|
||||||
|
speculative_model: Optional[str] = None
|
||||||
|
num_speculative_tokens: Optional[int] = None
|
||||||
|
speculative_max_model_len: Optional[int] = None
|
||||||
|
ngram_prompt_lookup_max: Optional[int] = None
|
||||||
|
ngram_prompt_lookup_min: Optional[int] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||||
|
"""Shared CLI arguments for vLLM engine."""
|
||||||
|
# Model arguments
|
||||||
|
# TODO(shengguangming): delete the unused args
|
||||||
|
parser.add_argument('--model',
|
||||||
|
type=str,
|
||||||
|
default='facebook/opt-125m',
|
||||||
|
help='name or path of the huggingface model to use')
|
||||||
|
parser.add_argument('--tokenizer',
|
||||||
|
type=str,
|
||||||
|
default=EngineArgs.tokenizer,
|
||||||
|
help='name or path of the huggingface tokenizer to use')
|
||||||
|
parser.add_argument('--revision',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='the specific model version to use. It can be a branch '
|
||||||
|
'name, a tag name, or a commit id. If unspecified, will use '
|
||||||
|
'the default version.')
|
||||||
|
parser.add_argument('--tokenizer-revision',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='the specific tokenizer version to use. It can be a branch '
|
||||||
|
'name, a tag name, or a commit id. If unspecified, will use '
|
||||||
|
'the default version.')
|
||||||
|
parser.add_argument('--tokenizer-mode',
|
||||||
|
type=str,
|
||||||
|
default=EngineArgs.tokenizer_mode,
|
||||||
|
choices=['auto', 'slow'],
|
||||||
|
help='tokenizer mode. "auto" will use the fast '
|
||||||
|
'tokenizer if available, and "slow" will '
|
||||||
|
'always use the slow tokenizer.')
|
||||||
|
parser.add_argument('--trust-remote-code', action='store_true', help='trust remote code from huggingface')
|
||||||
|
parser.add_argument('--download-dir',
|
||||||
|
type=str,
|
||||||
|
default=EngineArgs.download_dir,
|
||||||
|
help='directory to download and load the weights, '
|
||||||
|
'default to the default cache dir of '
|
||||||
|
'huggingface')
|
||||||
|
parser.add_argument('--load-format',
|
||||||
|
type=str,
|
||||||
|
default=EngineArgs.load_format,
|
||||||
|
choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
|
||||||
|
help='The format of the model weights to load. '
|
||||||
|
'"auto" will try to load the weights in the safetensors format '
|
||||||
|
'and fall back to the pytorch bin format if safetensors format '
|
||||||
|
'is not available. '
|
||||||
|
'"pt" will load the weights in the pytorch bin format. '
|
||||||
|
'"safetensors" will load the weights in the safetensors format. '
|
||||||
|
'"npcache" will load the weights in pytorch format and store '
|
||||||
|
'a numpy cache to speed up the loading. '
|
||||||
|
'"dummy" will initialize the weights with random values, '
|
||||||
|
'which is mainly for profiling.')
|
||||||
|
parser.add_argument('--dtype',
|
||||||
|
type=str,
|
||||||
|
default=EngineArgs.dtype,
|
||||||
|
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
|
||||||
|
help='data type for model weights and activations. '
|
||||||
|
'The "auto" option will use FP16 precision '
|
||||||
|
'for FP32 and FP16 models, and BF16 precision '
|
||||||
|
'for BF16 models.')
|
||||||
|
parser.add_argument('--max-model-len',
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help='model context length. If unspecified, '
|
||||||
|
'will be automatically derived from the model.')
|
||||||
|
# Parallel arguments
|
||||||
|
parser.add_argument('--worker-use-ray',
|
||||||
|
action='store_true',
|
||||||
|
help='use Ray for distributed serving, will be '
|
||||||
|
'automatically set when using more than 1 GPU')
|
||||||
|
parser.add_argument('--pipeline-parallel-size',
|
||||||
|
'-pp',
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.pipeline_parallel_size,
|
||||||
|
help='number of pipeline stages')
|
||||||
|
parser.add_argument('--tensor-parallel-size',
|
||||||
|
'-tp',
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.tensor_parallel_size,
|
||||||
|
help='number of tensor parallel replicas')
|
||||||
|
# KV cache arguments
|
||||||
|
parser.add_argument('--block-size',
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.block_size,
|
||||||
|
choices=[8, 16, 32],
|
||||||
|
help='token block size')
|
||||||
|
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||||
|
parser.add_argument('--seed', type=int, default=EngineArgs.seed, help='random seed')
|
||||||
|
parser.add_argument('--swap-space',
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.swap_space,
|
||||||
|
help='CPU swap space size (GiB) per GPU')
|
||||||
|
parser.add_argument('--gpu-memory-utilization',
|
||||||
|
type=float,
|
||||||
|
default=EngineArgs.gpu_memory_utilization,
|
||||||
|
help='the percentage of GPU memory to be used for'
|
||||||
|
'the model executor')
|
||||||
|
parser.add_argument('--max-num-batched-tokens',
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.max_num_batched_tokens,
|
||||||
|
help='maximum number of batched tokens per '
|
||||||
|
'iteration')
|
||||||
|
parser.add_argument('--max-num-seqs',
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.max_num_seqs,
|
||||||
|
help='maximum number of sequences per iteration')
|
||||||
|
parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics')
|
||||||
|
# Quantization settings.
|
||||||
|
parser.add_argument('--quantization',
|
||||||
|
'-q',
|
||||||
|
type=str,
|
||||||
|
choices=['awq', None],
|
||||||
|
default=None,
|
||||||
|
help='Method used to quantize the weights')
|
||||||
|
return parser
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
|
||||||
|
# Get the list of attributes of this dataclass.
|
||||||
|
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||||
|
# Set the attributes from the parsed arguments.
|
||||||
|
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||||
|
return engine_args
|
||||||
|
|
||||||
|
def create_engine_config(
|
||||||
|
self,
|
||||||
|
) -> EngineConfig:
|
||||||
|
device_config = DeviceConfig(self.device)
|
||||||
|
# NOTE(sgm): we only modify ModelConfig, other configs are import from vllm
|
||||||
|
model_config = ModelConfig(self.model_hf_config, self.dtype, self.seed, self.revision, self.code_revision,
|
||||||
|
self.tokenizer_revision, self.max_model_len, self.quantization,
|
||||||
|
self.quantization_param_path, self.enforce_eager, self.max_context_len_to_capture,
|
||||||
|
self.max_seq_len_to_capture, self.max_logprobs, self.skip_tokenizer_init,
|
||||||
|
self.served_model_name)
|
||||||
|
cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization,
|
||||||
|
self.swap_space, self.kv_cache_dtype, self.num_gpu_blocks_override,
|
||||||
|
model_config.get_sliding_window(), self.enable_prefix_caching)
|
||||||
|
parallel_config = ParallelConfig(
|
||||||
|
self.pipeline_parallel_size, self.tensor_parallel_size, self.worker_use_ray,
|
||||||
|
self.max_parallel_loading_workers, self.disable_custom_all_reduce,
|
||||||
|
TokenizerPoolConfig.create_config(
|
||||||
|
self.tokenizer_pool_size,
|
||||||
|
self.tokenizer_pool_type,
|
||||||
|
self.tokenizer_pool_extra_config,
|
||||||
|
), self.ray_workers_use_nsight)
|
||||||
|
|
||||||
|
# Use the world_size set by TORCHRUN
|
||||||
|
world_size = int(os.getenv("WORLD_SIZE", "-1"))
|
||||||
|
assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN"
|
||||||
|
parallel_config.world_size = world_size
|
||||||
|
|
||||||
|
# TODO: spec config
|
||||||
|
speculative_config = SpeculativeConfig.maybe_create_spec_config(
|
||||||
|
target_model_config=model_config,
|
||||||
|
target_parallel_config=parallel_config,
|
||||||
|
target_dtype=self.dtype,
|
||||||
|
speculative_model=self.speculative_model,
|
||||||
|
num_speculative_tokens=self.num_speculative_tokens,
|
||||||
|
speculative_max_model_len=self.speculative_max_model_len,
|
||||||
|
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||||
|
use_v2_block_manager=self.use_v2_block_manager,
|
||||||
|
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
|
||||||
|
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler_config = SchedulerConfig(
|
||||||
|
self.max_num_batched_tokens,
|
||||||
|
self.max_num_seqs,
|
||||||
|
model_config.max_model_len,
|
||||||
|
self.use_v2_block_manager,
|
||||||
|
num_lookahead_slots=(self.num_lookahead_slots
|
||||||
|
if speculative_config is None else speculative_config.num_lookahead_slots),
|
||||||
|
delay_factor=self.scheduler_delay_factor,
|
||||||
|
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||||
|
)
|
||||||
|
|
||||||
|
lora_config = LoRAConfig(max_lora_rank=self.max_lora_rank,
|
||||||
|
max_loras=self.max_loras,
|
||||||
|
fully_sharded_loras=self.fully_sharded_loras,
|
||||||
|
lora_extra_vocab_size=self.lora_extra_vocab_size,
|
||||||
|
lora_dtype=self.lora_dtype,
|
||||||
|
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else
|
||||||
|
None) if self.enable_lora else None
|
||||||
|
|
||||||
|
load_config = LoadConfig(
|
||||||
|
load_format=self.load_format,
|
||||||
|
download_dir=self.download_dir,
|
||||||
|
model_loader_extra_config=self.model_loader_extra_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.image_input_type:
|
||||||
|
if (not self.image_token_id or not self.image_input_shape or not self.image_feature_size):
|
||||||
|
raise ValueError('Specify `image_token_id`, `image_input_shape` and '
|
||||||
|
'`image_feature_size` together with `image_input_type`.')
|
||||||
|
vision_language_config = VisionLanguageConfig(
|
||||||
|
image_input_type=VisionLanguageConfig.get_image_input_enum_type(self.image_input_type),
|
||||||
|
image_token_id=self.image_token_id,
|
||||||
|
image_input_shape=str_to_int_tuple(self.image_input_shape),
|
||||||
|
image_feature_size=self.image_feature_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
vision_language_config = None
|
||||||
|
|
||||||
|
decoding_config = DecodingConfig(guided_decoding_backend=self.guided_decoding_backend)
|
||||||
|
|
||||||
|
return EngineConfig(model_config=model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
parallel_config=parallel_config,
|
||||||
|
scheduler_config=scheduler_config,
|
||||||
|
device_config=device_config,
|
||||||
|
lora_config=lora_config,
|
||||||
|
vision_language_config=vision_language_config,
|
||||||
|
speculative_config=speculative_config,
|
||||||
|
load_config=load_config,
|
||||||
|
decoding_config=decoding_config)
|
||||||
200
verl/third_party/vllm/vllm_v_0_4_2/config.py
vendored
Normal file
200
verl/third_party/vllm/vllm_v_0_4_2/config.py
vendored
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import json
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
from dataclasses import dataclass, field, fields
|
||||||
|
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.quantization import get_quantization_config
|
||||||
|
from vllm.transformers_utils.config import get_hf_text_config
|
||||||
|
from vllm.utils import is_hip
|
||||||
|
# Add for verl
|
||||||
|
from vllm.config import ModelConfig, _get_and_verify_dtype, _get_and_verify_max_len
|
||||||
|
|
||||||
|
GPTQMarlinConfig = get_quantization_config("gptq_marlin")
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
_GB = 1 << 30
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConfig(ModelConfig):
|
||||||
|
"""Configuration for the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Name or path of the huggingface model to use.
|
||||||
|
tokenizer: Name or path of the huggingface tokenizer to use.
|
||||||
|
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
|
||||||
|
available, and "slow" will always use the slow tokenizer.
|
||||||
|
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
||||||
|
downloading the model and tokenizer.
|
||||||
|
download_dir: Directory to download and load the weights, default to the
|
||||||
|
default cache directory of huggingface.
|
||||||
|
load_format: The format of the model weights to load:
|
||||||
|
"auto" will try to load the weights in the safetensors format and
|
||||||
|
fall back to the pytorch bin format if safetensors format is
|
||||||
|
not available.
|
||||||
|
"pt" will load the weights in the pytorch bin format.
|
||||||
|
"safetensors" will load the weights in the safetensors format.
|
||||||
|
"npcache" will load the weights in pytorch format and store
|
||||||
|
a numpy cache to speed up the loading.
|
||||||
|
"dummy" will initialize the weights with random values, which is
|
||||||
|
mainly for profiling.
|
||||||
|
dtype: Data type for model weights and activations. The "auto" option
|
||||||
|
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
||||||
|
for BF16 models.
|
||||||
|
seed: Random seed for reproducibility.
|
||||||
|
revision: The specific model version to use. It can be a branch name,
|
||||||
|
a tag name, or a commit id. If unspecified, will use the default
|
||||||
|
version.
|
||||||
|
code_revision: The specific revision to use for the model code on
|
||||||
|
Hugging Face Hub. It can be a branch name, a tag name, or a
|
||||||
|
commit id. If unspecified, will use the default version.
|
||||||
|
tokenizer_revision: The specific tokenizer version to use. It can be a
|
||||||
|
branch name, a tag name, or a commit id. If unspecified, will use
|
||||||
|
the default version.
|
||||||
|
max_model_len: Maximum length of a sequence (including prompt and
|
||||||
|
output). If None, will be derived from the model.
|
||||||
|
quantization: Quantization method that was used to quantize the model
|
||||||
|
weights. If None, we assume the model weights are not quantized.
|
||||||
|
quantization_param_path: Path to JSON file containing scaling factors.
|
||||||
|
Used to load KV cache scaling factors into the model when KV cache
|
||||||
|
type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
|
||||||
|
be used to load activation and weight scaling factors when the
|
||||||
|
model dtype is FP8_E4M3 on ROCm.
|
||||||
|
enforce_eager: Whether to enforce eager execution. If True, we will
|
||||||
|
disable CUDA graph and always execute the model in eager mode.
|
||||||
|
If False, we will use CUDA graph and eager execution in hybrid.
|
||||||
|
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
||||||
|
When a sequence has context length larger than this, we fall back
|
||||||
|
to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
|
||||||
|
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
|
||||||
|
When a sequence has context length larger than this, we fall back
|
||||||
|
to eager mode
|
||||||
|
skip_tokenizer_init: If true, skip initialization of tokenizer and
|
||||||
|
detokenizer.
|
||||||
|
served_model_name: The model name used in metrics tag `model_name`,
|
||||||
|
matches the model name exposed via the APIs. If multiple model
|
||||||
|
names provided, the first name will be used. If not specified,
|
||||||
|
the model name will be the same as `model`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hf_config: PretrainedConfig,
|
||||||
|
dtype: str,
|
||||||
|
seed: int,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
code_revision: Optional[str] = None,
|
||||||
|
tokenizer_revision: Optional[str] = None,
|
||||||
|
max_model_len: Optional[int] = None,
|
||||||
|
quantization: Optional[str] = None,
|
||||||
|
quantization_param_path: Optional[str] = None,
|
||||||
|
enforce_eager: bool = False,
|
||||||
|
max_context_len_to_capture: Optional[int] = None,
|
||||||
|
max_seq_len_to_capture: Optional[int] = None,
|
||||||
|
max_logprobs: int = 5,
|
||||||
|
skip_tokenizer_init: bool = False,
|
||||||
|
served_model_name: Optional[Union[str, List[str]]] = None,
|
||||||
|
) -> None:
|
||||||
|
self.model = hf_config._name_or_path
|
||||||
|
self.tokenizer = hf_config._name_or_path
|
||||||
|
self.seed = seed
|
||||||
|
self.revision = revision
|
||||||
|
self.code_revision = code_revision
|
||||||
|
self.tokenizer_revision = tokenizer_revision
|
||||||
|
self.quantization = quantization
|
||||||
|
self.quantization_param_path = quantization_param_path
|
||||||
|
self.enforce_eager = enforce_eager
|
||||||
|
self.max_context_len_to_capture = max_context_len_to_capture
|
||||||
|
if self.max_context_len_to_capture is not None:
|
||||||
|
raise ValueError("`max_context_len_to_capture` is deprecated. "
|
||||||
|
"Use `max_seq_len_to_capture` instead.")
|
||||||
|
self.max_seq_len_to_capture = (max_seq_len_to_capture or max_context_len_to_capture)
|
||||||
|
self.max_logprobs = max_logprobs
|
||||||
|
self.skip_tokenizer_init = skip_tokenizer_init
|
||||||
|
|
||||||
|
# self.hf_config = get_config(model, trust_remote_code, revision)
|
||||||
|
self.hf_config = hf_config
|
||||||
|
self.hf_text_config = get_hf_text_config(hf_config)
|
||||||
|
# TODO: for multimodal model
|
||||||
|
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||||
|
self.max_model_len = _get_and_verify_max_len(self.hf_config, max_model_len)
|
||||||
|
# self.served_model_name = get_served_model_name(model,
|
||||||
|
# served_model_name)
|
||||||
|
# self._verify_load_format()
|
||||||
|
# self._verify_tokenizer_mode()
|
||||||
|
self._verify_quantization()
|
||||||
|
self._verify_cuda_graph()
|
||||||
|
|
||||||
|
|
||||||
|
class LoadFormat(str, enum.Enum):
|
||||||
|
AUTO = 'auto'
|
||||||
|
MEGATRON = "megatron"
|
||||||
|
HF = "hf"
|
||||||
|
DTENSOR = 'dtensor'
|
||||||
|
DUMMY_HF = 'dummy_hf'
|
||||||
|
DUMMY_MEGATRON = 'dummy_megatron'
|
||||||
|
DUMMY_DTENSOR = 'dummy_dtensor'
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoadConfig:
|
||||||
|
"""
|
||||||
|
download_dir: Directory to download and load the weights, default to the
|
||||||
|
default cache directory of huggingface.
|
||||||
|
load_format: The format of the model weights to load:
|
||||||
|
"auto" will try to load the weights in the safetensors format and
|
||||||
|
fall back to the pytorch bin format if safetensors format is
|
||||||
|
not available.
|
||||||
|
"pt" will load the weights in the pytorch bin format.
|
||||||
|
"safetensors" will load the weights in the safetensors format.
|
||||||
|
"npcache" will load the weights in pytorch format and store
|
||||||
|
a numpy cache to speed up the loading.
|
||||||
|
"dummy" will initialize the weights with random values, which is
|
||||||
|
mainly for profiling.
|
||||||
|
"tensorizer" will use CoreWeave's tensorizer library for
|
||||||
|
fast weight loading.
|
||||||
|
"""
|
||||||
|
|
||||||
|
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
|
||||||
|
download_dir: Optional[str] = None
|
||||||
|
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
model_loader_extra_config = self.model_loader_extra_config or {}
|
||||||
|
if isinstance(model_loader_extra_config, str):
|
||||||
|
self.model_loader_extra_config = json.loads(model_loader_extra_config)
|
||||||
|
self._verify_load_format()
|
||||||
|
|
||||||
|
def _verify_load_format(self) -> None:
|
||||||
|
if not isinstance(self.load_format, str):
|
||||||
|
return
|
||||||
|
|
||||||
|
load_format = self.load_format.lower()
|
||||||
|
self.load_format = LoadFormat(load_format)
|
||||||
|
|
||||||
|
rocm_not_supported_load_format: List[str] = []
|
||||||
|
if is_hip() and load_format in rocm_not_supported_load_format:
|
||||||
|
rocm_supported_load_format = [
|
||||||
|
f for f in LoadFormat.__members__ if (f not in rocm_not_supported_load_format)
|
||||||
|
]
|
||||||
|
raise ValueError(f"load format '{load_format}' is not supported in ROCm. "
|
||||||
|
f"Supported load formats are "
|
||||||
|
f"{rocm_supported_load_format}")
|
||||||
269
verl/third_party/vllm/vllm_v_0_4_2/dtensor_weight_loaders.py
vendored
Normal file
269
verl/third_party/vllm/vllm_v_0_4_2/dtensor_weight_loaders.py
vendored
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models
|
||||||
|
|
||||||
|
from typing import Dict, Iterable, Tuple
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.distributed._tensor import DTensor, Shard, Replicate
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.linear import *
|
||||||
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
|
|
||||||
|
def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
|
||||||
|
params_dict = dict(vllm_model.named_parameters())
|
||||||
|
for name, loaded_weight in actor_weights.items():
|
||||||
|
for (param_name, shard_name, shard_id) in stacked_params_mapping:
|
||||||
|
if shard_name not in name:
|
||||||
|
continue
|
||||||
|
stacked_name = name.replace(shard_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if stacked_name.endswith(".bias") and stacked_name not in params_dict:
|
||||||
|
continue
|
||||||
|
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
|
||||||
|
param = params_dict[stacked_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# lm_head is not used in vllm as it is tied with embed_token.
|
||||||
|
# To prevent errors, skip loading lm_head.weight.
|
||||||
|
if "lm_head.weight" in name:
|
||||||
|
continue
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
# GemmaRMSNorm is different from Llama's in that it multiplies
|
||||||
|
# (1 + weight) to the output, instead of just weight.
|
||||||
|
if "norm.weight" in name:
|
||||||
|
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
|
||||||
|
|
||||||
|
norm_weight = local_loaded_weight + 1.0
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, norm_weight.to(dtype=param.dtype))
|
||||||
|
else:
|
||||||
|
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
|
||||||
|
|
||||||
|
|
||||||
|
def gptbigcode_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module):
|
||||||
|
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
|
||||||
|
for name, loaded_weight in actor_weights.items():
|
||||||
|
if "lm_head.weight" in name:
|
||||||
|
continue
|
||||||
|
if ".attn.bias" in name:
|
||||||
|
# Skip attention mask.
|
||||||
|
# NOTE: "c_attn.bias" should not be skipped.
|
||||||
|
continue
|
||||||
|
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
|
||||||
|
|
||||||
|
|
||||||
|
def starcoder2_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module):
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
]
|
||||||
|
|
||||||
|
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
|
||||||
|
for name, loaded_weight in actor_weights.items():
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
|
||||||
|
|
||||||
|
|
||||||
|
def llama_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
|
(".qkv_proj", ".k_proj", "k"),
|
||||||
|
(".qkv_proj", ".v_proj", "v"),
|
||||||
|
(".gate_up_proj", ".gate_proj", 0),
|
||||||
|
(".gate_up_proj", ".up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(vllm_model.named_parameters())
|
||||||
|
for name, loaded_weight in actor_weights.items():
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
if ("rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name):
|
||||||
|
# Models trained using ColossalAI may include these tensors in
|
||||||
|
# the checkpoint. Skip them.
|
||||||
|
continue
|
||||||
|
# With tie_word_embeddings, we can skip lm_head.weight
|
||||||
|
# The weight might appear unnecessarily in the files if the model is
|
||||||
|
# processed with quantization, LoRA, fine-tuning, etc.
|
||||||
|
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||||
|
continue
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, local_loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
def qwen2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
|
||||||
|
for name, loaded_weight in actor_weights.items():
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||||
|
continue
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
|
||||||
|
|
||||||
|
|
||||||
|
def gpt2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def redistribute_dtensor(param_name: str, loaded_weights: DTensor, parallelize_plan: Dict = None):
|
||||||
|
param_name = _process_parameter_names(name=param_name)
|
||||||
|
if parallelize_plan is not None:
|
||||||
|
assert param_name in parallelize_plan.keys(), \
|
||||||
|
f"param name: {param_name} not in parallelize_plan :{parallelize_plan.keys()}"
|
||||||
|
placement = parallelize_plan[param_name]
|
||||||
|
local_loaded_weights = loaded_weights.redistribute(device_mesh=loaded_weights.device_mesh,
|
||||||
|
placements=placement).to_local()
|
||||||
|
else:
|
||||||
|
local_loaded_weights = loaded_weights.full_tensor()
|
||||||
|
return local_loaded_weights
|
||||||
|
|
||||||
|
|
||||||
|
def _process_parameter_names(name):
|
||||||
|
# Remove '.weight' if it exists at the end of the string
|
||||||
|
if name.endswith(".weight"):
|
||||||
|
name = name[:-7]
|
||||||
|
|
||||||
|
# Remove 'model.layers.x.' or 'model.' prefix
|
||||||
|
if "model.layers" in name:
|
||||||
|
parts = name.split('.')
|
||||||
|
# Reconstruct the string without 'model.layers.x.'
|
||||||
|
name = '.'.join(parts[3:]) # parts[0] is 'model', parts[1] is 'layers', parts[2] is 'x'
|
||||||
|
elif name.startswith("model."):
|
||||||
|
name = name[6:] # Remove 'model.'
|
||||||
|
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__ = {
|
||||||
|
'GPT2LMHeadModel': gpt2_dtensor_weight_loader,
|
||||||
|
'LlamaForCausalLM': llama_dtensor_weight_loader,
|
||||||
|
'LLaMAForCausalLM': llama_dtensor_weight_loader,
|
||||||
|
'MistralForCausalLM': llama_dtensor_weight_loader, # mistral is the same as llama in vLLM
|
||||||
|
'InternLMForCausalLM': llama_dtensor_weight_loader,
|
||||||
|
'AquilaModel': llama_dtensor_weight_loader,
|
||||||
|
'AquilaForCausalLM': llama_dtensor_weight_loader,
|
||||||
|
'Phi3ForCausalLM': llama_dtensor_weight_loader,
|
||||||
|
'GemmaForCausalLM': gemma_dtensor_weight_loader,
|
||||||
|
'GPTBigCodeForCausalLM': gptbigcode_dtensor_load_weights,
|
||||||
|
'Starcoder2ForCausalLM': starcoder2_dtensor_load_weights,
|
||||||
|
'Qwen2ForCausalLM': qwen2_dtensor_weight_loader
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# the actor model is .state_dict()
|
||||||
|
# Load dtensor weights
|
||||||
|
def load_dtensor_weights(actor_weights: Dict, vllm_model: nn.Module):
|
||||||
|
weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__)
|
||||||
|
weight_loader(actor_weights, vllm_model)
|
||||||
|
# NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu
|
||||||
|
# after init, and we need this after sync model weights for in first iter.
|
||||||
|
vllm_model = vllm_model.cuda()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model_weight_loader(arch: str):
|
||||||
|
if arch in __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__:
|
||||||
|
return __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__[arch]
|
||||||
|
raise ValueError(f"Model architectures {arch} are not supported for now. "
|
||||||
|
f"Supported architectures: {__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__.keys()}")
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE(sgm): we use per-parameter weight loader in each vllm sub
|
||||||
|
def update_dtensor_weight_loader():
|
||||||
|
pass
|
||||||
91
verl/third_party/vllm/vllm_v_0_4_2/hf_weight_loader.py
vendored
Normal file
91
verl/third_party/vllm/vllm_v_0_4_2/hf_weight_loader.py
vendored
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models
|
||||||
|
|
||||||
|
from typing import Dict, Union, Optional, Iterable, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
|
|
||||||
|
def update_hf_weight_loader():
|
||||||
|
from vllm.model_executor.models.gemma import GemmaForCausalLM
|
||||||
|
GemmaForCausalLM.load_weights = gemma_load_weights
|
||||||
|
|
||||||
|
|
||||||
|
def gemma_load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
for (param_name, shard_name, shard_id) in stacked_params_mapping:
|
||||||
|
if shard_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(shard_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# lm_head is not used in vllm as it is tied with embed_token.
|
||||||
|
# To prevent errors, skip loading lm_head.weight.
|
||||||
|
if "lm_head.weight" in name:
|
||||||
|
continue
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
# GemmaRMSNorm is different from Llama's in that it multiplies
|
||||||
|
# (1 + weight) to the output, instead of just weight.
|
||||||
|
if "norm.weight" in name:
|
||||||
|
norm_weight = loaded_weight + 1.0 # prevent inplace modify actor weights
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, norm_weight)
|
||||||
|
else:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
unloaded_params = params_dict.keys() - loaded_params
|
||||||
|
if unloaded_params:
|
||||||
|
raise RuntimeError("Some weights are not initialized from checkpoints: "
|
||||||
|
f"{unloaded_params}")
|
||||||
|
|
||||||
|
|
||||||
|
def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module):
|
||||||
|
assert isinstance(actor_weights, Dict)
|
||||||
|
with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO
|
||||||
|
vllm_model.load_weights(actor_weights.items())
|
||||||
|
for _, module in vllm_model.named_modules():
|
||||||
|
quant_method = getattr(module, "quant_method", None)
|
||||||
|
if quant_method is not None:
|
||||||
|
quant_method.process_weights_after_loading(module)
|
||||||
|
# FIXME: Remove this after Mixtral is updated
|
||||||
|
# to use quant_method.
|
||||||
|
if hasattr(module, "process_weights_after_loading"):
|
||||||
|
module.process_weights_after_loading()
|
||||||
|
vllm_model = vllm_model.cuda()
|
||||||
306
verl/third_party/vllm/vllm_v_0_4_2/llm.py
vendored
Normal file
306
verl/third_party/vllm/vllm_v_0_4_2/llm.py
vendored
Normal file
@@ -0,0 +1,306 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
import torch.nn as nn
|
||||||
|
from .arg_utils import EngineArgs
|
||||||
|
from .llm_engine_sp import LLMEngine
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.outputs import RequestOutput
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.sequence import MultiModalData
|
||||||
|
from vllm.usage.usage_lib import UsageContext
|
||||||
|
from vllm.utils import Counter
|
||||||
|
import torch
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class LLM:
|
||||||
|
"""An LLM for generating texts from given prompts and sampling parameters.
|
||||||
|
|
||||||
|
This class includes a tokenizer, a language model (possibly distributed
|
||||||
|
across multiple GPUs), and GPU memory space allocated for intermediate
|
||||||
|
states (aka KV cache). Given a batch of prompts and sampling parameters,
|
||||||
|
this class generates texts from the model, using an intelligent batching
|
||||||
|
mechanism and efficient memory management.
|
||||||
|
|
||||||
|
NOTE: This class is intended to be used for offline inference. For online
|
||||||
|
serving, use the `AsyncLLMEngine` class instead.
|
||||||
|
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: A HuggingFace Transformers model instance.
|
||||||
|
tokenizer: A HuggingFace Transformers tokenizer instance.
|
||||||
|
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
|
||||||
|
if available, and "slow" will always use the slow tokenizer.
|
||||||
|
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
||||||
|
downloading the model and tokenizer.
|
||||||
|
tensor_parallel_size: The number of GPUs to use for distributed
|
||||||
|
execution with tensor parallelism.
|
||||||
|
dtype: The data type for the model weights and activations. Currently,
|
||||||
|
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
|
||||||
|
the `torch_dtype` attribute specified in the model config file.
|
||||||
|
However, if the `torch_dtype` in the config is `float32`, we will
|
||||||
|
use `float16` instead.
|
||||||
|
quantization: The method used to quantize the model weights. Currently,
|
||||||
|
we support "awq". If None, we assume the model weights are not
|
||||||
|
quantized and use `dtype` to determine the data type of the weights.
|
||||||
|
revision: The specific model version to use. It can be a branch name,
|
||||||
|
a tag name, or a commit id.
|
||||||
|
tokenizer_revision: The specific tokenizer version to use. It can be a
|
||||||
|
branch name, a tag name, or a commit id.
|
||||||
|
seed: The seed to initialize the random number generator for sampling.
|
||||||
|
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
|
||||||
|
reserve for the model weights, activations, and KV cache. Higher
|
||||||
|
values will increase the KV cache size and thus improve the model's
|
||||||
|
throughput. However, if the value is too high, it may cause out-of-
|
||||||
|
memory (OOM) errors.
|
||||||
|
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
|
||||||
|
This can be used for temporarily storing the states of the requests
|
||||||
|
when their `best_of` sampling parameters are larger than 1. If all
|
||||||
|
requests will have `best_of=1`, you can safely set this to 0.
|
||||||
|
Otherwise, too small values may cause out-of-memory (OOM) errors.
|
||||||
|
enforce_eager: Whether to enforce eager execution. If True, we will
|
||||||
|
disable CUDA graph and always execute the model in eager mode.
|
||||||
|
If False, we will use CUDA graph and eager execution in hybrid.
|
||||||
|
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
||||||
|
When a sequence has context length larger than this, we fall back
|
||||||
|
to eager mode.
|
||||||
|
disable_custom_all_reduce: See ParallelConfig
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Union[nn.Module, Dict], # model itself or its parameter dict
|
||||||
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer],
|
||||||
|
model_hf_config: PretrainedConfig,
|
||||||
|
tokenizer_mode: str = "auto",
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
tensor_parallel_size: int = 1,
|
||||||
|
dtype: str = "auto",
|
||||||
|
quantization: Optional[str] = None,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
tokenizer_revision: Optional[str] = None,
|
||||||
|
seed: int = 0,
|
||||||
|
gpu_memory_utilization: float = 0.9,
|
||||||
|
swap_space: int = 4,
|
||||||
|
enforce_eager: bool = False,
|
||||||
|
max_context_len_to_capture: int = None,
|
||||||
|
disable_custom_all_reduce: bool = False,
|
||||||
|
load_format = 'auto',
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
if "disable_log_stats" not in kwargs:
|
||||||
|
kwargs["disable_log_stats"] = True
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model_hf_config=model_hf_config,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
dtype=dtype,
|
||||||
|
quantization=quantization,
|
||||||
|
revision=revision,
|
||||||
|
tokenizer_revision=tokenizer_revision,
|
||||||
|
seed=seed,
|
||||||
|
gpu_memory_utilization=gpu_memory_utilization,
|
||||||
|
swap_space=swap_space,
|
||||||
|
enforce_eager=enforce_eager,
|
||||||
|
max_context_len_to_capture=max_context_len_to_capture,
|
||||||
|
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||||
|
load_format=load_format,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
tokenizer_cls = (PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer)
|
||||||
|
if not isinstance(tokenizer, tokenizer_cls):
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected tokenizer type: {type(tokenizer)}. Must be"
|
||||||
|
"one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer"
|
||||||
|
)
|
||||||
|
self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args)
|
||||||
|
self.request_counter = Counter()
|
||||||
|
|
||||||
|
def init_cache_engine(self):
|
||||||
|
self.llm_engine.init_cache_engine()
|
||||||
|
|
||||||
|
def free_cache_engine(self):
|
||||||
|
self.llm_engine.free_cache_engine()
|
||||||
|
|
||||||
|
def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||||
|
return self.llm_engine.tokenizer
|
||||||
|
|
||||||
|
def set_tokenizer(
|
||||||
|
self,
|
||||||
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
|
) -> None:
|
||||||
|
self.llm_engine.tokenizer = tokenizer
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompts: Optional[Union[str, List[str]]] = None,
|
||||||
|
sampling_params: Optional[Union[SamplingParams, List[SamplingParams]]] = None,
|
||||||
|
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||||
|
use_tqdm: bool = True,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
multi_modal_data: Optional[MultiModalData] = None,
|
||||||
|
) -> List[RequestOutput]:
|
||||||
|
"""Generates the completions for the input prompts.
|
||||||
|
|
||||||
|
NOTE: This class automatically batches the given prompts, considering
|
||||||
|
the memory constraint. For the best performance, put all of your prompts
|
||||||
|
into a single list and pass it to this method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts: A list of prompts to generate completions for.
|
||||||
|
sampling_params: The sampling parameters for text generation. If
|
||||||
|
None, we use the default sampling parameters.
|
||||||
|
When it is a single value, it is applied to every prompt.
|
||||||
|
When it is a list, the list must have the same length as the
|
||||||
|
prompts and it is paired one by one with the prompt.
|
||||||
|
prompt_token_ids: A list of token IDs for the prompts. If None, we
|
||||||
|
use the tokenizer to convert the prompts to token IDs.
|
||||||
|
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||||
|
lora_request: LoRA request to use for generation, if any.
|
||||||
|
multi_modal_data: Multi modal data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of `RequestOutput` objects containing the generated
|
||||||
|
completions in the same order as the input prompts.
|
||||||
|
"""
|
||||||
|
if prompts is None and prompt_token_ids is None:
|
||||||
|
raise ValueError("Either prompts or prompt_token_ids must be "
|
||||||
|
"provided.")
|
||||||
|
if self.llm_engine.model_config.skip_tokenizer_init \
|
||||||
|
and prompts is not None:
|
||||||
|
raise ValueError("prompts must be None if skip_tokenizer_init "
|
||||||
|
"is True")
|
||||||
|
if isinstance(prompts, str):
|
||||||
|
# Convert a single prompt to a list.
|
||||||
|
prompts = [prompts]
|
||||||
|
if (prompts is not None and prompt_token_ids is not None and len(prompts) != len(prompt_token_ids)):
|
||||||
|
raise ValueError("The lengths of prompts and prompt_token_ids "
|
||||||
|
"must be the same.")
|
||||||
|
|
||||||
|
if prompts is not None:
|
||||||
|
num_requests = len(prompts)
|
||||||
|
else:
|
||||||
|
assert prompt_token_ids is not None
|
||||||
|
num_requests = len(prompt_token_ids)
|
||||||
|
|
||||||
|
if sampling_params is None:
|
||||||
|
# Use default sampling params.
|
||||||
|
sampling_params = SamplingParams()
|
||||||
|
|
||||||
|
elif isinstance(sampling_params, list) and len(sampling_params) != num_requests:
|
||||||
|
raise ValueError("The lengths of prompts and sampling_params "
|
||||||
|
"must be the same.")
|
||||||
|
if multi_modal_data:
|
||||||
|
multi_modal_data.data = multi_modal_data.data.to(torch.float16)
|
||||||
|
|
||||||
|
# Add requests to the engine.
|
||||||
|
for i in range(num_requests):
|
||||||
|
prompt = prompts[i] if prompts is not None else None
|
||||||
|
token_ids = None if prompt_token_ids is None else prompt_token_ids[i]
|
||||||
|
if not isinstance(token_ids, list):
|
||||||
|
# NOTE(shengguangming): convert the rollout input into List[str]
|
||||||
|
token_ids = self._pre_process_inputs(token_ids)
|
||||||
|
self._add_request(
|
||||||
|
prompt,
|
||||||
|
sampling_params[i] if isinstance(sampling_params, list) else sampling_params,
|
||||||
|
token_ids,
|
||||||
|
lora_request=lora_request,
|
||||||
|
# Get ith image while maintaining the batch dim.
|
||||||
|
multi_modal_data=MultiModalData(type=multi_modal_data.type, data=multi_modal_data.data[i].unsqueeze(0))
|
||||||
|
if multi_modal_data else None,
|
||||||
|
)
|
||||||
|
return self._run_engine(use_tqdm)
|
||||||
|
|
||||||
|
def _add_request(
|
||||||
|
self,
|
||||||
|
prompt: Optional[str],
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
prompt_token_ids: Optional[List[int]],
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
multi_modal_data: Optional[MultiModalData] = None,
|
||||||
|
) -> None:
|
||||||
|
request_id = str(next(self.request_counter))
|
||||||
|
self.llm_engine.add_request(request_id,
|
||||||
|
prompt,
|
||||||
|
sampling_params,
|
||||||
|
prompt_token_ids,
|
||||||
|
lora_request=lora_request,
|
||||||
|
multi_modal_data=multi_modal_data)
|
||||||
|
|
||||||
|
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
|
||||||
|
# Initialize tqdm.
|
||||||
|
if use_tqdm:
|
||||||
|
num_requests = self.llm_engine.get_num_unfinished_requests()
|
||||||
|
pbar = tqdm(total=num_requests, desc="Processed prompts", dynamic_ncols=True)
|
||||||
|
# Run the engine.
|
||||||
|
outputs: List[RequestOutput] = []
|
||||||
|
while self.llm_engine.has_unfinished_requests():
|
||||||
|
step_outputs = self.llm_engine.step()
|
||||||
|
for output in step_outputs:
|
||||||
|
if output.finished:
|
||||||
|
outputs.append(output)
|
||||||
|
if use_tqdm:
|
||||||
|
pbar.update(1)
|
||||||
|
if use_tqdm:
|
||||||
|
pbar.close()
|
||||||
|
# Sort the outputs by request ID.
|
||||||
|
# This is necessary because some requests may be finished earlier than
|
||||||
|
# its previous requests.
|
||||||
|
outputs = sorted(outputs, key=lambda x: int(x.request_id))
|
||||||
|
# TODO(shengguangming): maybe we can hack the autoregressive logics without only apply post process for better performance
|
||||||
|
return self._post_process_outputs(outputs)
|
||||||
|
|
||||||
|
# NOTE(shengguangming): add for verl
|
||||||
|
# TODO(sgm): we can optimize it by making the dataloader yield List[int] without padding.
|
||||||
|
def _pre_process_inputs(self, prompt_token_ids: torch.Tensor) -> List[int]:
|
||||||
|
# remove the left padding in the prompt token_id
|
||||||
|
pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
|
||||||
|
non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
|
||||||
|
token_ids = prompt_token_ids[non_pad_index:].tolist()
|
||||||
|
return token_ids
|
||||||
|
|
||||||
|
# NOTE(shengguangming): add for verl
|
||||||
|
def _post_process_outputs(self, request_outputs: List[RequestOutput]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
output_token_ids = []
|
||||||
|
logprobs = []
|
||||||
|
for request_output in request_outputs: # List[RequestOutput]
|
||||||
|
outputs = request_output.outputs
|
||||||
|
for output in outputs: # List[CompletionOutput], usually len == 1
|
||||||
|
output_token_ids.append(torch.tensor(output.token_ids))
|
||||||
|
# TODO(shengguangming): can be optimzied by rewrite the Sampler._get_logprobs() logits
|
||||||
|
logprobs_dicts = output.logprobs
|
||||||
|
if logprobs_dicts is not None:
|
||||||
|
logprob = []
|
||||||
|
for logprobs_dict, id in zip(logprobs_dicts, output.token_ids):
|
||||||
|
logprob.append(logprobs_dict[id].logprob)
|
||||||
|
logprobs.append(torch.tensor(logprob))
|
||||||
|
|
||||||
|
pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
|
||||||
|
output_token_ids = pad_sequence(output_token_ids, batch_first=True, padding_value=pad_token_id)
|
||||||
|
if len(logprobs) > 0:
|
||||||
|
logprobs = pad_sequence(logprobs, batch_first=True, padding_value=pad_token_id)
|
||||||
|
return output_token_ids, logprobs
|
||||||
|
|
||||||
|
def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], load_format: str) -> None:
|
||||||
|
self.llm_engine.sync_model_weights(actor_weights=actor_weights, load_format=load_format)
|
||||||
|
|
||||||
|
def offload_model_weights(self) -> None:
|
||||||
|
self.llm_engine.offload_model_weights()
|
||||||
283
verl/third_party/vllm/vllm_v_0_4_2/llm_engine_sp.py
vendored
Normal file
283
verl/third_party/vllm/vllm_v_0_4_2/llm_engine_sp.py
vendored
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/llm_engine.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from typing import Dict, Optional, Union, Type
|
||||||
|
|
||||||
|
import vllm
|
||||||
|
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoRAConfig, ParallelConfig, SchedulerConfig,
|
||||||
|
SpeculativeConfig, VisionLanguageConfig)
|
||||||
|
from vllm.core.scheduler import Scheduler
|
||||||
|
from vllm.engine.output_processor.interfaces import (SequenceGroupOutputProcessor)
|
||||||
|
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||||
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||||
|
from vllm.engine.metrics import StatLogger
|
||||||
|
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message)
|
||||||
|
from vllm.utils import Counter
|
||||||
|
from vllm.engine.llm_engine import _load_generation_config_dict
|
||||||
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from .arg_utils import EngineArgs
|
||||||
|
from .tokenizer import TokenizerGroup
|
||||||
|
from .config import ModelConfig, LoadConfig
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
||||||
|
|
||||||
|
|
||||||
|
class LLMEngine(LLMEngine):
|
||||||
|
"""An LLM engine that receives requests and generates texts.
|
||||||
|
|
||||||
|
This is the main class for the vLLM engine. It receives requests
|
||||||
|
from clients and generates texts from the LLM. It includes a tokenizer, a
|
||||||
|
language model (possibly distributed across multiple GPUs), and GPU memory
|
||||||
|
space allocated for intermediate states (aka KV cache). This class utilizes
|
||||||
|
iteration-level scheduling and efficient memory management to maximize the
|
||||||
|
serving throughput.
|
||||||
|
|
||||||
|
The `LLM` class wraps this class for offline batched inference and the
|
||||||
|
`AsyncLLMEngine` class wraps this class for online serving.
|
||||||
|
|
||||||
|
NOTE: The config arguments are derived from the `EngineArgs` class. For the
|
||||||
|
comprehensive list of arguments, see `EngineArgs`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: the actor model initialize outside vllm (add for verl)
|
||||||
|
tokenizer: the initialized tokenizer (add for verl)
|
||||||
|
model_config: The configuration related to the LLM model.
|
||||||
|
cache_config: The configuration related to the KV cache memory
|
||||||
|
management.
|
||||||
|
parallel_config: The configuration related to distributed execution.
|
||||||
|
scheduler_config: The configuration related to the request scheduler.
|
||||||
|
distributed_init_method: The initialization method for distributed
|
||||||
|
execution. See `torch.distributed.init_process_group` for details.
|
||||||
|
placement_group: Ray placement group for distributed execution.
|
||||||
|
Required for distributed execution.
|
||||||
|
log_stats: Whether to log statistics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
# NOTE(sgm): first two arguments are added for verl
|
||||||
|
model: Union[nn.Module, Dict], # model itself or its parameter dict
|
||||||
|
tokenizer: nn.Module,
|
||||||
|
# NOTE(sgm): vllm original arguments
|
||||||
|
model_config: ModelConfig,
|
||||||
|
cache_config: CacheConfig,
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
scheduler_config: SchedulerConfig,
|
||||||
|
device_config: DeviceConfig,
|
||||||
|
load_config: LoadConfig,
|
||||||
|
lora_config: Optional[LoRAConfig],
|
||||||
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
|
speculative_config: Optional[SpeculativeConfig],
|
||||||
|
decoding_config: Optional[DecodingConfig],
|
||||||
|
executor_class: Type[ExecutorBase],
|
||||||
|
log_stats: bool,
|
||||||
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
|
) -> None:
|
||||||
|
logger.info(
|
||||||
|
"Initializing an LLM engine (v%s) with config: "
|
||||||
|
"model=%r, speculative_config=%r, tokenizer=%r, "
|
||||||
|
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
|
||||||
|
"tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, "
|
||||||
|
"max_seq_len=%d, download_dir=%r, load_format=%s, "
|
||||||
|
"tensor_parallel_size=%d, disable_custom_all_reduce=%s, "
|
||||||
|
"quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, "
|
||||||
|
"quantization_param_path=%s, device_config=%s, "
|
||||||
|
"decoding_config=%r, seed=%d, served_model_name=%s)",
|
||||||
|
vllm.__version__,
|
||||||
|
model_config.model,
|
||||||
|
speculative_config,
|
||||||
|
model_config.tokenizer,
|
||||||
|
model_config.skip_tokenizer_init,
|
||||||
|
# model_config.tokenizer_mode,
|
||||||
|
model_config.revision,
|
||||||
|
model_config.tokenizer_revision,
|
||||||
|
# model_config.trust_remote_code,
|
||||||
|
model_config.dtype,
|
||||||
|
model_config.max_model_len,
|
||||||
|
load_config.download_dir,
|
||||||
|
load_config.load_format,
|
||||||
|
parallel_config.tensor_parallel_size,
|
||||||
|
parallel_config.disable_custom_all_reduce,
|
||||||
|
model_config.quantization,
|
||||||
|
model_config.enforce_eager,
|
||||||
|
cache_config.cache_dtype,
|
||||||
|
model_config.quantization_param_path,
|
||||||
|
device_config.device,
|
||||||
|
decoding_config,
|
||||||
|
model_config.seed,
|
||||||
|
# model_config.served_model_name,
|
||||||
|
)
|
||||||
|
# TODO(woosuk): Print more configs in debug mode.
|
||||||
|
|
||||||
|
self.model_config = model_config # TODO: currently is hfconfig
|
||||||
|
self.cache_config = cache_config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
self.vision_language_config = vision_language_config
|
||||||
|
self.parallel_config = parallel_config
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.device_config = device_config
|
||||||
|
self.speculative_config = speculative_config
|
||||||
|
self.load_config = load_config
|
||||||
|
self.decoding_config = decoding_config or DecodingConfig()
|
||||||
|
self.log_stats = log_stats
|
||||||
|
|
||||||
|
# self.model = model # should not store the model, it should be deleted
|
||||||
|
# TODO(shengguangming): maybe we can choose init here or from arguments
|
||||||
|
if not self.model_config.skip_tokenizer_init:
|
||||||
|
# TODO: check tokenizer class
|
||||||
|
self._init_tokenizer(tokenizer)
|
||||||
|
self.detokenizer = Detokenizer(self.tokenizer)
|
||||||
|
else:
|
||||||
|
self.detokenizer = None
|
||||||
|
self.tokenizer = None
|
||||||
|
|
||||||
|
self.seq_counter = Counter()
|
||||||
|
# TODO: don't know what's the usage
|
||||||
|
self.generation_config_fields = _load_generation_config_dict(model_config)
|
||||||
|
|
||||||
|
self.model_executor = executor_class(
|
||||||
|
model=model, # add for spmd_gpu_executor
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
parallel_config=parallel_config,
|
||||||
|
scheduler_config=scheduler_config,
|
||||||
|
device_config=device_config,
|
||||||
|
lora_config=lora_config,
|
||||||
|
vision_language_config=vision_language_config,
|
||||||
|
speculative_config=speculative_config,
|
||||||
|
load_config=load_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Profile the memory usage and initialize the cache.
|
||||||
|
self._initialize_kv_caches()
|
||||||
|
|
||||||
|
# If usage stat is enabled, collect relevant info.
|
||||||
|
if is_usage_stats_enabled():
|
||||||
|
from vllm.model_executor.model_loader import (get_architecture_class_name)
|
||||||
|
usage_message.report_usage(
|
||||||
|
get_architecture_class_name(model_config),
|
||||||
|
usage_context,
|
||||||
|
extra_kvs={
|
||||||
|
# Common configuration
|
||||||
|
"dtype": str(model_config.dtype),
|
||||||
|
"tensor_parallel_size": parallel_config.tensor_parallel_size,
|
||||||
|
"block_size": cache_config.block_size,
|
||||||
|
"gpu_memory_utilization": cache_config.gpu_memory_utilization,
|
||||||
|
|
||||||
|
# Quantization
|
||||||
|
"quantization": model_config.quantization,
|
||||||
|
"kv_cache_dtype": cache_config.cache_dtype,
|
||||||
|
|
||||||
|
# Feature flags
|
||||||
|
"enable_lora": bool(lora_config),
|
||||||
|
"enable_prefix_caching": cache_config.enable_prefix_caching,
|
||||||
|
"enforce_eager": model_config.enforce_eager,
|
||||||
|
"disable_custom_all_reduce": parallel_config.disable_custom_all_reduce,
|
||||||
|
})
|
||||||
|
|
||||||
|
if self.tokenizer:
|
||||||
|
# Ping the tokenizer to ensure liveness if it runs in a
|
||||||
|
# different process.
|
||||||
|
self.tokenizer.ping()
|
||||||
|
|
||||||
|
# Create the scheduler.
|
||||||
|
# NOTE: the cache_config here have been updated with the numbers of
|
||||||
|
# GPU and CPU blocks, which are profiled in the distributed executor.
|
||||||
|
# NOTE(shengguangming): each process will have independent scheduler
|
||||||
|
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
|
||||||
|
|
||||||
|
# Metric Logging.
|
||||||
|
if self.log_stats:
|
||||||
|
self.stat_logger = StatLogger(local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
|
||||||
|
labels=dict(model_name=model_config.served_model_name),
|
||||||
|
max_model_len=self.model_config.max_model_len)
|
||||||
|
self.stat_logger.info("cache_config", self.cache_config)
|
||||||
|
|
||||||
|
# Create sequence output processor, e.g. for beam search or
|
||||||
|
# speculative decoding.
|
||||||
|
self.output_processor = (SequenceGroupOutputProcessor.create_output_processor(
|
||||||
|
self.scheduler_config,
|
||||||
|
self.detokenizer,
|
||||||
|
self.scheduler,
|
||||||
|
self.seq_counter,
|
||||||
|
self.get_tokenizer_for_seq,
|
||||||
|
stop_checker=StopChecker(
|
||||||
|
self.scheduler_config.max_model_len,
|
||||||
|
self.get_tokenizer_for_seq,
|
||||||
|
),
|
||||||
|
))
|
||||||
|
|
||||||
|
# TODO(sgm): add for verl but we may not tokenizer in Rollout
|
||||||
|
def _init_tokenizer(self, tokenizer, **tokenizer_init_kwargs):
|
||||||
|
init_kwargs = dict(enable_lora=bool(self.lora_config),
|
||||||
|
max_num_seqs=self.scheduler_config.max_num_seqs,
|
||||||
|
max_input_length=None)
|
||||||
|
init_kwargs.update(tokenizer_init_kwargs)
|
||||||
|
self.tokenizer: TokenizerGroup = TokenizerGroup(tokenizer, **init_kwargs)
|
||||||
|
|
||||||
|
def init_cache_engine(self):
|
||||||
|
# TODO: check whether we should rebuild the CUDAGraph every iter when offload/load KVCache
|
||||||
|
# Re-capture CUDAGraph would be time-consuming
|
||||||
|
self.model_executor.init_cache_engine()
|
||||||
|
|
||||||
|
def free_cache_engine(self):
|
||||||
|
self.model_executor.free_cache_engine()
|
||||||
|
|
||||||
|
# NOTE(sgm): currently, we only support GPU executor
|
||||||
|
# The GPUExecutor remove the Ray dependency
|
||||||
|
@classmethod
|
||||||
|
def from_engine_args(
|
||||||
|
cls,
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
engine_args: EngineArgs,
|
||||||
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
|
) -> "LLMEngine":
|
||||||
|
"""Creates an LLM engine from the engine arguments."""
|
||||||
|
# Create the engine configs.
|
||||||
|
engine_config = engine_args.create_engine_config()
|
||||||
|
|
||||||
|
# Initialize the cluster and specify the executor class.
|
||||||
|
assert engine_config.device_config.device_type == "cuda", \
|
||||||
|
"Currently, the vllm in verl only support running on GPU"
|
||||||
|
|
||||||
|
if engine_config.parallel_config.world_size == 1:
|
||||||
|
engine_config.load_config.load_format = "dummy_hf"
|
||||||
|
|
||||||
|
from .spmd_gpu_executor import SPMDGPUExecutor
|
||||||
|
executor_class = SPMDGPUExecutor
|
||||||
|
|
||||||
|
# Create the LLM engine.
|
||||||
|
engine = cls(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
**engine_config.to_dict(),
|
||||||
|
executor_class=executor_class,
|
||||||
|
log_stats=not engine_args.disable_log_stats,
|
||||||
|
usage_context=usage_context,
|
||||||
|
)
|
||||||
|
return engine
|
||||||
|
|
||||||
|
def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], load_format: str) -> None:
|
||||||
|
self.model_executor.sync_model_weights(actor_weights=actor_weights, load_format=load_format)
|
||||||
|
|
||||||
|
def offload_model_weights(self) -> None:
|
||||||
|
self.model_executor.offload_model_weights()
|
||||||
348
verl/third_party/vllm/vllm_v_0_4_2/megatron_weight_loaders.py
vendored
Normal file
348
verl/third_party/vllm/vllm_v_0_4_2/megatron_weight_loaders.py
vendored
Normal file
@@ -0,0 +1,348 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.linear import *
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead
|
||||||
|
from vllm.model_executor.layers.activation import ScaledActivation
|
||||||
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE(shengguangming): replace the origin weight loader function in the class
|
||||||
|
def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
||||||
|
"""Parallel Linear weight loader."""
|
||||||
|
assert param.size() == loaded_weight.size(
|
||||||
|
), 'the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}'.format(
|
||||||
|
param.size(), loaded_weight.size())
|
||||||
|
assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same"
|
||||||
|
|
||||||
|
param.data = loaded_weight.data
|
||||||
|
|
||||||
|
|
||||||
|
def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
||||||
|
"""Default weight loader."""
|
||||||
|
assert param.size() == loaded_weight.size()
|
||||||
|
assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same"
|
||||||
|
|
||||||
|
param.data = loaded_weight.data
|
||||||
|
|
||||||
|
|
||||||
|
def gpt2_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
|
||||||
|
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
|
||||||
|
for name, loaded_weight in actor_weights.items():
|
||||||
|
if "lm_head.weight" in name:
|
||||||
|
# GPT-2 ties the weights of the embedding layer and the final
|
||||||
|
# linear layer.
|
||||||
|
continue
|
||||||
|
if ".attn.bias" in name or ".attn.masked_bias" in name:
|
||||||
|
# Skip attention mask.
|
||||||
|
# NOTE: "c_attn.bias" should not be skipped.
|
||||||
|
continue
|
||||||
|
if not name.startswith("transformer."):
|
||||||
|
name = "transformer." + name
|
||||||
|
param = params_dict[name]
|
||||||
|
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
||||||
|
# Because of this, we need to transpose the weights.
|
||||||
|
# Note(zhuohan): the logic below might break quantized models.
|
||||||
|
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
||||||
|
if conv1d_weight_name not in name:
|
||||||
|
continue
|
||||||
|
if not name.endswith(".weight"):
|
||||||
|
continue
|
||||||
|
# TODO: check megatron
|
||||||
|
loaded_weight = loaded_weight.t()
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
def llama_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
|
||||||
|
# NOTE(shengguangming): the megatron llama may have this prefix
|
||||||
|
params_dict = dict(vllm_model.named_parameters())
|
||||||
|
for name, loaded_weight in actor_weights.items():
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
|
||||||
|
params_mapping = [
|
||||||
|
# (megatron core gpt model name, vllm model name)
|
||||||
|
("embedding.word_embeddings", "model.embed_tokens"),
|
||||||
|
("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"),
|
||||||
|
("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"),
|
||||||
|
("self_attention.linear_qkv", "self_attn.qkv_proj"),
|
||||||
|
("self_attention.linear_qkv", "self_attn.qkv_proj"),
|
||||||
|
("self_attention.linear_proj", 'self_attn.o_proj'),
|
||||||
|
('pre_mlp_layernorm', 'post_attention_layernorm'),
|
||||||
|
('mlp.linear_fc1.layer_norm_weight', 'post_attention_layernorm.weight'),
|
||||||
|
('mlp.linear_fc1.layer_norm_bias', 'post_attention_layernorm.bias'),
|
||||||
|
('mlp.linear_fc1', 'mlp.gate_up_proj'),
|
||||||
|
('mlp.linear_fc2', 'mlp.down_proj'),
|
||||||
|
('decoder.final_layernorm', 'model.norm'),
|
||||||
|
('output_layer', 'lm_head'),
|
||||||
|
]
|
||||||
|
# NOTE(shengguangming): the megatron llama may have this prefix
|
||||||
|
params_dict = dict(vllm_model.named_parameters())
|
||||||
|
for name, loaded_weight in actor_weights.items():
|
||||||
|
name = _replace_name(name, params_mapping)
|
||||||
|
if name.endswith('.bias') and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
|
||||||
|
params_mapping = [
|
||||||
|
# (megatron core gpt model name, vllm model name)
|
||||||
|
("embedding.word_embeddings", "model.embed_tokens"),
|
||||||
|
("self_attention.linear_qkv", "self_attn.qkv_proj"),
|
||||||
|
("self_attention.linear_proj", 'self_attn.o_proj'),
|
||||||
|
(
|
||||||
|
'input_layernorm',
|
||||||
|
'input_layernorm',
|
||||||
|
),
|
||||||
|
('pre_mlp_layernorm', 'post_attention_layernorm'),
|
||||||
|
('mlp.linear_fc1', 'mlp.gate_up_proj'),
|
||||||
|
('mlp.linear_fc2', 'mlp.down_proj'),
|
||||||
|
('decoder.final_layernorm', 'model.norm'),
|
||||||
|
('output_layer', 'lm_head'),
|
||||||
|
]
|
||||||
|
# NOTE(shengguangming): the megatron llama may have this prefix
|
||||||
|
params_dict = dict(vllm_model.named_parameters())
|
||||||
|
for name, loaded_weight in actor_weights.items():
|
||||||
|
name = _replace_name(name, params_mapping)
|
||||||
|
if name.endswith('.bias') and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_name(megatron_name, name_mapping):
|
||||||
|
for m_name, v_name in name_mapping:
|
||||||
|
if m_name not in megatron_name:
|
||||||
|
continue
|
||||||
|
if 'layers' in megatron_name: # deal with decoder layers
|
||||||
|
megatron_name = megatron_name.replace('decoder', 'model')
|
||||||
|
megatron_name_list = megatron_name.split('.')
|
||||||
|
if 'layer_norm_weight' in megatron_name_list or 'layer_norm_bias' in megatron_name_list:
|
||||||
|
param_name_list = megatron_name_list[:3]
|
||||||
|
param_name_list.append(v_name)
|
||||||
|
param_name = '.'.join(param_name_list)
|
||||||
|
else:
|
||||||
|
param_name_list = megatron_name_list[:3]
|
||||||
|
weight_or_bias = megatron_name_list[-1]
|
||||||
|
param_name_list.append(v_name)
|
||||||
|
param_name_list.append(weight_or_bias)
|
||||||
|
param_name = '.'.join(param_name_list)
|
||||||
|
return param_name
|
||||||
|
else:
|
||||||
|
param_name = megatron_name.replace(m_name, v_name)
|
||||||
|
return param_name
|
||||||
|
|
||||||
|
|
||||||
|
def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
|
||||||
|
params_mapping = [
|
||||||
|
# (megatron core gpt model name, vllm model name)
|
||||||
|
("embedding.word_embeddings", "model.embed_tokens"),
|
||||||
|
("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"),
|
||||||
|
("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"),
|
||||||
|
("self_attention.linear_qkv", "self_attn.qkv_proj"),
|
||||||
|
("self_attention.linear_qkv", "self_attn.qkv_proj"),
|
||||||
|
("self_attention.linear_proj", 'self_attn.o_proj'),
|
||||||
|
('pre_mlp_layernorm', 'post_attention_layernorm'),
|
||||||
|
('mlp.linear_fc1.layer_norm_weight', 'post_attention_layernorm.weight'),
|
||||||
|
('mlp.linear_fc1.layer_norm_bias', 'post_attention_layernorm.bias'),
|
||||||
|
('mlp.linear_fc1', 'mlp.gate_up_proj'),
|
||||||
|
('mlp.linear_fc2', 'mlp.down_proj'),
|
||||||
|
('decoder.final_layernorm', 'model.norm'),
|
||||||
|
('output_layer', 'lm_head'),
|
||||||
|
]
|
||||||
|
# NOTE(shengguangming): the megatron llama may have this prefix
|
||||||
|
params_dict = dict(vllm_model.named_parameters())
|
||||||
|
for name, loaded_weight in actor_weights.items():
|
||||||
|
name = _replace_name(name, params_mapping)
|
||||||
|
if name.endswith('.bias') and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
|
||||||
|
params_mapping = [
|
||||||
|
# (megatron core gpt model name, vllm model name)
|
||||||
|
("embedding.word_embeddings", "model.embed_tokens"),
|
||||||
|
("self_attention.linear_qkv", "self_attn.qkv_proj"),
|
||||||
|
("self_attention.linear_proj", 'self_attn.o_proj'),
|
||||||
|
(
|
||||||
|
'input_layernorm',
|
||||||
|
'input_layernorm',
|
||||||
|
),
|
||||||
|
('pre_mlp_layernorm', 'post_attention_layernorm'),
|
||||||
|
('mlp.linear_fc1', 'mlp.gate_up_proj'),
|
||||||
|
('mlp.linear_fc2', 'mlp.down_proj'),
|
||||||
|
('decoder.final_layernorm', 'model.norm'),
|
||||||
|
('output_layer', 'lm_head'),
|
||||||
|
]
|
||||||
|
# NOTE(shengguangming): the megatron llama may have this prefix
|
||||||
|
params_dict = dict(vllm_model.named_parameters())
|
||||||
|
for name, loaded_weight in actor_weights.items():
|
||||||
|
name = _replace_name(name, params_mapping)
|
||||||
|
if name.endswith('.bias') and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_name(megatron_name, name_mapping):
|
||||||
|
for m_name, v_name in name_mapping:
|
||||||
|
if m_name not in megatron_name:
|
||||||
|
continue
|
||||||
|
if 'layers' in megatron_name: # deal with decoder layers
|
||||||
|
megatron_name = megatron_name.replace('decoder', 'model')
|
||||||
|
megatron_name_list = megatron_name.split('.')
|
||||||
|
if 'layer_norm_weight' in megatron_name_list or 'layer_norm_bias' in megatron_name_list:
|
||||||
|
param_name_list = megatron_name_list[:3]
|
||||||
|
param_name_list.append(v_name)
|
||||||
|
param_name = '.'.join(param_name_list)
|
||||||
|
else:
|
||||||
|
param_name_list = megatron_name_list[:3]
|
||||||
|
weight_or_bias = megatron_name_list[-1]
|
||||||
|
param_name_list.append(v_name)
|
||||||
|
param_name_list.append(weight_or_bias)
|
||||||
|
param_name = '.'.join(param_name_list)
|
||||||
|
return param_name
|
||||||
|
else:
|
||||||
|
param_name = megatron_name.replace(m_name, v_name)
|
||||||
|
return param_name
|
||||||
|
|
||||||
|
|
||||||
|
def mistral_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
|
||||||
|
# TODO: need to implement a general way to deal with prefix
|
||||||
|
params_dict = dict(vllm_model.named_parameters())
|
||||||
|
for name, loaded_weight in actor_weights.items():
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
__LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__ = {
|
||||||
|
ColumnParallelLinear: parallel_weight_loader,
|
||||||
|
MergedColumnParallelLinear: parallel_weight_loader,
|
||||||
|
QKVParallelLinear: parallel_weight_loader,
|
||||||
|
RowParallelLinear: parallel_weight_loader,
|
||||||
|
VocabParallelEmbedding: parallel_weight_loader,
|
||||||
|
ParallelLMHead: parallel_weight_loader
|
||||||
|
# "ScaledActivation.weight_loader": ScaledActivation, # TODO(shengguangming): latest commit in vllm fix awq for this function and add load_weights
|
||||||
|
# "default_weight_loader": default_weight_loader
|
||||||
|
}
|
||||||
|
|
||||||
|
# for layer_class, weight_loader in __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__.items():
|
||||||
|
# # setattr(layer_class, 'megatron_weight_loader', weight_loader)
|
||||||
|
# layer_class.weight_loader = weight_loader
|
||||||
|
|
||||||
|
__MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__ = {
|
||||||
|
'GPT2LMHeadModel': gpt2_weight_loader,
|
||||||
|
'LlamaForCausalLM': llama_megatron_core_te_weight_loader, # use te backend for open-source megatron
|
||||||
|
'LLaMAForCausalLM': llama_megatron_core_te_weight_loader,
|
||||||
|
'MistralForCausalLM': mistral_megatron_weight_loader,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# the actor model is .state_dict()
|
||||||
|
# Load megatron weights
|
||||||
|
def load_megatron_weights(actor_weights: Dict, vllm_model: nn.Module):
|
||||||
|
weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__)
|
||||||
|
weight_loader(actor_weights, vllm_model)
|
||||||
|
# NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu
|
||||||
|
# after init, and we need this after sync model weights for in first iter.
|
||||||
|
vllm_model = vllm_model.cuda()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model_weight_loader(arch: str):
|
||||||
|
if arch in __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__:
|
||||||
|
return __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__[arch]
|
||||||
|
raise ValueError(f"Model architectures {arch} are not supported for now. "
|
||||||
|
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
||||||
|
|
||||||
|
|
||||||
|
def update_megatron_weight_loader():
|
||||||
|
for layer_class, weight_loader in __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__.items():
|
||||||
|
layer_class.weight_loader = weight_loader
|
||||||
|
VocabParallelEmbedding.__init__ = vocab_init
|
||||||
|
|
||||||
|
|
||||||
|
# FIXME(shengguangming): the vLLM vocab will pad to 64, which may incur out of bounds
|
||||||
|
# so we need to rewrite the init function of vocab
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||||
|
|
||||||
|
|
||||||
|
def vocab_init(self,
|
||||||
|
num_embeddings: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
|
org_num_embeddings: Optional[int] = None,
|
||||||
|
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
|
||||||
|
super(VocabParallelEmbedding, self).__init__()
|
||||||
|
|
||||||
|
# Keep the input dimensions.
|
||||||
|
# TODO (pad to be divided by 4)
|
||||||
|
self.num_embeddings = num_embeddings
|
||||||
|
self.org_vocab_size = org_num_embeddings or num_embeddings
|
||||||
|
|
||||||
|
# self.num_embeddings_padded = pad_vocab_size(num_embeddings,
|
||||||
|
# padding_size)
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
if params_dtype is None:
|
||||||
|
params_dtype = torch.get_default_dtype()
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
# Divide the weight matrix along the vocaburaly dimension.
|
||||||
|
|
||||||
|
# TODO: remove dependencies from megatron
|
||||||
|
from megatron.core.tensor_parallel.utils import VocabUtility
|
||||||
|
self.vocab_start_index, self.vocab_end_index = (VocabUtility.vocab_range_from_global_vocab_size(
|
||||||
|
self.num_embeddings, get_tensor_model_parallel_rank(), self.tp_size))
|
||||||
|
self.num_embeddings_per_partition = (self.vocab_end_index - self.vocab_start_index)
|
||||||
|
self.weight = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.num_embeddings_per_partition,
|
||||||
|
self.embedding_dim,
|
||||||
|
# device=torch.cuda.current_device(),
|
||||||
|
dtype=params_dtype))
|
||||||
|
set_weight_attrs(self.weight, {"parallel_dim": 0, "weight_loader": self.weight_loader})
|
||||||
265
verl/third_party/vllm/vllm_v_0_4_2/model_loader.py
vendored
Normal file
265
verl/third_party/vllm/vllm_v_0_4_2/model_loader.py
vendored
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader
|
||||||
|
"""Utilities for selecting and loading models."""
|
||||||
|
from typing import Dict, Union, Optional, Iterable, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
from vllm.config import (DeviceConfig, LoRAConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||||
|
from vllm.model_executor.model_loader import BaseModelLoader
|
||||||
|
from vllm.model_executor.model_loader.loader import _initialize_model
|
||||||
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||||
|
from vllm.distributed.communication_op import tensor_model_parallel_all_gather
|
||||||
|
|
||||||
|
from .config import ModelConfig, LoadFormat, LoadConfig
|
||||||
|
from .megatron_weight_loaders import load_megatron_weights, update_megatron_weight_loader
|
||||||
|
from .dtensor_weight_loaders import load_dtensor_weights, update_dtensor_weight_loader
|
||||||
|
from .hf_weight_loader import update_hf_weight_loader
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(actor_model: Union[PreTrainedModel, Dict], model_config: ModelConfig, load_config: LoadConfig,
|
||||||
|
device_config: DeviceConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig,
|
||||||
|
lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
|
||||||
|
loader = get_model_loader(load_config)
|
||||||
|
if load_config.load_format.startswith('dummy'):
|
||||||
|
return loader.load_model(model_config=model_config,
|
||||||
|
device_config=device_config,
|
||||||
|
lora_config=lora_config,
|
||||||
|
vision_language_config=vision_language_config,
|
||||||
|
parallel_config=parallel_config,
|
||||||
|
scheduler_config=scheduler_config)
|
||||||
|
else:
|
||||||
|
return loader.load_model(actor_model=actor_model,
|
||||||
|
model_config=model_config,
|
||||||
|
device_config=device_config,
|
||||||
|
lora_config=lora_config,
|
||||||
|
vision_language_config=vision_language_config,
|
||||||
|
parallel_config=parallel_config,
|
||||||
|
scheduler_config=scheduler_config)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||||
|
"""Get a model loader based on the load format."""
|
||||||
|
|
||||||
|
if isinstance(load_config.load_format, type):
|
||||||
|
return load_config.load_format(load_config)
|
||||||
|
|
||||||
|
if load_config.load_format == LoadFormat.AUTO:
|
||||||
|
update_megatron_weight_loader()
|
||||||
|
return MegatronLoader(load_config)
|
||||||
|
|
||||||
|
# NOTE(sgm): change the weight_loader function in runtime
|
||||||
|
if load_config.load_format == LoadFormat.MEGATRON:
|
||||||
|
update_megatron_weight_loader()
|
||||||
|
return MegatronLoader(load_config)
|
||||||
|
|
||||||
|
if load_config.load_format == LoadFormat.HF:
|
||||||
|
update_hf_weight_loader()
|
||||||
|
return HFLoader(load_config)
|
||||||
|
|
||||||
|
if load_config.load_format == LoadFormat.DTENSOR:
|
||||||
|
update_dtensor_weight_loader()
|
||||||
|
return DTensorLoader(load_config)
|
||||||
|
|
||||||
|
if load_config.load_format == LoadFormat.DUMMY_HF:
|
||||||
|
update_hf_weight_loader()
|
||||||
|
return DummyModelLoader(load_config)
|
||||||
|
|
||||||
|
if load_config.load_format == LoadFormat.DUMMY_MEGATRON:
|
||||||
|
update_megatron_weight_loader()
|
||||||
|
return DummyModelLoader(load_config)
|
||||||
|
|
||||||
|
if load_config.load_format == LoadFormat.DUMMY_DTENSOR:
|
||||||
|
update_dtensor_weight_loader()
|
||||||
|
return DummyModelLoader(load_config)
|
||||||
|
|
||||||
|
raise ValueError('load format not supported in verl: {}, only support {} and {}'.format(
|
||||||
|
load_config.load_format, LoadFormat.MEGATRON, LoadFormat.HF))
|
||||||
|
|
||||||
|
|
||||||
|
class DummyModelLoader(BaseModelLoader):
|
||||||
|
"""Model loader that will set model weights to random values."""
|
||||||
|
|
||||||
|
def __init__(self, load_config: LoadConfig):
|
||||||
|
super().__init__(load_config)
|
||||||
|
if load_config.model_loader_extra_config:
|
||||||
|
raise ValueError(f"Model loader extra config is not supported for "
|
||||||
|
f"load format {load_config.load_format}")
|
||||||
|
|
||||||
|
def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig],
|
||||||
|
vision_language_config: Optional[VisionLanguageConfig], parallel_config: ParallelConfig,
|
||||||
|
scheduler_config: SchedulerConfig) -> nn.Module:
|
||||||
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
|
with torch.device(device_config.device):
|
||||||
|
model = _initialize_model(model_config, self.load_config, lora_config, vision_language_config)
|
||||||
|
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||||
|
# random values to the weights.
|
||||||
|
# initialize_dummy_weights(model)
|
||||||
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
class MegatronLoader(BaseModelLoader):
|
||||||
|
"""Model loader that can load the model weights from partitioned megatron model."""
|
||||||
|
|
||||||
|
def __init__(self, load_config: LoadConfig):
|
||||||
|
super().__init__(load_config)
|
||||||
|
if load_config.model_loader_extra_config:
|
||||||
|
raise ValueError(f"Model loader extra config is not supported for "
|
||||||
|
f"load format {load_config.load_format}")
|
||||||
|
|
||||||
|
def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]):
|
||||||
|
# NOTE(shengguangming) Load the weights from the actor model
|
||||||
|
pass
|
||||||
|
# if isinstance(actor_model, nn.Module):
|
||||||
|
# load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model)
|
||||||
|
# else:
|
||||||
|
# load_weights(actor_weights=actor_model, vllm_model=model)
|
||||||
|
# return actor_model
|
||||||
|
|
||||||
|
def load_model(self, actor_model: Union[PreTrainedModel,
|
||||||
|
Dict], model_config: ModelConfig, device_config: DeviceConfig,
|
||||||
|
lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig],
|
||||||
|
parallel_config: ParallelConfig, scheduler_config: SchedulerConfig) -> nn.Module:
|
||||||
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
|
with torch.device(device_config.device):
|
||||||
|
model = _initialize_model(model_config, self.load_config, lora_config, vision_language_config)
|
||||||
|
|
||||||
|
# TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm
|
||||||
|
if isinstance(actor_model, nn.Module):
|
||||||
|
load_megatron_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)),
|
||||||
|
vllm_model=model)
|
||||||
|
else:
|
||||||
|
load_megatron_weights(actor_weights=actor_model, vllm_model=model)
|
||||||
|
|
||||||
|
for _, module in model.named_modules():
|
||||||
|
quant_method = getattr(module, "quant_method", None)
|
||||||
|
if quant_method is not None:
|
||||||
|
quant_method.process_weights_after_loading(module)
|
||||||
|
# FIXME: Remove this after Mixtral is updated
|
||||||
|
# to use quant_method.
|
||||||
|
if hasattr(module, "process_weights_after_loading"):
|
||||||
|
module.process_weights_after_loading()
|
||||||
|
# NOTE(sgm) Some weights are point to gpu, but still need this.
|
||||||
|
model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage
|
||||||
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
class HFLoader(BaseModelLoader):
|
||||||
|
"""Model loader that can load the model weights from model's full params."""
|
||||||
|
|
||||||
|
def __init__(self, load_config: LoadConfig):
|
||||||
|
super().__init__(load_config)
|
||||||
|
if load_config.model_loader_extra_config:
|
||||||
|
raise ValueError(f"Model loader extra config is not supported for "
|
||||||
|
f"load format {load_config.load_format}")
|
||||||
|
|
||||||
|
def _get_weights_iterator(self, actor_model: Union[PreTrainedModel, Dict]):
|
||||||
|
if isinstance(actor_model, Dict):
|
||||||
|
return actor_model.items()
|
||||||
|
elif isinstance(actor_model, nn.Module):
|
||||||
|
return dict(actor_model.named_parameters()).items()
|
||||||
|
else:
|
||||||
|
raise ValueError(f'actor model should be Dict or nn.Module, but get {type(actor_model)}')
|
||||||
|
|
||||||
|
def load_model(self, actor_model: Union[PreTrainedModel,
|
||||||
|
Dict], model_config: ModelConfig, device_config: DeviceConfig,
|
||||||
|
lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig],
|
||||||
|
parallel_config: ParallelConfig, scheduler_config: SchedulerConfig) -> nn.Module:
|
||||||
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
|
# with torch.device(device_config.device):
|
||||||
|
# NOTE(sgm): init the model in cpu
|
||||||
|
model = _initialize_model(model_config, self.load_config, lora_config, vision_language_config)
|
||||||
|
model.load_weights(self._get_weights_iterator(actor_model))
|
||||||
|
for _, module in model.named_modules():
|
||||||
|
quant_method = getattr(module, "quant_method", None)
|
||||||
|
if quant_method is not None:
|
||||||
|
quant_method.process_weights_after_loading(module)
|
||||||
|
# FIXME: Remove this after Mixtral is updated
|
||||||
|
# to use quant_method.
|
||||||
|
if hasattr(module, "process_weights_after_loading"):
|
||||||
|
module.process_weights_after_loading()
|
||||||
|
# NOTE(sgm) Some weights are point to gpu, but still need this.
|
||||||
|
model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage
|
||||||
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
class DTensorLoader(BaseModelLoader):
|
||||||
|
"""Model loader that can load the model weights from partitioned megatron model."""
|
||||||
|
|
||||||
|
def __init__(self, load_config: LoadConfig):
|
||||||
|
super().__init__(load_config)
|
||||||
|
if load_config.model_loader_extra_config:
|
||||||
|
raise ValueError(f"Model loader extra config is not supported for "
|
||||||
|
f"load format {load_config.load_format}")
|
||||||
|
|
||||||
|
def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]):
|
||||||
|
# NOTE(shengguangming) Load the weights from the actor model
|
||||||
|
pass
|
||||||
|
# if isinstance(actor_model, nn.Module):
|
||||||
|
# load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model)
|
||||||
|
# else:
|
||||||
|
# load_weights(actor_weights=actor_model, vllm_model=model)
|
||||||
|
# return actor_model
|
||||||
|
|
||||||
|
def load_model(self, actor_model: Union[PreTrainedModel,
|
||||||
|
Dict], model_config: ModelConfig, device_config: DeviceConfig,
|
||||||
|
lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig],
|
||||||
|
parallel_config: ParallelConfig, scheduler_config: SchedulerConfig) -> nn.Module:
|
||||||
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
|
with torch.device(device_config.device):
|
||||||
|
model = _initialize_model(model_config, self.load_config, lora_config, vision_language_config)
|
||||||
|
|
||||||
|
# TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm
|
||||||
|
if isinstance(actor_model, nn.Module):
|
||||||
|
load_dtensor_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)),
|
||||||
|
vllm_model=model)
|
||||||
|
else:
|
||||||
|
load_dtensor_weights(actor_weights=actor_model, vllm_model=model)
|
||||||
|
|
||||||
|
for _, module in model.named_modules():
|
||||||
|
quant_method = getattr(module, "quant_method", None)
|
||||||
|
if quant_method is not None:
|
||||||
|
quant_method.process_weights_after_loading(module)
|
||||||
|
# FIXME: Remove this after Mixtral is updated
|
||||||
|
# to use quant_method.
|
||||||
|
if hasattr(module, "process_weights_after_loading"):
|
||||||
|
module.process_weights_after_loading()
|
||||||
|
# NOTE(sgm) Some weights are point to gpu, but still need this.
|
||||||
|
model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage
|
||||||
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
# FIXME(sgm): hack the _get_logits function in vllm v0.4.2
|
||||||
|
# as they use ray, the _get_logits result will only need to return to the driver node,
|
||||||
|
# therefore gather is enough. However, we use SPMD instead of a central scheduler,
|
||||||
|
# all_gather is required (aligned with v0.2.6)
|
||||||
|
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
|
||||||
|
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
|
# Get the logits for the next tokens.
|
||||||
|
logits = torch.matmul(hidden_states, embedding.t())
|
||||||
|
if embedding_bias is not None:
|
||||||
|
logits += embedding_bias
|
||||||
|
logits = tensor_model_parallel_all_gather(logits)
|
||||||
|
# Remove paddings in vocab (if any).
|
||||||
|
if logits is not None:
|
||||||
|
logits = logits[:, :self.org_vocab_size]
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
|
||||||
|
LogitsProcessor._get_logits = _get_logits
|
||||||
281
verl/third_party/vllm/vllm_v_0_4_2/model_runner.py
vendored
Normal file
281
verl/third_party/vllm/vllm_v_0_4_2/model_runner.py
vendored
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/model_runner.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from enum import IntEnum
|
||||||
|
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
|
from vllm.attention import (AttentionMetadata, get_attn_backend)
|
||||||
|
from vllm.config import (DeviceConfig, LoRAConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.lora.layers import LoRAMapping
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||||
|
from vllm.model_executor import SamplingMetadata
|
||||||
|
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, SequenceGroupMetadata)
|
||||||
|
from vllm.utils import (CudaMemoryProfiler, is_hip, is_pin_memory_available)
|
||||||
|
from vllm.worker.model_runner import ModelRunner, CUDAGraphRunner
|
||||||
|
|
||||||
|
from .model_loader import get_model
|
||||||
|
from .config import ModelConfig, LoadConfig
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# How batches are constructed.
|
||||||
|
class BatchType(IntEnum):
|
||||||
|
# Every batch is prefill.
|
||||||
|
PREFILL = 0
|
||||||
|
# Every batch is decode.
|
||||||
|
DECODE = 1
|
||||||
|
# Batch is a mixture of prefill and decode.
|
||||||
|
MIXED = 2
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRunner(ModelRunner):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Union[nn.Module, Dict], # model itself or its parameter dict
|
||||||
|
model_config: ModelConfig,
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
scheduler_config: SchedulerConfig,
|
||||||
|
device_config: DeviceConfig,
|
||||||
|
load_config: LoadConfig,
|
||||||
|
lora_config: Optional[LoRAConfig],
|
||||||
|
kv_cache_dtype: Optional[str] = "auto",
|
||||||
|
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||||
|
):
|
||||||
|
self.model_config = model_config
|
||||||
|
self.parallel_config = parallel_config
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
self.load_config = load_config
|
||||||
|
|
||||||
|
# model_config can be None in tests/samplers/test_sampler.py.
|
||||||
|
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
|
||||||
|
self.sliding_window = (model_config.get_sliding_window() if model_config is not None else None)
|
||||||
|
self.device_config = (device_config if device_config is not None else DeviceConfig())
|
||||||
|
self.device = self.device_config.device
|
||||||
|
|
||||||
|
# NOTE(sgm): add for verl
|
||||||
|
self.model = model # this will be replaced by get_model()
|
||||||
|
|
||||||
|
# Set after load_model.
|
||||||
|
self.lora_manager: LRUCacheWorkerLoRAManager = None
|
||||||
|
|
||||||
|
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
|
||||||
|
self.graph_memory_pool: Optional[Tuple[int, int]] = None # Set during graph capture.
|
||||||
|
|
||||||
|
self.max_seq_len_to_capture = (self.model_config.max_seq_len_to_capture if self.model_config is not None else 0)
|
||||||
|
|
||||||
|
self.pin_memory = is_pin_memory_available()
|
||||||
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
self.vision_language_config = vision_language_config
|
||||||
|
|
||||||
|
self.attn_backend = get_attn_backend(self.model_config.dtype if model_config is not None else None)
|
||||||
|
|
||||||
|
# Lazy initialization
|
||||||
|
self.block_size: int # Set after initial profiling.
|
||||||
|
# When using CUDA graph, the input block tables must be padded to
|
||||||
|
# max_seq_len_to_capture. However, creating the block table in
|
||||||
|
# Python can be expensive. To optimize this, we cache the block table
|
||||||
|
# in numpy and only copy the actual input content at every iteration.
|
||||||
|
# The shape of the cached block table will be
|
||||||
|
# (max batch size to capture, max context len to capture / block size).
|
||||||
|
self.graph_block_tables: torch.Tensor # Set after initial profiling.
|
||||||
|
|
||||||
|
# Set if the backend is flashinfer.
|
||||||
|
self.flashinfer_workspace_buffer: torch.Tensor
|
||||||
|
|
||||||
|
# NOTE(sgm): initialize model using the actor model
|
||||||
|
def load_model(self) -> None:
|
||||||
|
with CudaMemoryProfiler() as m:
|
||||||
|
self.model = get_model(actor_model=self.model,
|
||||||
|
model_config=self.model_config,
|
||||||
|
device_config=self.device_config,
|
||||||
|
lora_config=self.lora_config,
|
||||||
|
load_config=self.load_config,
|
||||||
|
parallel_config=self.parallel_config,
|
||||||
|
scheduler_config=self.scheduler_config,
|
||||||
|
vision_language_config=self.vision_language_config)
|
||||||
|
self.model_memory_usage = m.consumed_memory
|
||||||
|
logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30))
|
||||||
|
|
||||||
|
if self.lora_config:
|
||||||
|
assert hasattr(self.model, "supported_lora_modules") and self.model.supported_lora_modules, (
|
||||||
|
"Model does not support LoRA")
|
||||||
|
assert hasattr(self.model, "embedding_modules"), "Model does not have embedding_modules"
|
||||||
|
assert hasattr(self.model, "embedding_padding_modules"), "Model does not have embedding_padding_modules"
|
||||||
|
self.lora_manager = LRUCacheWorkerLoRAManager(self.scheduler_config.max_num_seqs,
|
||||||
|
self.scheduler_config.max_num_batched_tokens, self.vocab_size,
|
||||||
|
self.lora_config, self.device, self.model.embedding_modules,
|
||||||
|
self.model.embedding_padding_modules)
|
||||||
|
self.model = self.lora_manager.create_lora_manager(self.model)
|
||||||
|
|
||||||
|
if self.kv_cache_dtype == "fp8" and is_hip():
|
||||||
|
# Currently scaled KV cache is only enabled on ROCm
|
||||||
|
if self.model_config.quantization_param_path is not None:
|
||||||
|
if callable(getattr(self.model, "load_kv_cache_scales", None)):
|
||||||
|
self.model.load_kv_cache_scales(self.model_config.quantization_param_path)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Using FP8 KV cache and scaling factors provided but "
|
||||||
|
"model %s does not support loading scaling factors.", self.model.__class__)
|
||||||
|
else:
|
||||||
|
logger.warning("Using FP8 KV cache but no scaling factors "
|
||||||
|
"provided. Defaulting to scaling factors of 1.0. "
|
||||||
|
"This may lead to less accurate results!")
|
||||||
|
elif self.model_config.quantization_param_path is not None:
|
||||||
|
logger.warning("KV cache scaling factors provided, "
|
||||||
|
"but the KV cache data type is not FP8. "
|
||||||
|
"KV cache scaling factors will not be used.")
|
||||||
|
|
||||||
|
def prepare_input_tensors(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, Set[LoRARequest], LoRAMapping,
|
||||||
|
torch.Tensor]:
|
||||||
|
# NOTE(sgm): all workers prepare the input in the same way
|
||||||
|
prefill_reqs = []
|
||||||
|
decode_reqs = []
|
||||||
|
for seq_group_meta in seq_group_metadata_list:
|
||||||
|
if seq_group_meta.is_prompt:
|
||||||
|
prefill_reqs.append(seq_group_meta)
|
||||||
|
else:
|
||||||
|
decode_reqs.append(seq_group_meta)
|
||||||
|
|
||||||
|
# Prepare input tensors.
|
||||||
|
(
|
||||||
|
input_tokens,
|
||||||
|
input_positions,
|
||||||
|
prefill_attn_metadata,
|
||||||
|
seq_lens,
|
||||||
|
query_lens,
|
||||||
|
lora_index_mapping,
|
||||||
|
lora_prompt_mapping,
|
||||||
|
lora_requests,
|
||||||
|
multi_modal_input,
|
||||||
|
slot_mapping,
|
||||||
|
) = self._prepare_prompt(prefill_reqs)
|
||||||
|
(
|
||||||
|
decode_input_tokens,
|
||||||
|
decode_input_positions,
|
||||||
|
decode_attn_metadata,
|
||||||
|
decode_lora_index_mapping,
|
||||||
|
decode_lora_prompt_mapping,
|
||||||
|
decode_lora_requests,
|
||||||
|
decode_slot_mapping,
|
||||||
|
) = self._prepare_decode(decode_reqs)
|
||||||
|
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, seq_lens, query_lens, self.device,
|
||||||
|
self.pin_memory)
|
||||||
|
|
||||||
|
if not self.scheduler_config.chunked_prefill_enabled:
|
||||||
|
assert (len(prefill_reqs) and len(decode_reqs)) == 0
|
||||||
|
|
||||||
|
num_prefills = len(seq_lens)
|
||||||
|
num_prefill_tokens = len(input_tokens)
|
||||||
|
num_decode_tokens = len(decode_input_tokens)
|
||||||
|
|
||||||
|
# Coalesce tensors. Note that attn_metadata is currently not
|
||||||
|
# coalesced for simplicity.
|
||||||
|
input_tokens.extend(decode_input_tokens)
|
||||||
|
input_positions.extend(decode_input_positions)
|
||||||
|
slot_mapping.extend(decode_slot_mapping)
|
||||||
|
lora_index_mapping.extend(decode_lora_index_mapping)
|
||||||
|
lora_prompt_mapping.extend(decode_lora_prompt_mapping)
|
||||||
|
lora_requests.update(decode_lora_requests)
|
||||||
|
|
||||||
|
input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device)
|
||||||
|
input_positions = torch.tensor(input_positions, dtype=torch.long, device=self.device)
|
||||||
|
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device)
|
||||||
|
|
||||||
|
if self.lora_config:
|
||||||
|
lora_mapping = LoRAMapping(
|
||||||
|
lora_index_mapping,
|
||||||
|
lora_prompt_mapping,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lora_mapping = None
|
||||||
|
|
||||||
|
# Broadcast the metadata.
|
||||||
|
# If batch contains both prefill and decode, it sends 2 broadcasts.
|
||||||
|
# If it only contains 1 type, it triggers a single broadcast.
|
||||||
|
if (prefill_attn_metadata is not None and decode_attn_metadata is not None):
|
||||||
|
batch_type = BatchType.MIXED
|
||||||
|
elif prefill_attn_metadata is not None:
|
||||||
|
batch_type = BatchType.PREFILL
|
||||||
|
else:
|
||||||
|
batch_type = BatchType.DECODE
|
||||||
|
|
||||||
|
attn_metadata = AttentionMetadata(
|
||||||
|
num_prefills=num_prefills,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
num_prefill_tokens=num_prefill_tokens,
|
||||||
|
num_decode_tokens=num_decode_tokens,
|
||||||
|
prefill_metadata=prefill_attn_metadata,
|
||||||
|
decode_metadata=decode_attn_metadata,
|
||||||
|
kv_cache_dtype=self.kv_cache_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping,
|
||||||
|
multi_modal_input)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
(input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping,
|
||||||
|
multi_modal_input) = self.prepare_input_tensors(seq_group_metadata_list)
|
||||||
|
|
||||||
|
if self.lora_config:
|
||||||
|
self.set_active_loras(lora_requests, lora_mapping)
|
||||||
|
|
||||||
|
# Currently cuda graph is only supported by the decode phase.
|
||||||
|
prefill_meta = attn_metadata.prefill_metadata
|
||||||
|
decode_meta = attn_metadata.decode_metadata
|
||||||
|
if prefill_meta is None and decode_meta.use_cuda_graph:
|
||||||
|
graph_batch_size = input_tokens.shape[0]
|
||||||
|
model_executable = self.graph_runners[graph_batch_size]
|
||||||
|
else:
|
||||||
|
model_executable = self.model
|
||||||
|
execute_model_kwargs = {
|
||||||
|
"input_ids": input_tokens,
|
||||||
|
"positions": input_positions,
|
||||||
|
"kv_caches": kv_caches,
|
||||||
|
"attn_metadata": attn_metadata,
|
||||||
|
}
|
||||||
|
if self.vision_language_config:
|
||||||
|
execute_model_kwargs.update({"image_input": multi_modal_input})
|
||||||
|
hidden_states = model_executable(**execute_model_kwargs)
|
||||||
|
|
||||||
|
# Compute the logits.
|
||||||
|
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||||
|
|
||||||
|
# Only perform sampling in the driver worker.
|
||||||
|
# if not self.is_driver_worker:
|
||||||
|
# return None
|
||||||
|
|
||||||
|
# TODO(sgm): perform sampling on rank 0
|
||||||
|
# Sample the next token.
|
||||||
|
output = self.model.sample(
|
||||||
|
logits=logits,
|
||||||
|
sampling_metadata=sampling_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
294
verl/third_party/vllm/vllm_v_0_4_2/parallel_state.py
vendored
Normal file
294
verl/third_party/vllm/vllm_v_0_4_2/parallel_state.py
vendored
Normal file
@@ -0,0 +1,294 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
||||||
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
"""Model and data parallel groups."""
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import vllm.distributed.parallel_state as ps
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
from torch.distributed.device_mesh import init_device_mesh
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
"""
|
||||||
|
This version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron.
|
||||||
|
- We assume the Megatron tp+dp+pp world is already established before calling this function.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Device mesh for using DTensor
|
||||||
|
_DEVICE_MESH = None
|
||||||
|
|
||||||
|
# Tensor model parallel group that the current rank belongs to.
|
||||||
|
_TP_DEVICE_GROUP = None
|
||||||
|
_TP_CPU_GROUP = None
|
||||||
|
|
||||||
|
|
||||||
|
# This method is for initializing the ParallelGroup when using HybridEngine
|
||||||
|
def initialize_parallel_state(
|
||||||
|
distributed_init_method: str = "env://",
|
||||||
|
backend: str = "nccl",
|
||||||
|
tensor_model_parallel_size: int = 1,
|
||||||
|
num_tp_per_train_tp: int = 1,
|
||||||
|
pipeline_model_parallel_size: int = 1,
|
||||||
|
):
|
||||||
|
# torch.distributed.all_reduce does not free the input tensor until
|
||||||
|
# the synchronization point. This causes the memory usage to grow
|
||||||
|
# as the number of all_reduce calls increases. This env var disables
|
||||||
|
# this behavior.
|
||||||
|
# Related issue:
|
||||||
|
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
|
||||||
|
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||||
|
|
||||||
|
# NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN.
|
||||||
|
rank = int(os.getenv("RANK", "-1"))
|
||||||
|
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||||
|
|
||||||
|
# Use the world_size set by TORCHRUN
|
||||||
|
world_size = int(os.getenv("WORLD_SIZE", "-1"))
|
||||||
|
assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN"
|
||||||
|
ps.init_distributed_environment(world_size, rank, distributed_init_method, local_rank, backend)
|
||||||
|
if torch.distributed.get_world_size() > 1:
|
||||||
|
# NOTE: build a sepearate inference group with infer tp & micro dp
|
||||||
|
initialize_model_parallel_for_vllm(tensor_model_parallel_size=tensor_model_parallel_size,
|
||||||
|
num_tensor_model_parallel_groups_per_train_tp=num_tp_per_train_tp)
|
||||||
|
else:
|
||||||
|
initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_model_parallel_initialized(
|
||||||
|
tensor_model_parallel_size: int,
|
||||||
|
pipeline_model_parallel_size: int = 1,
|
||||||
|
backend: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Helper to initialize model parallel groups if they are not initialized,
|
||||||
|
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
|
||||||
|
values if the model parallel groups are initialized.
|
||||||
|
"""
|
||||||
|
# get the backend of _DEVICE_WORLD_GROUP
|
||||||
|
backend = backend or torch.distributed.get_backend()
|
||||||
|
if not model_parallel_is_initialized():
|
||||||
|
initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend)
|
||||||
|
return
|
||||||
|
|
||||||
|
assert (get_tensor_model_parallel_world_size() == tensor_model_parallel_size), (
|
||||||
|
"tensor parallel group already initialized, but of unexpected size: "
|
||||||
|
f"{get_tensor_model_parallel_world_size()=} vs. "
|
||||||
|
f"{tensor_model_parallel_size=}")
|
||||||
|
# assert (get_pipeline_model_parallel_world_size(
|
||||||
|
# ) == pipeline_model_parallel_size), (
|
||||||
|
# "pipeline parallel group already initialized, but of unexpected size: "
|
||||||
|
# f"{get_pipeline_model_parallel_world_size()=} vs. "
|
||||||
|
# f"{pipeline_model_parallel_size=}")
|
||||||
|
|
||||||
|
|
||||||
|
def model_parallel_is_initialized():
|
||||||
|
"""Check if tensor and pipeline parallel groups are initialized."""
|
||||||
|
return (ps._TP_DEVICE_GROUP is not None)
|
||||||
|
# and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_model_parallel_for_vllm(tensor_model_parallel_size: int,
|
||||||
|
num_tensor_model_parallel_groups_per_train_tp: int = 1) -> None:
|
||||||
|
from torch.distributed import new_group
|
||||||
|
# Get world size and rank. Ensure some consistencies.
|
||||||
|
assert torch.distributed.is_initialized()
|
||||||
|
|
||||||
|
assert isinstance(tensor_model_parallel_size, int)
|
||||||
|
|
||||||
|
# assert num_tensor_model_parallel_groups_per_train_tp == 1 and not different_tp_group
|
||||||
|
# assert num_tensor_model_parallel_groups_per_train_tp > 1 and different_tp_group
|
||||||
|
|
||||||
|
# Build the tensor model-parallel groups.
|
||||||
|
assert ps._TP_DEVICE_GROUP is None, ("tensor model parallel group is already initialized")
|
||||||
|
|
||||||
|
global _TP_DEVICE_GROUP
|
||||||
|
global _TP_CPU_GROUP
|
||||||
|
global _DEVICE_MESH
|
||||||
|
|
||||||
|
world_size: int = torch.distributed.get_world_size()
|
||||||
|
|
||||||
|
rank = torch.distributed.get_rank()
|
||||||
|
|
||||||
|
backend = torch.distributed.get_backend()
|
||||||
|
|
||||||
|
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
|
||||||
|
|
||||||
|
if num_tensor_model_parallel_groups_per_train_tp == 1:
|
||||||
|
# if tensor_model_parallel_size == train_tensor_parallel_size:
|
||||||
|
# using the same tp group as Megatron/vllm
|
||||||
|
for i in range(num_tensor_model_parallel_groups):
|
||||||
|
ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
|
||||||
|
group = torch.distributed.new_group(ranks, backend=backend)
|
||||||
|
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
|
||||||
|
if rank in ranks:
|
||||||
|
_TP_DEVICE_GROUP = group
|
||||||
|
_TP_CPU_GROUP = cpu_group
|
||||||
|
ps._TP_DEVICE_GROUP = group
|
||||||
|
ps._TP_CPU_GROUP = cpu_group
|
||||||
|
|
||||||
|
# no _MICRO_DATA_PARALLEL_GROUP
|
||||||
|
else:
|
||||||
|
# initialize a micro_dp group and a tp group
|
||||||
|
# assume training tp=4, infer tp=2, then, weight is partitioned as
|
||||||
|
# [1], [2], [3], [4] for training and [1,2], [1,2], [3,4], [3,4] for inference
|
||||||
|
|
||||||
|
# Build the inference tp groups
|
||||||
|
# train_tp = train_tensor_parallel_size
|
||||||
|
train_tp = num_tensor_model_parallel_groups_per_train_tp * tensor_model_parallel_size
|
||||||
|
# num_tensor_model_parallel_groups_per_train_tp = train_tp // tensor_model_parallel_size
|
||||||
|
assert _TP_DEVICE_GROUP is None, ("tensor model parallel group is already initialized")
|
||||||
|
for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp):
|
||||||
|
start = train_tp * i
|
||||||
|
end = train_tp * (i + 1)
|
||||||
|
for j in range(num_tensor_model_parallel_groups_per_train_tp):
|
||||||
|
ranks = list(range(start, end, num_tensor_model_parallel_groups_per_train_tp))
|
||||||
|
for i in range(len(ranks)):
|
||||||
|
ranks[i] += j
|
||||||
|
group = torch.distributed.new_group(ranks)
|
||||||
|
cpu_group = torch.distributed.new_group(ranks, backend='gloo')
|
||||||
|
if rank in ranks:
|
||||||
|
_TP_DEVICE_GROUP = group
|
||||||
|
_TP_CPU_GROUP = cpu_group
|
||||||
|
ps._TP_DEVICE_GROUP = _TP_DEVICE_GROUP
|
||||||
|
ps._TP_CPU_GROUP = cpu_group
|
||||||
|
|
||||||
|
# Build the pipeline model-parallel groups.
|
||||||
|
# global _PIPELINE_MODEL_PARALLEL_GROUP
|
||||||
|
# global _PIPELINE_GLOBAL_RANKS
|
||||||
|
# assert ps._PIPELINE_MODEL_PARALLEL_GROUP is None, ("pipeline model parallel group is already initialized")
|
||||||
|
|
||||||
|
# ps._PIPELINE_MODEL_PARALLEL_GROUP = mpu.get_pipeline_model_parallel_group()
|
||||||
|
# ps._PIPELINE_GLOBAL_RANKS = mpu.get_pipeline_model_parallel_ranks()
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_model_parallel(
|
||||||
|
tensor_model_parallel_size: int = 1,
|
||||||
|
pipeline_model_parallel_size: int = 1,
|
||||||
|
backend: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
NOTE: This method is a hack from the open-sourced version without
|
||||||
|
asertion of world_size = tp * pp
|
||||||
|
|
||||||
|
Initialize model parallel groups.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
tensor_model_parallel_size: number of GPUs used for tensor model
|
||||||
|
parallelism.
|
||||||
|
pipeline_model_parallel_size: number of GPUs used for pipeline model
|
||||||
|
parallelism.
|
||||||
|
|
||||||
|
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
|
||||||
|
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
|
||||||
|
the model pipeline. The present function will
|
||||||
|
create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
|
||||||
|
4 tensor model-parallel groups:
|
||||||
|
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
|
||||||
|
2 pipeline model-parallel groups:
|
||||||
|
[g0, g2, g4, g6], [g1, g3, g5, g7]
|
||||||
|
Note that for efficiency, the caller should make sure adjacent ranks
|
||||||
|
are on the same DGX box. For example if we are using 2 DGX-1 boxes
|
||||||
|
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
|
||||||
|
ranks 8 to 15 belong to the second box.
|
||||||
|
"""
|
||||||
|
# Get world size and rank. Ensure some consistencies.
|
||||||
|
assert torch.distributed.is_initialized()
|
||||||
|
world_size: int = torch.distributed.get_world_size()
|
||||||
|
# get the backend of _DEVICE_WORLD_GROUP
|
||||||
|
backend = backend or torch.distributed.get_backend()
|
||||||
|
|
||||||
|
# NOTE(sgm) we don't assert world_size == tp * pp
|
||||||
|
# DP is not managed by vllm but by the veRL WorkerGroup
|
||||||
|
|
||||||
|
num_tensor_model_parallel_groups: int = (world_size // tensor_model_parallel_size)
|
||||||
|
num_pipeline_model_parallel_groups: int = (world_size // pipeline_model_parallel_size)
|
||||||
|
rank = torch.distributed.get_rank()
|
||||||
|
|
||||||
|
# Build device mesh for TP
|
||||||
|
if num_tensor_model_parallel_groups > 1:
|
||||||
|
device_mesh = init_device_mesh("cuda", (num_tensor_model_parallel_groups, tensor_model_parallel_size),
|
||||||
|
mesh_dim_names=("replicate", "tp_shard"))
|
||||||
|
else:
|
||||||
|
device_mesh = init_device_mesh("cuda", (tensor_model_parallel_size,), mesh_dim_names=["tp_shard"])
|
||||||
|
shard_group = device_mesh.get_group(mesh_dim="tp_shard")
|
||||||
|
|
||||||
|
# Build the tensor model-parallel groups.
|
||||||
|
global _TP_DEVICE_GROUP, _TP_CPU_GROUP
|
||||||
|
global _DEVICE_MESH
|
||||||
|
assert _TP_DEVICE_GROUP is None, ("tensor model parallel group is already initialized")
|
||||||
|
assert _DEVICE_MESH is None, ("device mesh in vllm is already initialized")
|
||||||
|
|
||||||
|
_DEVICE_MESH = device_mesh
|
||||||
|
# for i in range(num_tensor_model_parallel_groups):
|
||||||
|
# ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
|
||||||
|
# group = torch.distributed.new_group(ranks, backend=backend)
|
||||||
|
# cpu_group = torch.distributed.new_group(ranks, backend="gloo")
|
||||||
|
# assert torch.distributed.get_process_group_ranks(shard_group) == torch.distributed.get_process_group_ranks(cpu_group)
|
||||||
|
# ranks = torch.distributed.get_process_group_ranks(shard_group)
|
||||||
|
# cpu_group = torch.distributed.new_group(ranks, backend="gloo") # TODO: this will hang
|
||||||
|
# cpu_group = torch.distributed.new_group(, backend="gloo")
|
||||||
|
# if rank == 0:
|
||||||
|
# print(f'rank: {rank}')
|
||||||
|
# print(f'ranks: {ranks}')
|
||||||
|
# print(f'torch.distributed.get_process_group_ranks(shard_group): {torch.distributed.get_process_group_ranks(shard_group)}')
|
||||||
|
# if rank in ranks:
|
||||||
|
_TP_DEVICE_GROUP = shard_group
|
||||||
|
ps._TP_DEVICE_GROUP = _TP_DEVICE_GROUP
|
||||||
|
# ps._TP_CPU_GROUP = cpu_group # TODO: will hang when used with device mesh
|
||||||
|
|
||||||
|
# TODO: init using device mesh
|
||||||
|
# Build the pipeline model-parallel groups.
|
||||||
|
assert ps._PIPELINE_MODEL_PARALLEL_GROUP is None, ("pipeline model parallel group is already initialized")
|
||||||
|
for i in range(num_pipeline_model_parallel_groups):
|
||||||
|
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
|
||||||
|
group = torch.distributed.new_group(ranks, backend=backend)
|
||||||
|
if rank in ranks:
|
||||||
|
ps._PIPELINE_MODEL_PARALLEL_GROUP = group
|
||||||
|
ps._PIPELINE_GLOBAL_RANKS = ranks
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Device mesh utilities
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_mesh():
|
||||||
|
assert _DEVICE_MESH is not None, ("device mesh is not initialized")
|
||||||
|
return _DEVICE_MESH
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Tensor model parallel utilities
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_tensor_model_parallel_group():
|
||||||
|
"""Get the tensor model parallel group the caller rank belongs to."""
|
||||||
|
assert _TP_DEVICE_GROUP is not None, ("tensor model parallel group is not initialized")
|
||||||
|
return _TP_DEVICE_GROUP
|
||||||
|
|
||||||
|
|
||||||
|
def get_tensor_model_parallel_world_size():
|
||||||
|
"""Return world size for the tensor model parallel group."""
|
||||||
|
return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
|
||||||
|
|
||||||
|
|
||||||
|
def get_tensor_model_parallel_rank():
|
||||||
|
"""Return my rank for the tensor model parallel group."""
|
||||||
|
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
|
||||||
|
|
||||||
|
|
||||||
|
def get_tensor_model_parallel_src_rank():
|
||||||
|
"""Calculate the global rank corresponding to the first local rank
|
||||||
|
in the tensor model parallel group."""
|
||||||
|
global_rank = torch.distributed.get_rank()
|
||||||
|
local_world_size = get_tensor_model_parallel_world_size()
|
||||||
|
return (global_rank // local_world_size) * local_world_size
|
||||||
218
verl/third_party/vllm/vllm_v_0_4_2/spmd_gpu_executor.py
vendored
Normal file
218
verl/third_party/vllm/vllm_v_0_4_2/spmd_gpu_executor.py
vendored
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/executor/gpu_executor.py
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.executor.executor_base import ExecutorBase, ExecutorAsyncBase
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.sequence import SamplerOutput, ExecuteModelRequest
|
||||||
|
|
||||||
|
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||||
|
VisionLanguageConfig)
|
||||||
|
from .config import ModelConfig, LoadConfig
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SPMDGPUExecutor(ExecutorBase):
|
||||||
|
"""SPMD-based multi-GPU executor implementations."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model, # pytorch model itself or its parameter dict
|
||||||
|
model_config: ModelConfig,
|
||||||
|
cache_config: CacheConfig,
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
scheduler_config: SchedulerConfig,
|
||||||
|
device_config: DeviceConfig,
|
||||||
|
load_config: LoadConfig,
|
||||||
|
lora_config: Optional[LoRAConfig],
|
||||||
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
|
speculative_config: Optional[SpeculativeConfig],
|
||||||
|
) -> None:
|
||||||
|
self.model_config = model_config
|
||||||
|
self.cache_config = cache_config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
self.load_config = load_config
|
||||||
|
self.parallel_config = parallel_config
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.device_config = device_config
|
||||||
|
self.vision_language_config = vision_language_config
|
||||||
|
self.speculative_config = speculative_config
|
||||||
|
|
||||||
|
distributed_init_method = initialize_cluster(parallel_config)
|
||||||
|
self._init_executor(model, distributed_init_method)
|
||||||
|
|
||||||
|
# TODO(sgm): verl not support speculative decode now
|
||||||
|
def _init_executor(self, model, distributed_init_method) -> None:
|
||||||
|
assert (not self.speculative_config), "Speculative decoding not yet supported for multi-GPU backend."
|
||||||
|
|
||||||
|
# Create the parallel worker for each GPU.
|
||||||
|
self._init_workers_sp(model, distributed_init_method)
|
||||||
|
|
||||||
|
def _init_workers_sp(self, model, distributed_init_method: str):
|
||||||
|
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||||
|
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||||
|
from .worker import Worker # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
|
rank = int(os.getenv("RANK"))
|
||||||
|
local_rank = int(os.getenv("LOCAL_RANK"))
|
||||||
|
print(f'local rank {local_rank}')
|
||||||
|
|
||||||
|
self.worker = Worker(
|
||||||
|
model,
|
||||||
|
self.model_config,
|
||||||
|
self.parallel_config,
|
||||||
|
self.scheduler_config,
|
||||||
|
self.device_config,
|
||||||
|
self.cache_config,
|
||||||
|
self.load_config,
|
||||||
|
local_rank,
|
||||||
|
rank,
|
||||||
|
distributed_init_method,
|
||||||
|
lora_config=self.lora_config,
|
||||||
|
vision_language_config=self.vision_language_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE(shengguangming): torch.distributed.init_process_group will be called inside the init_model()
|
||||||
|
self.worker.init_device()
|
||||||
|
self.worker.load_model()
|
||||||
|
|
||||||
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
|
"""Determine the number of available KV blocks.
|
||||||
|
|
||||||
|
This invokes `determine_num_available_blocks` on each worker and takes
|
||||||
|
the min of the results, guaranteeing that the selected cache sizes are
|
||||||
|
compatible with all workers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- tuple[num_gpu_blocks, num_cpu_blocks]
|
||||||
|
"""
|
||||||
|
# Get the maximum number of blocks that can be allocated on GPU and CPU.
|
||||||
|
num_blocks = self.worker.determine_num_available_blocks()
|
||||||
|
|
||||||
|
# NOTE(shengguangming): Now we don't use a shared centralized controler but each process will
|
||||||
|
# have its own scheduler
|
||||||
|
num_gpu_blocks = num_blocks[0]
|
||||||
|
num_cpu_blocks = num_blocks[1]
|
||||||
|
|
||||||
|
return num_gpu_blocks, num_cpu_blocks
|
||||||
|
|
||||||
|
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
|
||||||
|
"""Initialize the KV cache in all workers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# NOTE: We log here to avoid multiple logs when number of workers is
|
||||||
|
# greater than one. We could log in the engine, but not all executors
|
||||||
|
# have GPUs.
|
||||||
|
logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, num_cpu_blocks)
|
||||||
|
|
||||||
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||||
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||||
|
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
print(
|
||||||
|
f'before init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB'
|
||||||
|
)
|
||||||
|
self.worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks)
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
print(
|
||||||
|
f'after init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB'
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE(sgm): This will not profile & capture the model(CUDAGraph) when rebuilding KVCache
|
||||||
|
def init_cache_engine(self) -> None:
|
||||||
|
self.worker._init_cache_engine()
|
||||||
|
|
||||||
|
def free_cache_engine(self) -> None:
|
||||||
|
self.worker.free_cache_engine()
|
||||||
|
|
||||||
|
def execute_model(self, execute_model_req) -> List[SamplerOutput]:
|
||||||
|
all_outputs = self.worker.execute_model(execute_model_req=execute_model_req)
|
||||||
|
|
||||||
|
# NOTE(sgm):
|
||||||
|
# Each GPU in vllm under verl has its own spmd_gpu_executor, therefore all GPUs should return the outputs
|
||||||
|
# In vllm with ray, only the driver worker returns the sampling results.
|
||||||
|
return all_outputs
|
||||||
|
|
||||||
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
|
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
|
||||||
|
return self.worker.add_lora(lora_request=lora_request)
|
||||||
|
|
||||||
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
|
assert lora_id > 0, "lora_id must be greater than 0."
|
||||||
|
return self.worker.remove_lora(lora_id=lora_id)
|
||||||
|
|
||||||
|
def list_loras(self) -> Set[int]:
|
||||||
|
return self.worker.list_loras()
|
||||||
|
|
||||||
|
def check_health(self) -> None:
|
||||||
|
# SPMDExecutor will always be healthy as long as
|
||||||
|
# it's running.
|
||||||
|
return
|
||||||
|
|
||||||
|
# NOTE(sgm): add for verl
|
||||||
|
def offload_model_weights(self) -> None:
|
||||||
|
self.worker.offload_model_weights()
|
||||||
|
|
||||||
|
def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], load_format: str) -> None:
|
||||||
|
self.worker.sync_model_weights(actor_weights=actor_weights, load_format=load_format)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_cluster(
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
engine_use_ray: bool = False,
|
||||||
|
ray_address: Optional[str] = None,
|
||||||
|
) -> Tuple[str, Optional[None]]:
|
||||||
|
"""Initialize the distributed cluster probably with Ray.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parallel_config: The configurations for parallel execution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The `distributed_init_method` is the address for initializing the
|
||||||
|
distributed backend.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Initialize cluster locally.
|
||||||
|
port = get_open_port()
|
||||||
|
# We need to setup the distributed init method to make sure
|
||||||
|
# the distributed megatron code (e.g., get world size) works correctly.
|
||||||
|
# distributed_init_method = f"tcp://localhost:{port}"
|
||||||
|
distributed_init_method = 'env://'
|
||||||
|
return distributed_init_method
|
||||||
|
|
||||||
|
|
||||||
|
def get_open_port():
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.bind(("", 0))
|
||||||
|
return s.getsockname()[1]
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(sgm): not implemented async executor yet
|
||||||
|
class SPMDGPUExecutorAsync(SPMDGPUExecutor, ExecutorAsyncBase):
|
||||||
|
|
||||||
|
async def execute_model_async(self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||||
|
"""Executes one model step on the given sequences."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def check_health_async(self) -> None:
|
||||||
|
"""Checks if the executor is healthy. If not, it should raise an
|
||||||
|
exception."""
|
||||||
|
self.check_health()
|
||||||
77
verl/third_party/vllm/vllm_v_0_4_2/tokenizer.py
vendored
Normal file
77
verl/third_party/vllm/vllm_v_0_4_2/tokenizer.py
vendored
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast)
|
||||||
|
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.utils import make_async, LRUCache
|
||||||
|
from vllm.transformers_utils.tokenizers import *
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerGroup:
|
||||||
|
"""A group of tokenizers that can be used for LoRA adapters."""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int,
|
||||||
|
max_input_length: Optional[int]):
|
||||||
|
self.enable_lora = enable_lora
|
||||||
|
self.max_input_length = max_input_length
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.lora_tokenizers = LRUCache[PreTrainedTokenizer](capacity=max_num_seqs) if enable_lora else None
|
||||||
|
|
||||||
|
def ping(self) -> bool:
|
||||||
|
"""Check if the tokenizer group is alive."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_max_input_len(self, lora_request: Optional[LoRARequest] = None) -> Optional[int]:
|
||||||
|
"""Get the maximum input length for the LoRA request."""
|
||||||
|
return self.max_input_length
|
||||||
|
|
||||||
|
def encode(self,
|
||||||
|
prompt: str,
|
||||||
|
request_id: Optional[str] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None) -> List[int]:
|
||||||
|
tokenizer = self.get_lora_tokenizer(lora_request)
|
||||||
|
return tokenizer.encode(prompt)
|
||||||
|
|
||||||
|
async def encode_async(self,
|
||||||
|
prompt: str,
|
||||||
|
request_id: Optional[str] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None) -> List[int]:
|
||||||
|
tokenizer = await self.get_lora_tokenizer_async(lora_request)
|
||||||
|
return tokenizer.encode(prompt)
|
||||||
|
|
||||||
|
def get_lora_tokenizer(self, lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
|
||||||
|
if not lora_request or not self.enable_lora:
|
||||||
|
return self.tokenizer
|
||||||
|
if lora_request.lora_int_id not in self.lora_tokenizers:
|
||||||
|
# TODO(sgm): the lora tokenizer is also passed, but may be different
|
||||||
|
tokenizer = self.tokenizer
|
||||||
|
# tokenizer = (get_lora_tokenizer(
|
||||||
|
# lora_request, **self.tokenizer_config) or self.tokenizer)
|
||||||
|
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
|
||||||
|
return tokenizer
|
||||||
|
else:
|
||||||
|
return self.lora_tokenizers.get(lora_request.lora_int_id)
|
||||||
|
|
||||||
|
# FIXME(sgm): for simplicity, we assign the special token here
|
||||||
|
@property
|
||||||
|
def pad_token_id(self):
|
||||||
|
return self.tokenizer.pad_token_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eos_token_id(self):
|
||||||
|
return self.tokenizer.eos_token_id
|
||||||
292
verl/third_party/vllm/vllm_v_0_4_2/worker.py
vendored
Normal file
292
verl/third_party/vllm/vllm_v_0_4_2/worker.py
vendored
Normal file
@@ -0,0 +1,292 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/worker.py
|
||||||
|
"""A GPU worker class."""
|
||||||
|
import os
|
||||||
|
import gc
|
||||||
|
from typing import Dict, List, Tuple, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||||
|
from vllm.model_executor import set_random_seed
|
||||||
|
from vllm.sequence import SamplerOutput, ExecuteModelRequest
|
||||||
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
|
from vllm.distributed.device_communicators import pynccl_utils
|
||||||
|
from vllm.distributed.device_communicators.custom_all_reduce import (init_custom_ar)
|
||||||
|
# TODO(sgm): check why vllm has similar file in vllm.model_executor.parallel_utils.parallel_state
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_cpu_group, init_distributed_environment, get_tensor_model_parallel_group
|
||||||
|
from vllm.worker.worker import Worker, _check_if_gpu_supports_dtype
|
||||||
|
|
||||||
|
from .model_runner import ModelRunner
|
||||||
|
from .megatron_weight_loaders import load_megatron_weights
|
||||||
|
from .hf_weight_loader import load_hf_weights
|
||||||
|
from .dtensor_weight_loaders import load_dtensor_weights
|
||||||
|
from .parallel_state import (ensure_model_parallel_initialized)
|
||||||
|
from .config import ModelConfig, LoadConfig, LoadFormat
|
||||||
|
|
||||||
|
|
||||||
|
class Worker(Worker):
|
||||||
|
"""A worker class that executes (a partition of) the model on a GPU.
|
||||||
|
|
||||||
|
Each worker is associated with a single GPU. The worker is responsible for
|
||||||
|
maintaining the KV cache and executing the model on the GPU. In case of
|
||||||
|
distributed inference, each worker is assigned a partition of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Union[nn.Module, Dict], # model itself or its parameter dict
|
||||||
|
model_config: ModelConfig,
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
scheduler_config: SchedulerConfig,
|
||||||
|
device_config: DeviceConfig,
|
||||||
|
cache_config: CacheConfig,
|
||||||
|
load_config: LoadConfig,
|
||||||
|
local_rank: int,
|
||||||
|
rank: int,
|
||||||
|
distributed_init_method: str,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
|
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||||
|
is_driver_worker: bool = False,
|
||||||
|
) -> None:
|
||||||
|
# self.model = model # will be replaced in the init_model
|
||||||
|
self.model_config = model_config
|
||||||
|
self.parallel_config = parallel_config
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.device_config = device_config
|
||||||
|
self.cache_config = cache_config
|
||||||
|
self.local_rank = local_rank
|
||||||
|
self.rank = rank
|
||||||
|
self.distributed_init_method = distributed_init_method
|
||||||
|
self.lora_config = lora_config
|
||||||
|
self.load_config = load_config
|
||||||
|
self.is_driver_worker = is_driver_worker
|
||||||
|
if self.is_driver_worker:
|
||||||
|
assert self.rank == 0, "The driver worker must have rank 0."
|
||||||
|
|
||||||
|
self.vision_language_config = vision_language_config
|
||||||
|
if self.vision_language_config:
|
||||||
|
assert not self.lora_config, ("To be tested: vision language model with LoRA settings.")
|
||||||
|
|
||||||
|
self.model_runner = ModelRunner(
|
||||||
|
model,
|
||||||
|
model_config,
|
||||||
|
parallel_config,
|
||||||
|
scheduler_config,
|
||||||
|
device_config,
|
||||||
|
load_config=load_config,
|
||||||
|
lora_config=self.lora_config,
|
||||||
|
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||||
|
vision_language_config=vision_language_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Uninitialized cache engine. Will be initialized by
|
||||||
|
# init_cache_engine.
|
||||||
|
self.cache_engine: CacheEngine = None
|
||||||
|
self.gpu_cache: List[torch.Tensor] = None
|
||||||
|
|
||||||
|
# NOTE(sgm): For offloading inference engine params
|
||||||
|
self.cpu_model = None
|
||||||
|
|
||||||
|
def init_device(self) -> None:
|
||||||
|
if self.device_config.device.type == "cuda":
|
||||||
|
# torch.distributed.all_reduce does not free the input tensor until
|
||||||
|
# the synchronization point. This causes the memory usage to grow
|
||||||
|
# as the number of all_reduce calls increases. This env var disables
|
||||||
|
# this behavior.
|
||||||
|
# Related issue:
|
||||||
|
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
|
||||||
|
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||||
|
|
||||||
|
# NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN.
|
||||||
|
self.rank = self.rank if self.rank is not None else int(os.getenv("RANK", "-1"))
|
||||||
|
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||||
|
self.device = torch.device(f"cuda:{local_rank}")
|
||||||
|
if self.rank < 0:
|
||||||
|
raise ValueError("Invalid or unspecified rank.")
|
||||||
|
torch.cuda.set_device(self.device)
|
||||||
|
|
||||||
|
# Use the world_size set by TORCHRUN
|
||||||
|
world_size = int(os.getenv("WORLD_SIZE", "-1"))
|
||||||
|
assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN"
|
||||||
|
self.parallel_config.world_size = world_size
|
||||||
|
|
||||||
|
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Not support device type: {self.device_config.device}")
|
||||||
|
|
||||||
|
# Initialize the distributed environment.
|
||||||
|
init_worker_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method,
|
||||||
|
self.local_rank)
|
||||||
|
# Set random seed.
|
||||||
|
set_random_seed(self.model_config.seed)
|
||||||
|
# self.model = get_model(actor_model=self.model, model_config=self.model_config)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
|
"""Profiles the peak memory usage of the model to determine how many
|
||||||
|
KV blocks may be allocated without OOMs.
|
||||||
|
|
||||||
|
The engine will first conduct a profiling of the existing memory usage.
|
||||||
|
Then, it calculate the maximum possible number of GPU and CPU blocks
|
||||||
|
that can be allocated with the remaining free memory.
|
||||||
|
|
||||||
|
.. tip::
|
||||||
|
You may limit the usage of GPU memory
|
||||||
|
by adjusting the `gpu_memory_utilization` parameter.
|
||||||
|
"""
|
||||||
|
# Profile the memory usage of the model and get the maximum number of
|
||||||
|
# cache blocks that can be allocated with the remaining free memory.
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
# torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||||
|
# of the model.
|
||||||
|
self.model_runner.profile_run()
|
||||||
|
|
||||||
|
# Calculate the number of blocks that can be allocated with the
|
||||||
|
# profiled peak memory.
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
||||||
|
peak_memory = total_gpu_memory - free_gpu_memory
|
||||||
|
|
||||||
|
assert peak_memory > 0, ("Error in memory profiling. This happens when the GPU memory was "
|
||||||
|
"not properly cleaned up before initializing the vLLM instance.")
|
||||||
|
|
||||||
|
cache_block_size = self.get_cache_block_size_bytes()
|
||||||
|
|
||||||
|
# NOTE(sgm) use the remaining memory
|
||||||
|
num_gpu_blocks = int((free_gpu_memory * self.cache_config.gpu_memory_utilization) // cache_block_size)
|
||||||
|
# num_gpu_blocks = int((total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory) // cache_block_size)
|
||||||
|
|
||||||
|
num_cpu_blocks = int(self.cache_config.swap_space_bytes // cache_block_size)
|
||||||
|
num_gpu_blocks = max(num_gpu_blocks, 0)
|
||||||
|
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||||
|
if self.model_runner.lora_manager:
|
||||||
|
self.model_runner.remove_all_loras()
|
||||||
|
|
||||||
|
# NOTE(sgm): Add for verl, synchronize number of blocks with all the rank
|
||||||
|
num_gpu_blocks = torch.tensor([num_gpu_blocks], device='cuda')
|
||||||
|
num_cpu_blocks = torch.tensor([num_cpu_blocks], device='cuda')
|
||||||
|
torch.distributed.all_reduce(num_gpu_blocks,
|
||||||
|
op=torch.distributed.ReduceOp.MIN,
|
||||||
|
group=get_tensor_model_parallel_group())
|
||||||
|
torch.distributed.all_reduce(num_cpu_blocks,
|
||||||
|
op=torch.distributed.ReduceOp.MIN,
|
||||||
|
group=get_tensor_model_parallel_group())
|
||||||
|
num_gpu_blocks = num_gpu_blocks.item()
|
||||||
|
num_cpu_blocks = num_cpu_blocks.item()
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return num_gpu_blocks, num_cpu_blocks
|
||||||
|
|
||||||
|
def _init_cache_engine(self):
|
||||||
|
if self.cache_engine is None and self.gpu_cache is None:
|
||||||
|
super()._init_cache_engine()
|
||||||
|
|
||||||
|
def free_cache_engine(self):
|
||||||
|
# ensure `enforce_eager=True`
|
||||||
|
self.cache_engine = None
|
||||||
|
self.gpu_cache = None
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def execute_model(self, execute_model_req: Optional[ExecuteModelRequest] = None) -> List[SamplerOutput]:
|
||||||
|
|
||||||
|
if execute_model_req is None:
|
||||||
|
seq_group_metadata_list = None
|
||||||
|
else:
|
||||||
|
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||||
|
|
||||||
|
# NOTE(sgm): each SPMD rank will have identical input
|
||||||
|
assert seq_group_metadata_list is not None
|
||||||
|
assert execute_model_req is not None
|
||||||
|
num_seq_groups = len(seq_group_metadata_list)
|
||||||
|
blocks_to_swap_in = execute_model_req.blocks_to_swap_in
|
||||||
|
blocks_to_swap_out = execute_model_req.blocks_to_swap_out
|
||||||
|
blocks_to_copy = execute_model_req.blocks_to_copy
|
||||||
|
|
||||||
|
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
|
||||||
|
|
||||||
|
# If there is no input, we don't need to execute the model.
|
||||||
|
if num_seq_groups == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
output = self.model_runner.execute_model(seq_group_metadata_list, self.gpu_cache)
|
||||||
|
|
||||||
|
# Worker only supports single-step execution. Wrap the output in a list
|
||||||
|
# to conform to interface.
|
||||||
|
return [output]
|
||||||
|
|
||||||
|
# assume the input is .state_dict()
|
||||||
|
def sync_model_weights(self, actor_weights: Dict, load_format: str):
|
||||||
|
if load_format in [LoadFormat.MEGATRON, LoadFormat.AUTO]:
|
||||||
|
load_megatron_weights(actor_weights, self.model_runner.model)
|
||||||
|
elif load_format == LoadFormat.HF:
|
||||||
|
# full model state dict without no sharding
|
||||||
|
load_hf_weights(actor_weights, self.model_runner.model)
|
||||||
|
elif load_format == LoadFormat.DTENSOR:
|
||||||
|
load_dtensor_weights(actor_weights, self.model_runner.model)
|
||||||
|
|
||||||
|
def offload_model_weights(self) -> None:
|
||||||
|
if self.cpu_model == None:
|
||||||
|
self.cpu_model = {}
|
||||||
|
for name, params in self.model_runner.model.named_parameters():
|
||||||
|
self.cpu_model[name] = torch.empty_like(params, device='cpu')
|
||||||
|
params.data = self.cpu_model[name]
|
||||||
|
else:
|
||||||
|
for name, params in self.model_runner.model.named_parameters():
|
||||||
|
params.data = self.cpu_model[name]
|
||||||
|
|
||||||
|
|
||||||
|
def init_worker_distributed_environment(
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
rank: int,
|
||||||
|
distributed_init_method: Optional[str] = "env://",
|
||||||
|
local_rank: int = -1,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the distributed environment."""
|
||||||
|
# NOTE(sgm) use tcp://localhost:xxxx will hang in HF setting without megatron
|
||||||
|
init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank)
|
||||||
|
|
||||||
|
ensure_model_parallel_initialized(tensor_model_parallel_size=parallel_config.tensor_parallel_size,
|
||||||
|
pipeline_model_parallel_size=parallel_config.pipeline_parallel_size)
|
||||||
|
|
||||||
|
# TODO(sgm): check whether need this
|
||||||
|
# if pynccl_utils.is_initialized():
|
||||||
|
# pynccl_world_size = pynccl_utils.get_world_size()
|
||||||
|
# if pynccl_world_size != parallel_config.world_size:
|
||||||
|
# raise RuntimeError(
|
||||||
|
# "pynccl is already initialized but the pynccl world "
|
||||||
|
# "size does not match parallel_config.world_size "
|
||||||
|
# f"({pynccl_world_size} vs. {parallel_config.world_size}).")
|
||||||
|
# elif parallel_config.world_size > 1:
|
||||||
|
# # NOTE(woosuk): We don't initialize pynccl process group when world size
|
||||||
|
# # is 1.
|
||||||
|
# # NOTE(kaichao): By default, pynccl is initialized for tp group.
|
||||||
|
# pynccl_utils.init_process_group(
|
||||||
|
# group=get_tensor_model_parallel_cpu_group())
|
||||||
|
|
||||||
|
# # Initialize a custom fast all-reduce implementation.
|
||||||
|
# if not parallel_config.disable_custom_all_reduce:
|
||||||
|
# init_custom_ar()
|
||||||
|
|
||||||
|
# A small all_reduce for warmup.
|
||||||
|
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
||||||
|
# if pynccl_utils.is_initialized():
|
||||||
|
# pynccl_utils.all_reduce(torch.zeros(1).cuda())
|
||||||
13
verl/third_party/vllm/vllm_v_0_5_4/__init__.py
vendored
Normal file
13
verl/third_party/vllm/vllm_v_0_5_4/__init__.py
vendored
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
453
verl/third_party/vllm/vllm_v_0_5_4/arg_utils.py
vendored
Normal file
453
verl/third_party/vllm/vllm_v_0_5_4/arg_utils.py
vendored
Normal file
@@ -0,0 +1,453 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import dataclasses
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
from .config import ModelConfig, LoadConfig
|
||||||
|
|
||||||
|
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoRAConfig, MultiModalConfig,
|
||||||
|
ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig,
|
||||||
|
TokenizerPoolConfig)
|
||||||
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
|
from vllm.utils import str_to_int_tuple
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (BaseTokenizerGroup)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def nullable_str(val: str):
|
||||||
|
if not val or val == "None":
|
||||||
|
return None
|
||||||
|
return val
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EngineArgs:
|
||||||
|
"""Arguments for vLLM engine."""
|
||||||
|
model_hf_config: PretrainedConfig = None # for verl
|
||||||
|
served_model_name = None # TODO(sgm): check this
|
||||||
|
# tokenizer: Optional[str] = None # TODO(sgm): check this
|
||||||
|
skip_tokenizer_init: bool = False
|
||||||
|
tokenizer_mode: str = 'auto'
|
||||||
|
trust_remote_code: bool = False
|
||||||
|
download_dir: Optional[str] = None
|
||||||
|
load_format: str = 'auto'
|
||||||
|
dtype: str = 'auto'
|
||||||
|
kv_cache_dtype: str = 'auto'
|
||||||
|
quantization_param_path: Optional[str] = None
|
||||||
|
seed: int = 0
|
||||||
|
max_model_len: Optional[int] = None
|
||||||
|
worker_use_ray: bool = False
|
||||||
|
# Note: Specifying a custom executor backend by passing a class
|
||||||
|
# is intended for expert use only. The API may change without
|
||||||
|
# notice.
|
||||||
|
distributed_executor_backend: Optional[Union[str, Type[ExecutorBase]]] = None
|
||||||
|
pipeline_parallel_size: int = 1
|
||||||
|
tensor_parallel_size: int = 1
|
||||||
|
max_parallel_loading_workers: Optional[int] = None
|
||||||
|
block_size: int = 16
|
||||||
|
enable_prefix_caching: bool = False
|
||||||
|
disable_sliding_window: bool = False
|
||||||
|
use_v2_block_manager: bool = False
|
||||||
|
swap_space: int = 4 # GiB
|
||||||
|
cpu_offload_gb: int = 0 # GiB
|
||||||
|
gpu_memory_utilization: float = 0.90
|
||||||
|
max_num_batched_tokens: Optional[int] = None
|
||||||
|
max_num_seqs: int = 256
|
||||||
|
max_logprobs: int = 20 # Default value for OpenAI Chat Completions API
|
||||||
|
disable_log_stats: bool = False
|
||||||
|
revision: Optional[str] = None
|
||||||
|
code_revision: Optional[str] = None
|
||||||
|
rope_scaling: Optional[dict] = None
|
||||||
|
rope_theta: Optional[float] = None
|
||||||
|
tokenizer_revision: Optional[str] = None
|
||||||
|
quantization: Optional[str] = None
|
||||||
|
enforce_eager: bool = False
|
||||||
|
max_context_len_to_capture: Optional[int] = None
|
||||||
|
max_seq_len_to_capture: int = 8192
|
||||||
|
disable_custom_all_reduce: bool = False
|
||||||
|
tokenizer_pool_size: int = 0
|
||||||
|
# Note: Specifying a tokenizer pool by passing a class
|
||||||
|
# is intended for expert use only. The API may change without
|
||||||
|
# notice.
|
||||||
|
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
|
||||||
|
tokenizer_pool_extra_config: Optional[dict] = None
|
||||||
|
enable_lora: bool = False
|
||||||
|
max_loras: int = 1
|
||||||
|
max_lora_rank: int = 16
|
||||||
|
enable_prompt_adapter: bool = False
|
||||||
|
max_prompt_adapters: int = 1
|
||||||
|
max_prompt_adapter_token: int = 0
|
||||||
|
fully_sharded_loras: bool = False
|
||||||
|
lora_extra_vocab_size: int = 256
|
||||||
|
long_lora_scaling_factors: Optional[Tuple[float]] = None
|
||||||
|
lora_dtype: str = 'auto'
|
||||||
|
max_cpu_loras: Optional[int] = None
|
||||||
|
device: str = 'auto'
|
||||||
|
ray_workers_use_nsight: bool = False
|
||||||
|
num_gpu_blocks_override: Optional[int] = None
|
||||||
|
num_lookahead_slots: int = 0
|
||||||
|
model_loader_extra_config: Optional[dict] = None
|
||||||
|
ignore_patterns: Optional[Union[str, List[str]]] = None
|
||||||
|
preemption_mode: Optional[str] = None
|
||||||
|
|
||||||
|
scheduler_delay_factor: float = 0.0
|
||||||
|
enable_chunked_prefill: Optional[bool] = None
|
||||||
|
|
||||||
|
guided_decoding_backend: str = 'outlines'
|
||||||
|
# Speculative decoding configuration.
|
||||||
|
speculative_model: Optional[str] = None
|
||||||
|
speculative_draft_tensor_parallel_size: Optional[int] = None
|
||||||
|
num_speculative_tokens: Optional[int] = None
|
||||||
|
speculative_max_model_len: Optional[int] = None
|
||||||
|
speculative_disable_by_batch_size: Optional[int] = None
|
||||||
|
ngram_prompt_lookup_max: Optional[int] = None
|
||||||
|
ngram_prompt_lookup_min: Optional[int] = None
|
||||||
|
spec_decoding_acceptance_method: str = 'rejection_sampler'
|
||||||
|
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
|
||||||
|
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
|
||||||
|
qlora_adapter_name_or_path: Optional[str] = None
|
||||||
|
disable_logprobs_during_spec_decoding: Optional[bool] = None
|
||||||
|
|
||||||
|
otlp_traces_endpoint: Optional[str] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||||
|
"""Shared CLI arguments for vLLM engine."""
|
||||||
|
# Model arguments
|
||||||
|
# TODO(shengguangming): delete the unused args
|
||||||
|
parser.add_argument('--model',
|
||||||
|
type=str,
|
||||||
|
default='facebook/opt-125m',
|
||||||
|
help='name or path of the huggingface model to use')
|
||||||
|
parser.add_argument('--tokenizer',
|
||||||
|
type=str,
|
||||||
|
default=EngineArgs.tokenizer,
|
||||||
|
help='name or path of the huggingface tokenizer to use')
|
||||||
|
parser.add_argument('--revision',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='the specific model version to use. It can be a branch '
|
||||||
|
'name, a tag name, or a commit id. If unspecified, will use '
|
||||||
|
'the default version.')
|
||||||
|
parser.add_argument('--tokenizer-revision',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='the specific tokenizer version to use. It can be a branch '
|
||||||
|
'name, a tag name, or a commit id. If unspecified, will use '
|
||||||
|
'the default version.')
|
||||||
|
parser.add_argument('--tokenizer-mode',
|
||||||
|
type=str,
|
||||||
|
default=EngineArgs.tokenizer_mode,
|
||||||
|
choices=['auto', 'slow'],
|
||||||
|
help='tokenizer mode. "auto" will use the fast '
|
||||||
|
'tokenizer if available, and "slow" will '
|
||||||
|
'always use the slow tokenizer.')
|
||||||
|
parser.add_argument('--trust-remote-code', action='store_true', help='trust remote code from huggingface')
|
||||||
|
parser.add_argument('--download-dir',
|
||||||
|
type=str,
|
||||||
|
default=EngineArgs.download_dir,
|
||||||
|
help='directory to download and load the weights, '
|
||||||
|
'default to the default cache dir of '
|
||||||
|
'huggingface')
|
||||||
|
parser.add_argument('--load-format',
|
||||||
|
type=str,
|
||||||
|
default=EngineArgs.load_format,
|
||||||
|
choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
|
||||||
|
help='The format of the model weights to load. '
|
||||||
|
'"auto" will try to load the weights in the safetensors format '
|
||||||
|
'and fall back to the pytorch bin format if safetensors format '
|
||||||
|
'is not available. '
|
||||||
|
'"pt" will load the weights in the pytorch bin format. '
|
||||||
|
'"safetensors" will load the weights in the safetensors format. '
|
||||||
|
'"npcache" will load the weights in pytorch format and store '
|
||||||
|
'a numpy cache to speed up the loading. '
|
||||||
|
'"dummy" will initialize the weights with random values, '
|
||||||
|
'which is mainly for profiling.')
|
||||||
|
parser.add_argument('--dtype',
|
||||||
|
type=str,
|
||||||
|
default=EngineArgs.dtype,
|
||||||
|
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
|
||||||
|
help='data type for model weights and activations. '
|
||||||
|
'The "auto" option will use FP16 precision '
|
||||||
|
'for FP32 and FP16 models, and BF16 precision '
|
||||||
|
'for BF16 models.')
|
||||||
|
parser.add_argument('--max-model-len',
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help='model context length. If unspecified, '
|
||||||
|
'will be automatically derived from the model.')
|
||||||
|
# Parallel arguments
|
||||||
|
parser.add_argument('--worker-use-ray',
|
||||||
|
action='store_true',
|
||||||
|
help='use Ray for distributed serving, will be '
|
||||||
|
'automatically set when using more than 1 GPU')
|
||||||
|
parser.add_argument('--pipeline-parallel-size',
|
||||||
|
'-pp',
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.pipeline_parallel_size,
|
||||||
|
help='number of pipeline stages')
|
||||||
|
parser.add_argument('--tensor-parallel-size',
|
||||||
|
'-tp',
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.tensor_parallel_size,
|
||||||
|
help='number of tensor parallel replicas')
|
||||||
|
# KV cache arguments
|
||||||
|
parser.add_argument('--block-size',
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.block_size,
|
||||||
|
choices=[8, 16, 32],
|
||||||
|
help='token block size')
|
||||||
|
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||||
|
parser.add_argument('--seed', type=int, default=EngineArgs.seed, help='random seed')
|
||||||
|
parser.add_argument('--swap-space',
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.swap_space,
|
||||||
|
help='CPU swap space size (GiB) per GPU')
|
||||||
|
parser.add_argument('--gpu-memory-utilization',
|
||||||
|
type=float,
|
||||||
|
default=EngineArgs.gpu_memory_utilization,
|
||||||
|
help='the percentage of GPU memory to be used for'
|
||||||
|
'the model executor')
|
||||||
|
parser.add_argument('--max-num-batched-tokens',
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.max_num_batched_tokens,
|
||||||
|
help='maximum number of batched tokens per '
|
||||||
|
'iteration')
|
||||||
|
parser.add_argument('--max-num-seqs',
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.max_num_seqs,
|
||||||
|
help='maximum number of sequences per iteration')
|
||||||
|
parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics')
|
||||||
|
# Quantization settings.
|
||||||
|
parser.add_argument('--quantization',
|
||||||
|
'-q',
|
||||||
|
type=str,
|
||||||
|
choices=['awq', None],
|
||||||
|
default=None,
|
||||||
|
help='Method used to quantize the weights')
|
||||||
|
return parser
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
|
||||||
|
# Get the list of attributes of this dataclass.
|
||||||
|
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||||
|
# Set the attributes from the parsed arguments.
|
||||||
|
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||||
|
return engine_args
|
||||||
|
|
||||||
|
def create_engine_config(
|
||||||
|
self,
|
||||||
|
) -> EngineConfig:
|
||||||
|
# bitsandbytes quantization needs a specific model loader
|
||||||
|
# so we make sure the quant method and the load format are consistent
|
||||||
|
if (self.quantization == "bitsandbytes" or
|
||||||
|
self.qlora_adapter_name_or_path is not None) and \
|
||||||
|
self.load_format != "bitsandbytes":
|
||||||
|
raise ValueError("BitsAndBytes quantization and QLoRA adapter only support "
|
||||||
|
f"'bitsandbytes' load format, but got {self.load_format}")
|
||||||
|
|
||||||
|
if (self.load_format == "bitsandbytes" or
|
||||||
|
self.qlora_adapter_name_or_path is not None) and \
|
||||||
|
self.quantization != "bitsandbytes":
|
||||||
|
raise ValueError("BitsAndBytes load format and QLoRA adapter only support "
|
||||||
|
f"'bitsandbytes' quantization, but got {self.quantization}")
|
||||||
|
|
||||||
|
assert self.cpu_offload_gb >= 0, ("CPU offload space must be non-negative"
|
||||||
|
f", but got {self.cpu_offload_gb}")
|
||||||
|
|
||||||
|
multimodal_config = MultiModalConfig()
|
||||||
|
device_config = DeviceConfig(self.device)
|
||||||
|
# NOTE(sgm): we only modify ModelConfig, other configs are import from vllm
|
||||||
|
model_config = ModelConfig(hf_config=self.model_hf_config,
|
||||||
|
tokenizer_mode=self.tokenizer_mode,
|
||||||
|
trust_remote_code=self.trust_remote_code,
|
||||||
|
dtype=self.dtype,
|
||||||
|
seed=self.seed,
|
||||||
|
revision=self.revision,
|
||||||
|
code_revision=self.code_revision,
|
||||||
|
rope_scaling=self.rope_scaling,
|
||||||
|
rope_theta=self.rope_theta,
|
||||||
|
tokenizer_revision=self.tokenizer_revision,
|
||||||
|
max_model_len=self.max_model_len,
|
||||||
|
quantization=self.quantization,
|
||||||
|
quantization_param_path=self.quantization_param_path,
|
||||||
|
enforce_eager=self.enforce_eager,
|
||||||
|
max_context_len_to_capture=self.max_context_len_to_capture,
|
||||||
|
max_seq_len_to_capture=self.max_seq_len_to_capture,
|
||||||
|
max_logprobs=self.max_logprobs,
|
||||||
|
disable_sliding_window=self.disable_sliding_window,
|
||||||
|
skip_tokenizer_init=self.skip_tokenizer_init,
|
||||||
|
served_model_name=self.served_model_name,
|
||||||
|
multimodal_config=multimodal_config)
|
||||||
|
cache_config = CacheConfig(
|
||||||
|
block_size=self.block_size,
|
||||||
|
gpu_memory_utilization=self.gpu_memory_utilization,
|
||||||
|
swap_space=self.swap_space,
|
||||||
|
cache_dtype=self.kv_cache_dtype,
|
||||||
|
num_gpu_blocks_override=self.num_gpu_blocks_override,
|
||||||
|
sliding_window=model_config.get_sliding_window(),
|
||||||
|
enable_prefix_caching=self.enable_prefix_caching,
|
||||||
|
cpu_offload_gb=self.cpu_offload_gb,
|
||||||
|
)
|
||||||
|
parallel_config = ParallelConfig(pipeline_parallel_size=self.pipeline_parallel_size,
|
||||||
|
tensor_parallel_size=self.tensor_parallel_size,
|
||||||
|
worker_use_ray=self.worker_use_ray,
|
||||||
|
max_parallel_loading_workers=self.max_parallel_loading_workers,
|
||||||
|
disable_custom_all_reduce=self.disable_custom_all_reduce,
|
||||||
|
tokenizer_pool_config=TokenizerPoolConfig.create_config(
|
||||||
|
self.tokenizer_pool_size,
|
||||||
|
self.tokenizer_pool_type,
|
||||||
|
self.tokenizer_pool_extra_config,
|
||||||
|
),
|
||||||
|
ray_workers_use_nsight=self.ray_workers_use_nsight,
|
||||||
|
distributed_executor_backend=self.distributed_executor_backend)
|
||||||
|
|
||||||
|
# NOTE[VERL]: Use the world_size set by TORCHRUN
|
||||||
|
world_size = int(os.getenv("WORLD_SIZE", "-1"))
|
||||||
|
assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN"
|
||||||
|
parallel_config.world_size = world_size
|
||||||
|
|
||||||
|
max_model_len = model_config.max_model_len
|
||||||
|
use_long_context = max_model_len > 32768
|
||||||
|
if self.enable_chunked_prefill is None:
|
||||||
|
# If not explicitly set, enable chunked prefill by default for
|
||||||
|
# long context (> 32K) models. This is to avoid OOM errors in the
|
||||||
|
# initial memory profiling phase.
|
||||||
|
if use_long_context:
|
||||||
|
is_gpu = device_config.device_type == "cuda"
|
||||||
|
use_sliding_window = (model_config.get_sliding_window() is not None)
|
||||||
|
use_spec_decode = self.speculative_model is not None
|
||||||
|
has_seqlen_agnostic_layers = (model_config.contains_seqlen_agnostic_layers(parallel_config))
|
||||||
|
if (is_gpu and not use_sliding_window and not use_spec_decode and not self.enable_lora and
|
||||||
|
not self.enable_prompt_adapter and not self.enable_prefix_caching and
|
||||||
|
not has_seqlen_agnostic_layers):
|
||||||
|
self.enable_chunked_prefill = True
|
||||||
|
logger.warning("Chunked prefill is enabled by default for models with "
|
||||||
|
"max_model_len > 32K. Currently, chunked prefill might "
|
||||||
|
"not work with some features or models. If you "
|
||||||
|
"encounter any issues, please disable chunked prefill "
|
||||||
|
"by setting --enable-chunked-prefill=False.")
|
||||||
|
if self.enable_chunked_prefill is None:
|
||||||
|
self.enable_chunked_prefill = False
|
||||||
|
|
||||||
|
if not self.enable_chunked_prefill and use_long_context:
|
||||||
|
logger.warning(
|
||||||
|
"The model has a long context length (%s). This may cause OOM "
|
||||||
|
"errors during the initial memory profiling phase, or result "
|
||||||
|
"in low performance due to small KV cache space. Consider "
|
||||||
|
"setting --max-model-len to a smaller value.", max_model_len)
|
||||||
|
|
||||||
|
# TODO: spec config
|
||||||
|
speculative_config = SpeculativeConfig.maybe_create_spec_config(
|
||||||
|
target_model_config=model_config,
|
||||||
|
target_parallel_config=parallel_config,
|
||||||
|
target_dtype=self.dtype,
|
||||||
|
speculative_model=self.speculative_model,
|
||||||
|
speculative_draft_tensor_parallel_size = \
|
||||||
|
self.speculative_draft_tensor_parallel_size,
|
||||||
|
num_speculative_tokens=self.num_speculative_tokens,
|
||||||
|
speculative_disable_by_batch_size=self.
|
||||||
|
speculative_disable_by_batch_size,
|
||||||
|
speculative_max_model_len=self.speculative_max_model_len,
|
||||||
|
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||||
|
use_v2_block_manager=self.use_v2_block_manager,
|
||||||
|
disable_log_stats=self.disable_log_stats,
|
||||||
|
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
|
||||||
|
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
|
||||||
|
draft_token_acceptance_method=\
|
||||||
|
self.spec_decoding_acceptance_method,
|
||||||
|
typical_acceptance_sampler_posterior_threshold=self.
|
||||||
|
typical_acceptance_sampler_posterior_threshold,
|
||||||
|
typical_acceptance_sampler_posterior_alpha=self.
|
||||||
|
typical_acceptance_sampler_posterior_alpha,
|
||||||
|
disable_logprobs=self.disable_logprobs_during_spec_decoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler_config = SchedulerConfig(
|
||||||
|
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||||
|
max_num_seqs=self.max_num_seqs,
|
||||||
|
max_model_len=model_config.max_model_len,
|
||||||
|
use_v2_block_manager=self.use_v2_block_manager,
|
||||||
|
num_lookahead_slots=(self.num_lookahead_slots
|
||||||
|
if speculative_config is None else speculative_config.num_lookahead_slots),
|
||||||
|
delay_factor=self.scheduler_delay_factor,
|
||||||
|
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||||
|
embedding_mode=model_config.embedding_mode,
|
||||||
|
preemption_mode=self.preemption_mode,
|
||||||
|
)
|
||||||
|
lora_config = LoRAConfig(max_lora_rank=self.max_lora_rank,
|
||||||
|
max_loras=self.max_loras,
|
||||||
|
fully_sharded_loras=self.fully_sharded_loras,
|
||||||
|
lora_extra_vocab_size=self.lora_extra_vocab_size,
|
||||||
|
long_lora_scaling_factors=self.long_lora_scaling_factors,
|
||||||
|
lora_dtype=self.lora_dtype,
|
||||||
|
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else
|
||||||
|
None) if self.enable_lora else None
|
||||||
|
|
||||||
|
if self.qlora_adapter_name_or_path is not None and \
|
||||||
|
self.qlora_adapter_name_or_path != "":
|
||||||
|
if self.model_loader_extra_config is None:
|
||||||
|
self.model_loader_extra_config = {}
|
||||||
|
self.model_loader_extra_config["qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
|
||||||
|
|
||||||
|
load_config = LoadConfig(
|
||||||
|
load_format=self.load_format,
|
||||||
|
download_dir=self.download_dir,
|
||||||
|
model_loader_extra_config=self.model_loader_extra_config,
|
||||||
|
ignore_patterns=self.ignore_patterns,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_adapter_config = PromptAdapterConfig(
|
||||||
|
max_prompt_adapters=self.max_prompt_adapters,
|
||||||
|
max_prompt_adapter_token=self.max_prompt_adapter_token) \
|
||||||
|
if self.enable_prompt_adapter else None
|
||||||
|
|
||||||
|
decoding_config = DecodingConfig(guided_decoding_backend=self.guided_decoding_backend)
|
||||||
|
|
||||||
|
observability_config = ObservabilityConfig(otlp_traces_endpoint=self.otlp_traces_endpoint)
|
||||||
|
|
||||||
|
if (model_config.get_sliding_window() is not None and scheduler_config.chunked_prefill_enabled and
|
||||||
|
not scheduler_config.use_v2_block_manager):
|
||||||
|
raise ValueError("Chunked prefill is not supported with sliding window. "
|
||||||
|
"Set --disable-sliding-window to disable sliding window.")
|
||||||
|
|
||||||
|
return EngineConfig(
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
parallel_config=parallel_config,
|
||||||
|
scheduler_config=scheduler_config,
|
||||||
|
device_config=device_config,
|
||||||
|
lora_config=lora_config,
|
||||||
|
multimodal_config=multimodal_config,
|
||||||
|
speculative_config=speculative_config,
|
||||||
|
load_config=load_config,
|
||||||
|
decoding_config=decoding_config,
|
||||||
|
observability_config=observability_config,
|
||||||
|
prompt_adapter_config=prompt_adapter_config,
|
||||||
|
)
|
||||||
246
verl/third_party/vllm/vllm_v_0_5_4/config.py
vendored
Normal file
246
verl/third_party/vllm/vllm_v_0_5_4/config.py
vendored
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import json
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
from dataclasses import dataclass, field, fields
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.quantization import get_quantization_config
|
||||||
|
from vllm.transformers_utils.config import get_hf_text_config
|
||||||
|
from vllm.utils import is_hip, print_warning_once
|
||||||
|
# Add for verl
|
||||||
|
from vllm.config import ModelConfig, _get_and_verify_dtype, _get_and_verify_max_len, get_served_model_name
|
||||||
|
|
||||||
|
GPTQMarlinConfig = get_quantization_config("gptq_marlin")
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
_GB = 1 << 30
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConfig(ModelConfig):
|
||||||
|
"""Configuration for the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Name or path of the huggingface model to use.
|
||||||
|
tokenizer: Name or path of the huggingface tokenizer to use.
|
||||||
|
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
|
||||||
|
available, and "slow" will always use the slow tokenizer.
|
||||||
|
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
||||||
|
downloading the model and tokenizer.
|
||||||
|
download_dir: Directory to download and load the weights, default to the
|
||||||
|
default cache directory of huggingface.
|
||||||
|
load_format: The format of the model weights to load:
|
||||||
|
"auto" will try to load the weights in the safetensors format and
|
||||||
|
fall back to the pytorch bin format if safetensors format is
|
||||||
|
not available.
|
||||||
|
"pt" will load the weights in the pytorch bin format.
|
||||||
|
"safetensors" will load the weights in the safetensors format.
|
||||||
|
"npcache" will load the weights in pytorch format and store
|
||||||
|
a numpy cache to speed up the loading.
|
||||||
|
"dummy" will initialize the weights with random values, which is
|
||||||
|
mainly for profiling.
|
||||||
|
dtype: Data type for model weights and activations. The "auto" option
|
||||||
|
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
||||||
|
for BF16 models.
|
||||||
|
seed: Random seed for reproducibility.
|
||||||
|
revision: The specific model version to use. It can be a branch name,
|
||||||
|
a tag name, or a commit id. If unspecified, will use the default
|
||||||
|
version.
|
||||||
|
code_revision: The specific revision to use for the model code on
|
||||||
|
Hugging Face Hub. It can be a branch name, a tag name, or a
|
||||||
|
commit id. If unspecified, will use the default version.
|
||||||
|
tokenizer_revision: The specific tokenizer version to use. It can be a
|
||||||
|
branch name, a tag name, or a commit id. If unspecified, will use
|
||||||
|
the default version.
|
||||||
|
max_model_len: Maximum length of a sequence (including prompt and
|
||||||
|
output). If None, will be derived from the model.
|
||||||
|
quantization: Quantization method that was used to quantize the model
|
||||||
|
weights. If None, we assume the model weights are not quantized.
|
||||||
|
quantization_param_path: Path to JSON file containing scaling factors.
|
||||||
|
Used to load KV cache scaling factors into the model when KV cache
|
||||||
|
type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
|
||||||
|
be used to load activation and weight scaling factors when the
|
||||||
|
model dtype is FP8_E4M3 on ROCm.
|
||||||
|
enforce_eager: Whether to enforce eager execution. If True, we will
|
||||||
|
disable CUDA graph and always execute the model in eager mode.
|
||||||
|
If False, we will use CUDA graph and eager execution in hybrid.
|
||||||
|
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
||||||
|
When a sequence has context length larger than this, we fall back
|
||||||
|
to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
|
||||||
|
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
|
||||||
|
When a sequence has context length larger than this, we fall back
|
||||||
|
to eager mode
|
||||||
|
skip_tokenizer_init: If true, skip initialization of tokenizer and
|
||||||
|
detokenizer.
|
||||||
|
served_model_name: The model name used in metrics tag `model_name`,
|
||||||
|
matches the model name exposed via the APIs. If multiple model
|
||||||
|
names provided, the first name will be used. If not specified,
|
||||||
|
the model name will be the same as `model`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hf_config: PretrainedConfig,
|
||||||
|
tokenizer_mode: str,
|
||||||
|
trust_remote_code: bool,
|
||||||
|
dtype: Union[str, torch.dtype],
|
||||||
|
seed: int,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
code_revision: Optional[str] = None,
|
||||||
|
rope_scaling: Optional[dict] = None,
|
||||||
|
rope_theta: Optional[float] = None,
|
||||||
|
tokenizer_revision: Optional[str] = None,
|
||||||
|
max_model_len: Optional[int] = None,
|
||||||
|
quantization: Optional[str] = None,
|
||||||
|
quantization_param_path: Optional[str] = None,
|
||||||
|
enforce_eager: bool = False,
|
||||||
|
max_context_len_to_capture: Optional[int] = None,
|
||||||
|
max_seq_len_to_capture: Optional[int] = None,
|
||||||
|
max_logprobs: int = 20,
|
||||||
|
disable_sliding_window: bool = False,
|
||||||
|
skip_tokenizer_init: bool = False,
|
||||||
|
served_model_name: Optional[Union[str, List[str]]] = None,
|
||||||
|
multimodal_config: Optional["MultiModalConfig"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.model = hf_config._name_or_path
|
||||||
|
self.tokenizer = hf_config._name_or_path
|
||||||
|
# NOTE(sgm): same as open-sourced
|
||||||
|
self.tokenizer_mode = tokenizer_mode
|
||||||
|
self.trust_remote_code = trust_remote_code
|
||||||
|
self.seed = seed
|
||||||
|
self.revision = revision
|
||||||
|
self.code_revision = code_revision
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
# The tokenizer version is consistent with the model version by default.
|
||||||
|
if tokenizer_revision is None:
|
||||||
|
self.tokenizer_revision = revision
|
||||||
|
else:
|
||||||
|
self.tokenizer_revision = tokenizer_revision
|
||||||
|
self.quantization = quantization
|
||||||
|
self.quantization_param_path = quantization_param_path
|
||||||
|
self.enforce_eager = enforce_eager
|
||||||
|
if max_context_len_to_capture is not None:
|
||||||
|
raise ValueError("`max_context_len_to_capture` is deprecated. "
|
||||||
|
"Use `max_seq_len_to_capture` instead.")
|
||||||
|
self.max_seq_len_to_capture = max_seq_len_to_capture
|
||||||
|
self.max_logprobs = max_logprobs
|
||||||
|
self.disable_sliding_window = disable_sliding_window
|
||||||
|
self.skip_tokenizer_init = skip_tokenizer_init
|
||||||
|
|
||||||
|
# self.hf_config = get_config(model, trust_remote_code, revision)
|
||||||
|
self.hf_config = hf_config
|
||||||
|
self.hf_text_config = get_hf_text_config(hf_config)
|
||||||
|
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||||
|
# self.served_model_name = get_served_model_name(model,
|
||||||
|
# served_model_name)
|
||||||
|
# self._verify_load_format()
|
||||||
|
# self._verify_tokenizer_mode()
|
||||||
|
if (not self.disable_sliding_window and self.hf_text_config.model_type == "gemma2" and
|
||||||
|
self.hf_text_config.sliding_window is not None):
|
||||||
|
print_warning_once("Gemma 2 uses sliding window attention for every odd layer, "
|
||||||
|
"which is currently not supported by vLLM. Disabling sliding "
|
||||||
|
"window and capping the max length to the sliding window size "
|
||||||
|
f"({self.hf_text_config.sliding_window}).")
|
||||||
|
self.disable_sliding_window = True
|
||||||
|
|
||||||
|
self.max_model_len = _get_and_verify_max_len(hf_config=self.hf_text_config,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
disable_sliding_window=self.disable_sliding_window,
|
||||||
|
sliding_window_len=self.get_hf_config_sliding_window())
|
||||||
|
self.served_model_name = get_served_model_name(
|
||||||
|
self.model, # str
|
||||||
|
served_model_name)
|
||||||
|
self.multimodal_config = multimodal_config
|
||||||
|
|
||||||
|
if not self.skip_tokenizer_init:
|
||||||
|
self._verify_tokenizer_mode()
|
||||||
|
self._verify_embedding_mode()
|
||||||
|
self._verify_quantization()
|
||||||
|
self._verify_cuda_graph()
|
||||||
|
|
||||||
|
|
||||||
|
class LoadFormat(str, enum.Enum):
|
||||||
|
AUTO = 'auto'
|
||||||
|
MEGATRON = "megatron"
|
||||||
|
HF = "hf"
|
||||||
|
DTENSOR = 'dtensor'
|
||||||
|
DUMMY_HF = 'dummy_hf'
|
||||||
|
DUMMY_MEGATRON = 'dummy_megatron'
|
||||||
|
DUMMY_DTENSOR = 'dummy_dtensor'
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: check whether this is necessary
|
||||||
|
@dataclass
|
||||||
|
class LoadConfig:
|
||||||
|
"""
|
||||||
|
download_dir: Directory to download and load the weights, default to the
|
||||||
|
default cache directory of huggingface.
|
||||||
|
load_format: The format of the model weights to load:
|
||||||
|
"auto" will try to load the weights in the safetensors format and
|
||||||
|
fall back to the pytorch bin format if safetensors format is
|
||||||
|
not available.
|
||||||
|
"pt" will load the weights in the pytorch bin format.
|
||||||
|
"safetensors" will load the weights in the safetensors format.
|
||||||
|
"npcache" will load the weights in pytorch format and store
|
||||||
|
a numpy cache to speed up the loading.
|
||||||
|
"dummy" will initialize the weights with random values, which is
|
||||||
|
mainly for profiling.
|
||||||
|
"tensorizer" will use CoreWeave's tensorizer library for
|
||||||
|
fast weight loading.
|
||||||
|
"bitsandbytes" will load nf4 type weights.
|
||||||
|
ignore_patterns: The list of patterns to ignore when loading the model.
|
||||||
|
Default to "original/**/*" to avoid repeated loading of llama's
|
||||||
|
checkpoints.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
|
||||||
|
download_dir: Optional[str] = None
|
||||||
|
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
|
||||||
|
ignore_patterns: Optional[Union[List[str], str]] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
model_loader_extra_config = self.model_loader_extra_config or {}
|
||||||
|
if isinstance(model_loader_extra_config, str):
|
||||||
|
self.model_loader_extra_config = json.loads(model_loader_extra_config)
|
||||||
|
self._verify_load_format()
|
||||||
|
|
||||||
|
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
|
||||||
|
logger.info("Ignoring the following patterns when downloading weights: %s", self.ignore_patterns)
|
||||||
|
else:
|
||||||
|
self.ignore_patterns = ["original/**/*"]
|
||||||
|
|
||||||
|
def _verify_load_format(self) -> None:
|
||||||
|
if not isinstance(self.load_format, str):
|
||||||
|
return
|
||||||
|
|
||||||
|
load_format = self.load_format.lower()
|
||||||
|
self.load_format = LoadFormat(load_format)
|
||||||
|
|
||||||
|
rocm_not_supported_load_format: List[str] = []
|
||||||
|
if is_hip() and load_format in rocm_not_supported_load_format:
|
||||||
|
rocm_supported_load_format = [
|
||||||
|
f for f in LoadFormat.__members__ if (f not in rocm_not_supported_load_format)
|
||||||
|
]
|
||||||
|
raise ValueError(f"load format '{load_format}' is not supported in ROCm. "
|
||||||
|
f"Supported load formats are "
|
||||||
|
f"{rocm_supported_load_format}")
|
||||||
340
verl/third_party/vllm/vllm_v_0_5_4/dtensor_weight_loaders.py
vendored
Normal file
340
verl/third_party/vllm/vllm_v_0_5_4/dtensor_weight_loaders.py
vendored
Normal file
@@ -0,0 +1,340 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models
|
||||||
|
|
||||||
|
from typing import Dict, Iterable, Tuple
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.distributed._tensor import DTensor, Shard, Replicate
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.linear import *
|
||||||
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||||
|
|
||||||
|
|
||||||
|
def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(vllm_model.named_parameters())
|
||||||
|
for name, loaded_weight in actor_weights.items():
|
||||||
|
for (param_name, shard_name, shard_id) in stacked_params_mapping:
|
||||||
|
if shard_name not in name:
|
||||||
|
continue
|
||||||
|
stacked_name = name.replace(shard_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if stacked_name.endswith(".bias") and stacked_name not in params_dict:
|
||||||
|
continue
|
||||||
|
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
|
||||||
|
param = params_dict[stacked_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# lm_head is not used in vllm as it is tied with embed_token.
|
||||||
|
# To prevent errors, skip loading lm_head.weight.
|
||||||
|
if "lm_head.weight" in name:
|
||||||
|
continue
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
|
||||||
|
|
||||||
|
|
||||||
|
def gptbigcode_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module):
|
||||||
|
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
|
||||||
|
for name, loaded_weight in actor_weights.items():
|
||||||
|
if "lm_head.weight" in name:
|
||||||
|
continue
|
||||||
|
if ".attn.bias" in name:
|
||||||
|
# Skip attention mask.
|
||||||
|
# NOTE: "c_attn.bias" should not be skipped.
|
||||||
|
continue
|
||||||
|
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
|
||||||
|
|
||||||
|
|
||||||
|
def starcoder2_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module):
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
]
|
||||||
|
|
||||||
|
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
|
||||||
|
for name, loaded_weight in actor_weights.items():
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
|
||||||
|
|
||||||
|
|
||||||
|
def llama_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
|
(".qkv_proj", ".k_proj", "k"),
|
||||||
|
(".qkv_proj", ".v_proj", "v"),
|
||||||
|
(".gate_up_proj", ".gate_proj", 0),
|
||||||
|
(".gate_up_proj", ".up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(vllm_model.named_parameters())
|
||||||
|
for name, loaded_weight in actor_weights.items():
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
if ("rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name):
|
||||||
|
# Models trained using ColossalAI may include these tensors in
|
||||||
|
# the checkpoint. Skip them.
|
||||||
|
continue
|
||||||
|
# With tie_word_embeddings, we can skip lm_head.weight
|
||||||
|
# The weight might appear unnecessarily in the files if the model is
|
||||||
|
# processed with quantization, LoRA, fine-tuning, etc.
|
||||||
|
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||||
|
continue
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, local_loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
def qwen2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
|
||||||
|
for name, loaded_weight in actor_weights.items():
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||||
|
continue
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
|
||||||
|
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
|
|
||||||
|
|
||||||
|
def deepseekv2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(ckpt_gate_proj_name="gate_proj",
|
||||||
|
ckpt_down_proj_name="down_proj",
|
||||||
|
ckpt_up_proj_name="up_proj",
|
||||||
|
num_experts=vllm_model.config.n_routed_experts)
|
||||||
|
|
||||||
|
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
|
||||||
|
for name, loaded_weight in actor_weights.items():
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
# Skip non-stacked layers and experts (experts handled below).
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||||
|
# Since we handle the experts below in expert_params_mapping,
|
||||||
|
# we need to skip here BEFORE we update the name, otherwise
|
||||||
|
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||||
|
# will then be updated below in expert_params_mapping
|
||||||
|
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||||
|
if (("mlp.experts." in name) and name not in params_dict):
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, vllm_model):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
for mapping in expert_params_mapping:
|
||||||
|
param_name, weight_name, expert_id, shard_id = mapping
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, vllm_model):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param,
|
||||||
|
local_loaded_weight.to(dtype=param.dtype),
|
||||||
|
weight_name,
|
||||||
|
shard_id=shard_id,
|
||||||
|
expert_id=expert_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, vllm_model):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
|
||||||
|
|
||||||
|
|
||||||
|
def gpt2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def redistribute_dtensor(param_name: str, loaded_weights: DTensor, parallelize_plan: Dict = None):
|
||||||
|
param_name = _process_parameter_names(name=param_name)
|
||||||
|
if parallelize_plan is not None:
|
||||||
|
assert param_name in parallelize_plan.keys(), \
|
||||||
|
f"param name: {param_name} not in parallelize_plan :{parallelize_plan.keys()}"
|
||||||
|
placement = parallelize_plan[param_name]
|
||||||
|
local_loaded_weights = loaded_weights.redistribute(device_mesh=loaded_weights.device_mesh,
|
||||||
|
placements=placement).to_local()
|
||||||
|
else:
|
||||||
|
local_loaded_weights = loaded_weights.full_tensor()
|
||||||
|
return local_loaded_weights
|
||||||
|
|
||||||
|
|
||||||
|
def _process_parameter_names(name):
|
||||||
|
# Remove '.weight' if it exists at the end of the string
|
||||||
|
if name.endswith(".weight"):
|
||||||
|
name = name[:-7]
|
||||||
|
|
||||||
|
# Remove 'model.layers.x.' or 'model.' prefix
|
||||||
|
if "model.layers" in name:
|
||||||
|
parts = name.split('.')
|
||||||
|
# Reconstruct the string without 'model.layers.x.'
|
||||||
|
name = '.'.join(parts[3:]) # parts[0] is 'model', parts[1] is 'layers', parts[2] is 'x'
|
||||||
|
elif name.startswith("model."):
|
||||||
|
name = name[6:] # Remove 'model.'
|
||||||
|
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__ = {
|
||||||
|
'GPT2LMHeadModel': gpt2_dtensor_weight_loader,
|
||||||
|
'LlamaForCausalLM': llama_dtensor_weight_loader,
|
||||||
|
'LLaMAForCausalLM': llama_dtensor_weight_loader,
|
||||||
|
'MistralForCausalLM': llama_dtensor_weight_loader, # mistral is the same as llama in vLLM
|
||||||
|
'InternLMForCausalLM': llama_dtensor_weight_loader,
|
||||||
|
'AquilaModel': llama_dtensor_weight_loader,
|
||||||
|
'AquilaForCausalLM': llama_dtensor_weight_loader,
|
||||||
|
'Phi3ForCausalLM': llama_dtensor_weight_loader,
|
||||||
|
'GemmaForCausalLM': gemma_dtensor_weight_loader,
|
||||||
|
'Gemma2ForCausalLM': gemma_dtensor_weight_loader,
|
||||||
|
'GPTBigCodeForCausalLM': gptbigcode_dtensor_load_weights,
|
||||||
|
'Starcoder2ForCausalLM': starcoder2_dtensor_load_weights,
|
||||||
|
'Qwen2ForCausalLM': qwen2_dtensor_weight_loader,
|
||||||
|
'DeepseekV2ForCausalLM': deepseekv2_dtensor_weight_loader
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# the actor model is .state_dict()
|
||||||
|
# Load dtensor weights
|
||||||
|
def load_dtensor_weights(actor_weights: Dict, vllm_model: nn.Module):
|
||||||
|
weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__)
|
||||||
|
weight_loader(actor_weights, vllm_model)
|
||||||
|
# NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu
|
||||||
|
# after init, and we need this after sync model weights for in first iter.
|
||||||
|
vllm_model = vllm_model.cuda()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model_weight_loader(arch: str):
|
||||||
|
if arch in __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__:
|
||||||
|
return __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__[arch]
|
||||||
|
raise ValueError(f"Model architectures {arch} are not supported for now. "
|
||||||
|
f"Supported architectures: {__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__.keys()}")
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE(sgm): we use per-parameter weight loader in each vllm sub
|
||||||
|
def update_dtensor_weight_loader():
|
||||||
|
pass
|
||||||
44
verl/third_party/vllm/vllm_v_0_5_4/hf_weight_loader.py
vendored
Normal file
44
verl/third_party/vllm/vllm_v_0_5_4/hf_weight_loader.py
vendored
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models
|
||||||
|
|
||||||
|
from typing import Dict, Union, Optional, Iterable, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
|
|
||||||
|
def update_hf_weight_loader():
|
||||||
|
print('no hf weight loader need to be updated')
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module):
|
||||||
|
assert isinstance(actor_weights, Dict)
|
||||||
|
with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO
|
||||||
|
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights.keys():
|
||||||
|
del actor_weights["lm_head.weight"]
|
||||||
|
vllm_model.load_weights(actor_weights.items())
|
||||||
|
for _, module in vllm_model.named_modules():
|
||||||
|
quant_method = getattr(module, "quant_method", None)
|
||||||
|
if quant_method is not None:
|
||||||
|
quant_method.process_weights_after_loading(module)
|
||||||
|
# FIXME: Remove this after Mixtral is updated
|
||||||
|
# to use quant_method.
|
||||||
|
if hasattr(module, "process_weights_after_loading"):
|
||||||
|
module.process_weights_after_loading()
|
||||||
|
vllm_model = vllm_model.cuda()
|
||||||
239
verl/third_party/vllm/vllm_v_0_5_4/llm.py
vendored
Normal file
239
verl/third_party/vllm/vllm_v_0_5_4/llm.py
vendored
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload, Dict, Tuple
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
import torch.nn as nn
|
||||||
|
from .arg_utils import EngineArgs
|
||||||
|
from .llm_engine_sp import LLMEngine
|
||||||
|
from vllm import LLM
|
||||||
|
from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt, parse_and_batch_prompt)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.model_executor.guided_decoding import (GuidedDecodingRequest, get_local_guided_decoding_logits_processor)
|
||||||
|
from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
|
||||||
|
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
|
||||||
|
from vllm.usage.usage_lib import UsageContext
|
||||||
|
from vllm.utils import Counter, deprecate_kwargs
|
||||||
|
import torch
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class LLM(LLM):
|
||||||
|
"""An LLM for generating texts from given prompts and sampling parameters.
|
||||||
|
|
||||||
|
This class includes a tokenizer, a language model (possibly distributed
|
||||||
|
across multiple GPUs), and GPU memory space allocated for intermediate
|
||||||
|
states (aka KV cache). Given a batch of prompts and sampling parameters,
|
||||||
|
this class generates texts from the model, using an intelligent batching
|
||||||
|
mechanism and efficient memory management.
|
||||||
|
|
||||||
|
NOTE: This class is intended to be used for offline inference. For online
|
||||||
|
serving, use the `AsyncLLMEngine` class instead.
|
||||||
|
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: A HuggingFace Transformers model instance.
|
||||||
|
tokenizer: A HuggingFace Transformers tokenizer instance.
|
||||||
|
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
|
||||||
|
if available, and "slow" will always use the slow tokenizer.
|
||||||
|
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
||||||
|
downloading the model and tokenizer.
|
||||||
|
tensor_parallel_size: The number of GPUs to use for distributed
|
||||||
|
execution with tensor parallelism.
|
||||||
|
dtype: The data type for the model weights and activations. Currently,
|
||||||
|
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
|
||||||
|
the `torch_dtype` attribute specified in the model config file.
|
||||||
|
However, if the `torch_dtype` in the config is `float32`, we will
|
||||||
|
use `float16` instead.
|
||||||
|
quantization: The method used to quantize the model weights. Currently,
|
||||||
|
we support "awq". If None, we assume the model weights are not
|
||||||
|
quantized and use `dtype` to determine the data type of the weights.
|
||||||
|
revision: The specific model version to use. It can be a branch name,
|
||||||
|
a tag name, or a commit id.
|
||||||
|
tokenizer_revision: The specific tokenizer version to use. It can be a
|
||||||
|
branch name, a tag name, or a commit id.
|
||||||
|
seed: The seed to initialize the random number generator for sampling.
|
||||||
|
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
|
||||||
|
reserve for the model weights, activations, and KV cache. Higher
|
||||||
|
values will increase the KV cache size and thus improve the model's
|
||||||
|
throughput. However, if the value is too high, it may cause out-of-
|
||||||
|
memory (OOM) errors.
|
||||||
|
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
|
||||||
|
This can be used for temporarily storing the states of the requests
|
||||||
|
when their `best_of` sampling parameters are larger than 1. If all
|
||||||
|
requests will have `best_of=1`, you can safely set this to 0.
|
||||||
|
Otherwise, too small values may cause out-of-memory (OOM) errors.
|
||||||
|
enforce_eager: Whether to enforce eager execution. If True, we will
|
||||||
|
disable CUDA graph and always execute the model in eager mode.
|
||||||
|
If False, we will use CUDA graph and eager execution in hybrid.
|
||||||
|
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
||||||
|
When a sequence has context length larger than this, we fall back
|
||||||
|
to eager mode.
|
||||||
|
disable_custom_all_reduce: See ParallelConfig
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Union[nn.Module, Dict], # model itself or its parameter dict
|
||||||
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer],
|
||||||
|
model_hf_config: PretrainedConfig,
|
||||||
|
tokenizer_mode: str = "auto",
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
skip_tokenizer_init: bool = False,
|
||||||
|
tensor_parallel_size: int = 1,
|
||||||
|
dtype: str = "auto",
|
||||||
|
quantization: Optional[str] = None,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
tokenizer_revision: Optional[str] = None,
|
||||||
|
seed: int = 0,
|
||||||
|
gpu_memory_utilization: float = 0.9,
|
||||||
|
swap_space: int = 4,
|
||||||
|
cpu_offload_gb: float = 0,
|
||||||
|
enforce_eager: bool = False,
|
||||||
|
max_context_len_to_capture: Optional[int] = None,
|
||||||
|
max_seq_len_to_capture: int = 8192,
|
||||||
|
disable_custom_all_reduce: bool = False,
|
||||||
|
load_format = 'auto',
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
if "disable_log_stats" not in kwargs:
|
||||||
|
kwargs["disable_log_stats"] = True
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model_hf_config=model_hf_config,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
dtype=dtype,
|
||||||
|
quantization=quantization,
|
||||||
|
revision=revision,
|
||||||
|
tokenizer_revision=tokenizer_revision,
|
||||||
|
seed=seed,
|
||||||
|
gpu_memory_utilization=gpu_memory_utilization,
|
||||||
|
swap_space=swap_space,
|
||||||
|
cpu_offload_gb=cpu_offload_gb,
|
||||||
|
enforce_eager=enforce_eager,
|
||||||
|
max_context_len_to_capture=max_context_len_to_capture,
|
||||||
|
max_seq_len_to_capture=max_seq_len_to_capture,
|
||||||
|
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||||
|
load_format=load_format,
|
||||||
|
skip_tokenizer_init=skip_tokenizer_init,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
tokenizer_cls = (PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer)
|
||||||
|
if not isinstance(tokenizer, tokenizer_cls):
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected tokenizer type: {type(tokenizer)}. Must be"
|
||||||
|
"one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer"
|
||||||
|
)
|
||||||
|
self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args) # TODO: check usagecontext
|
||||||
|
self.request_counter = Counter()
|
||||||
|
|
||||||
|
def init_cache_engine(self):
|
||||||
|
self.llm_engine.init_cache_engine()
|
||||||
|
|
||||||
|
def free_cache_engine(self):
|
||||||
|
self.llm_engine.free_cache_engine()
|
||||||
|
|
||||||
|
def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||||
|
return self.llm_engine.tokenizer
|
||||||
|
|
||||||
|
def set_tokenizer(
|
||||||
|
self,
|
||||||
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
|
) -> None:
|
||||||
|
self.llm_engine.tokenizer = tokenizer
|
||||||
|
|
||||||
|
def _run_engine(self, *, use_tqdm: bool) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||||
|
# Initialize tqdm.
|
||||||
|
if use_tqdm:
|
||||||
|
num_requests = self.llm_engine.get_num_unfinished_requests()
|
||||||
|
pbar = tqdm(
|
||||||
|
total=num_requests,
|
||||||
|
desc="Processed prompts",
|
||||||
|
dynamic_ncols=True,
|
||||||
|
postfix=(f"est. speed input: {0:.2f} toks/s, "
|
||||||
|
f"output: {0:.2f} toks/s"),
|
||||||
|
)
|
||||||
|
# Run the engine.
|
||||||
|
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
|
||||||
|
total_in_toks = 0
|
||||||
|
total_out_toks = 0
|
||||||
|
while self.llm_engine.has_unfinished_requests():
|
||||||
|
step_outputs = self.llm_engine.step()
|
||||||
|
for output in step_outputs:
|
||||||
|
if output.finished:
|
||||||
|
outputs.append(output)
|
||||||
|
if use_tqdm:
|
||||||
|
if isinstance(output, RequestOutput):
|
||||||
|
# Calculate tokens only for RequestOutput
|
||||||
|
total_in_toks += len(output.prompt_token_ids)
|
||||||
|
in_spd = total_in_toks / pbar.format_dict["elapsed"]
|
||||||
|
total_out_toks += sum(len(stp.token_ids) for stp in output.outputs)
|
||||||
|
out_spd = total_out_toks / pbar.format_dict["elapsed"]
|
||||||
|
pbar.postfix = (f"est. speed input: {in_spd:.2f} toks/s, "
|
||||||
|
f"output: {out_spd:.2f} toks/s")
|
||||||
|
pbar.update(1)
|
||||||
|
if use_tqdm:
|
||||||
|
pbar.close()
|
||||||
|
# Sort the outputs by request ID.
|
||||||
|
# This is necessary because some requests may be finished earlier than
|
||||||
|
# its previous requests.
|
||||||
|
outputs = sorted(outputs, key=lambda x: int(x.request_id))
|
||||||
|
return self._post_process_outputs(outputs)
|
||||||
|
|
||||||
|
# # NOTE(shengguangming): add for verl
|
||||||
|
# # TODO(sgm): we can optimize it by making the dataloader yield List[int] without padding.
|
||||||
|
# def _pre_process_inputs(self, prompt_token_ids: torch.Tensor) -> List[int]:
|
||||||
|
# # remove the left padding in the prompt token_id
|
||||||
|
# pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
|
||||||
|
# non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
|
||||||
|
# token_ids = prompt_token_ids[non_pad_index:].tolist()
|
||||||
|
# return token_ids
|
||||||
|
|
||||||
|
# NOTE(shengguangming): add for verl
|
||||||
|
def _post_process_outputs(self, request_outputs: List[RequestOutput]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
output_token_ids = []
|
||||||
|
logprobs = []
|
||||||
|
for request_output in request_outputs: # List[RequestOutput]
|
||||||
|
outputs = request_output.outputs
|
||||||
|
for output in outputs: # List[CompletionOutput], usually len == 1
|
||||||
|
output_token_ids.append(torch.tensor(output.token_ids))
|
||||||
|
# TODO(shengguangming): can be optimzied by rewrite the Sampler._get_logprobs() logits
|
||||||
|
logprobs_dicts = output.logprobs
|
||||||
|
if logprobs_dicts is not None:
|
||||||
|
logprob = []
|
||||||
|
for logprobs_dict, id in zip(logprobs_dicts, output.token_ids):
|
||||||
|
logprob.append(logprobs_dict[id].logprob)
|
||||||
|
logprobs.append(torch.tensor(logprob))
|
||||||
|
|
||||||
|
pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
|
||||||
|
output_token_ids = pad_sequence(output_token_ids, batch_first=True, padding_value=pad_token_id)
|
||||||
|
if len(logprobs) > 0:
|
||||||
|
logprobs = pad_sequence(logprobs, batch_first=True, padding_value=pad_token_id)
|
||||||
|
return output_token_ids, logprobs
|
||||||
|
|
||||||
|
def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], load_format: str) -> None:
|
||||||
|
self.llm_engine.sync_model_weights(actor_weights=actor_weights, load_format=load_format)
|
||||||
|
|
||||||
|
def offload_model_weights(self) -> None:
|
||||||
|
self.llm_engine.offload_model_weights()
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user