Files
datapipe/clean/preprocess_mineru_new.py
2025-01-18 17:09:51 +08:00

246 lines
9.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
import os
import requests
import time
import PyPDF2
import multiprocessing as mp
import math
import sys
import torch
from loguru import logger
from glob import glob
from tqdm import tqdm
from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.config.enums import SupportedPdfParseMethod
# 图床配置
IMGBED_URL = "http://localhost:40027/"
# 检查imgbed url是否以/结尾
if not IMGBED_URL.endswith('/'):
IMGBED_URL += '/'
token_endpoint = f"{IMGBED_URL}api/v1/tokens"
upload_endpoint = f"{IMGBED_URL}api/v1/upload"
# 通过如下方式获取token
# curl -X POST http://localhost:40027/api/v1/tokens -H "Content-Type: application/json" -d '{"email":"yt.li2@siat.ac.cn", "password":"lyt20000414."}'
IMGBED_TOKEN = "6|QsBh5H7txY3Hd7ju1nzYKOBSdFQeL0YberydSFIH"
def replace_image_links(md_content: str, images_urls: dict) -> str:
# 匹配 Markdown 中的图像链接形式,即: ![alt text](image_path)
pattern = r'!\[(.*?)\]\((.*?)\)'
def replace_link(match):
# 提取出当前匹配到的图片路径
image_path = match.group(2)
# 检查该路径是否在字典中
if image_path in images_urls:
# 从字典中获取新的 URL
new_url = images_urls[image_path]
return f"![]({new_url})"
return match.group(0)
# 使用 sub 函数进行替换
updated_md_content = re.sub(pattern, replace_link, md_content)
return updated_md_content
# 上传图片到LSKY Pro
def upload_image(img_dir):
headers = {
"Authorization": f"Bearer {IMGBED_TOKEN}",
'Accept': 'application/json'
}
image_urls = {}
os.makedirs(img_dir, exist_ok=True)
img_names = os.listdir(img_dir)
for image_name in img_names:
retry = 0
image_path = os.path.join(img_dir, image_name)
while retry < 5: # 最大重试次数
try:
with open(image_path, 'rb') as image_file: # 确保文件在上传时是打开状态
files = {'file': image_file}
# 上传文件
response = requests.post(upload_endpoint, headers=headers, files=files)
if response.status_code == 200:
result = response.json()
if result['status']:
image_url = result['data']['links']['url']
image_urls['images/'+image_name] = image_url
break # 上传成功,退出重试循环
else:
raise Exception(f"图片上传失败: {result['message']}")
elif response.status_code == 429:
# 429 响应,等待一段时间再重试
wait_time = min(2 ** retry, 60) # 指数退避,最大等待 60 秒
logger.warning(f"请求过于频繁,等待 {wait_time} 秒...")
time.sleep(wait_time)
else:
raise Exception(f"HTTP请求出错: {response.status_code}")
retry += 1 # 增加重试次数
time.sleep(1) # 在重试失败后稍等一下
except FileNotFoundError:
logger.error(f"文件 {image_path} 不存在,请检查路径是否正确")
return
return image_urls
def pdf_parse_main(
pdf_path: str,
output_dir: str = None
):
try:
name_without_suff = os.path.basename(pdf_path).replace('.pdf', '')
# prepare env
local_md_dir = os.path.join(output_dir, name_without_suff)
local_image_dir = os.path.join(local_md_dir, 'images')
image_dir = str(os.path.basename(local_image_dir))
os.makedirs(local_image_dir, exist_ok=True)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(
local_md_dir
)
# read bytes
reader1 = FileBasedDataReader("")
pdf_bytes = reader1.read(pdf_path) # read the pdf content
# proc
## Create Dataset Instance
ds = PymuDocDataset(pdf_bytes)
## inference
if ds.classify() == SupportedPdfParseMethod.OCR:
infer_result = ds.apply(doc_analyze, ocr=True)
## pipeline
pipe_result = infer_result.pipe_ocr_mode(image_writer)
else:
infer_result = ds.apply(doc_analyze, ocr=False)
## pipeline
pipe_result = infer_result.pipe_txt_mode(image_writer)
### draw model result on each page
infer_result.draw_model(os.path.join(local_md_dir, f"{name_without_suff}_model.pdf"))
### draw layout result on each page
pipe_result.draw_layout(os.path.join(local_md_dir, f"{name_without_suff}_layout.pdf"))
### draw spans result on each page
pipe_result.draw_span(os.path.join(local_md_dir, f"{name_without_suff}_spans.pdf"))
### dump markdown
md_content = pipe_result.dump_md(md_writer, os.path.join(local_md_dir, f"{name_without_suff}.md"), image_dir)
### dump content list
pipe_result.dump_content_list(md_writer, os.path.join(local_md_dir, f"{name_without_suff}_content_list.json"), image_dir)
# print(md_content)
# 上传图像到图床
image_urls = upload_image(local_image_dir)
md_content = replace_image_links(md_content, image_urls)
md_writer.write_string(os.path.join(local_md_dir, f"{name_without_suff}.md"), md_content)
except Exception as e:
logger.exception(e)
return 'error'
def init_worker(pdfs, gpu_index, output_dir): # 添加output_dir参数
"""
Initialize a worker process to process a chunk of PDFs with a specific GPU.
"""
try:
# 设置CUDA设备
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_index)
import torch
device = torch.device('cuda:0')
print(f"进程 {os.getpid()} 启动于GPU {gpu_index}")
print(f"处理 {len(pdfs)} 个PDF文件")
process_pdf_chunk(pdfs, device, output_dir) # 传递output_dir
except Exception as e:
print(f"进程 {os.getpid()} 在GPU {gpu_index} 上初始化失败: {str(e)}")
raise e
def process_pdf_chunk(pdf_paths, worker_id, output_dir):
for pdf_path in tqdm(pdf_paths, desc=f"Worker {worker_id} Progress"):
try:
# 定期清理GPU内存
torch.cuda.empty_cache()
with open(pdf_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
print(os.path.basename(pdf_path).replace(".pdf", "").replace('_', '/'))
pdf_parse_main(pdf_path, output_dir=output_dir)
except PyPDF2.errors.PdfReadError:
logger.error(f"{pdf_path} has been broken")
except Exception as e:
logger.error(f"{pdf_path} has an error: {e}")
def multiprocessing_setup(pdf_paths, num_gpus, output_dir):
# 计算每个GPU处理的文件数量
chunk_size = math.ceil(len(pdf_paths) / num_gpus)
processes = []
# 为每个GPU创建一个进程
for gpu_id in range(num_gpus):
start_idx = gpu_id * chunk_size
end_idx = min(len(pdf_paths), start_idx + chunk_size)
chunk = pdf_paths[start_idx:end_idx]
p = mp.Process(target=init_worker, args=(chunk, gpu_id, output_dir)) # 传递output_dir
processes.append(p)
p.start()
time.sleep(2)
# 等待所有进程完成
for p in processes:
p.join()
if __name__ == "__main__":
_cur_dir = os.path.dirname(os.path.abspath(__file__))
# 此处更改路径
# pdf_dir = os.path.join(_cur_dir, "二维材料剥离/二维材料剥离/石墨烯")
# output_dir = os.path.join(_cur_dir, "二维材料剥离/mds/石墨烯")
# pdf_dir = os.path.join(_cur_dir, "二维材料剥离/二维材料剥离/黑磷烯")
# output_dir = os.path.join(_cur_dir, "二维材料剥离/mds/黑磷烯")
pdf_dir = os.path.join(_cur_dir, "模型评估/模型评估")
output_dir = os.path.join(_cur_dir, "模型评估/mds")
# pdf_dir = os.path.join(_cur_dir, "金纳米棒/金纳米棒")
# output_dir = os.path.join(_cur_dir, "金纳米棒/mds")
# pdf_dir = os.path.join(_cur_dir, "钙钛矿/钙钛矿-复合材料")
# output_dir = os.path.join(_cur_dir, "钙钛矿/mds/复合材料")
# pdf_dir = os.path.join(_cur_dir, "钙钛矿/钙钛矿-LAPR/PDF论文")
# output_dir = os.path.join(_cur_dir, "钙钛矿/mds/LAPR")
os.makedirs(output_dir, exist_ok=True)
pdf_paths = sorted(glob(os.path.join(pdf_dir, "*.pdf")))
print("pdf数量", len(pdf_paths))
# 输出目录中md文件的数量
md_paths = sorted(glob(os.path.join(output_dir, "**", "*.md"), recursive=True))
md_names = [os.path.basename(md_path) for md_path in md_paths]
pdf_paths = [pdf_path for pdf_path in pdf_paths if os.path.basename(pdf_path).replace('.pdf', '.md') not in md_names]
print("过滤后pdf数量", len(pdf_paths))
# # 设置GPU数量
# num_gpus = 2 # 先用2个GPU测试
# # 设置多进程启动方法
# mp.set_start_method('spawn', force=True)
# try:
# multiprocessing_setup(pdf_paths, num_gpus, output_dir)
# except Exception as e:
# print(f"程序执行出错: {str(e)}")
# pdf_path = "black_phosphorus/参考文献/2015.03-ACS Nano-Barbaros Özyilmaz-石墨烯接触、全封装的超薄黑磷基场效应晶体管中的空气稳定传输.pdf"
for pdf_path in tqdm(pdf_paths):
pdf_parse_main(pdf_path, output_dir=output_dir)