Files
datapipe/clean/step4_preprocess_mineru_multi_with_database.py
2025-01-18 17:09:51 +08:00

425 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
import os
import json
import copy
import requests
import time
import shutil
import uuid
import sqlite3
import PyPDF2
import multiprocessing
import mysql.connector
from concurrent.futures import ThreadPoolExecutor, as_completed
from loguru import logger
from glob import glob
from tqdm import tqdm
from datetime import datetime
import asyncio
from magic_pdf.pipe.UNIPipe import UNIPipe
from magic_pdf.pipe.OCRPipe import OCRPipe
from magic_pdf.pipe.TXTPipe import TXTPipe
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
import magic_pdf.model as model_config
model_config.__use_inside_model__ = True
# 图床配置
# IMGBED_URL = "http://localhost:40027/"
IMGBED_URL = "http://172.20.103.171:40027/"
# 检查imgbed url是否以/结尾
if not IMGBED_URL.endswith('/'):
IMGBED_URL += '/'
token_endpoint = f"{IMGBED_URL}api/v1/tokens"
upload_endpoint = f"{IMGBED_URL}api/v1/upload"
# 通过如下方式获取token
# curl -X POST http://localhost:40027/api/v1/tokens -H "Content-Type: application/json" -d '{"email":"yt.li2@siat.ac.cn", "password":"lyt20000414."}'
IMGBED_TOKEN = "6|QsBh5H7txY3Hd7ju1nzYKOBSdFQeL0YberydSFIH"
def replace_image_links(md_content: str, images_urls: dict) -> str:
# 匹配 Markdown 中的图像链接形式,即: ![alt text](image_path)
pattern = r'!\[(.*?)\]\((.*?)\)'
def replace_link(match):
# 提取出当前匹配到的图片路径
image_path = match.group(2)
# 检查该路径是否在字典中
if image_path in images_urls:
# 从字典中获取新的 URL
new_url = images_urls[image_path]
return f"![]({new_url})"
return match.group(0)
# 使用 sub 函数进行替换
updated_md_content = re.sub(pattern, replace_link, md_content)
return updated_md_content
# 上传图片到LSKY Pro
def upload_image(img_dir):
headers = {
"Authorization": f"Bearer {IMGBED_TOKEN}",
'Accept': 'application/json'
}
image_urls = {}
img_names = os.listdir(img_dir)
for image_name in img_names:
retry = 0
image_path = os.path.join(img_dir, image_name)
while retry < 5: # 最大重试次数
try:
with open(image_path, 'rb') as image_file: # 确保文件在上传时是打开状态
files = {'file': image_file}
# 上传文件
response = requests.post(upload_endpoint, headers=headers, files=files)
if response.status_code == 200:
result = response.json()
if result['status']:
image_url = result['data']['links']['url']
image_urls['images/'+image_name] = image_url
print(f"图片上传成功: {image_url}")
break # 上传成功,退出重试循环
else:
raise Exception(f"图片上传失败: {result['message']}")
elif response.status_code == 429:
# 429 响应,等待一段时间再重试
wait_time = 3
# wait_time = min(2 ** retry, 10) # 指数退避,最大等待 10 秒
# logger.warning(f"请求过于频繁,等待 {wait_time} 秒...")
print(f"请求过于频繁,等待 {wait_time} 秒...")
time.sleep(wait_time)
else:
raise Exception(f"HTTP请求出错: {response.status_code}")
retry += 1 # 增加重试次数
time.sleep(1) # 在重试失败后稍等一下
except FileNotFoundError:
logger.error(f"文件 {image_path} 不存在,请检查路径是否正确")
return
return image_urls
# 保存图片到本地,并确保生成的文件名唯一
def save_images_locally(img_dir, target_dir):
if not os.path.exists(target_dir):
os.makedirs(target_dir)
image_urls = {}
img_names = os.listdir(img_dir)
# 遍历图片并保存到目标文件夹
for image_name in img_names:
image_path = os.path.join(img_dir, image_name)
# 使用UUID生成唯一的文件名以保持图片名称的唯一性
unique_name = f"{uuid.uuid4()}{os.path.splitext(image_name)[1]}" # 保留原扩展名
save_path = os.path.join(target_dir, unique_name)
try:
# 复制文件到目标目录
shutil.copy2(image_path, save_path)
# 将图片名称与保存路径加入字典
image_urls[f'images/{unique_name}'] = save_path
print(f"图片保存成功: {save_path}")
except FileNotFoundError:
print(f"文件 {image_path} 不存在,跳过该图片")
except Exception as e:
print(f"保存图片 {image_name} 过程中发生错误: {e}")
return image_urls
def json_md_dump(
pipe,
md_writer,
pdf_name,
content_list,
md_content,
):
# 写入模型结果到 model.json
orig_model_list = copy.deepcopy(pipe.model_list)
md_writer.write(
content=json.dumps(orig_model_list, ensure_ascii=False, indent=4),
path=f"{pdf_name}_model.json"
)
# 写入中间结果到 middle.json
md_writer.write(
content=json.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4),
path=f"{pdf_name}_middle.json"
)
# text文本结果写入到 conent_list.json
md_writer.write(
content=json.dumps(content_list, ensure_ascii=False, indent=4),
path=f"{pdf_name}_content_list.json"
)
# 写入结果到 .md 文件中
md_writer.write(
content=md_content,
path=f"{pdf_name}.md"
)
def pdf_parse_main(
pdf_path: str,
parse_method: str = 'auto',
model_json_path: str = None,
is_json_md_dump: bool = True,
output_dir: str = None
):
"""
执行从 pdf 转换到 json、md 的过程,输出 md 和 json 文件到 pdf 文件所在的目录
:param pdf_path: .pdf 文件的路径,可以是相对路径,也可以是绝对路径
:param parse_method: 解析方法, 共 auto、ocr、txt 三种,默认 auto如果效果不好可以尝试 ocr
:param model_json_path: 已经存在的模型数据文件如果为空则使用内置模型pdf 和 model_json 务必对应
:param is_json_md_dump: 是否将解析后的数据写入到 .json 和 .md 文件中,默认 True会将不同阶段的数据写入到不同的 .json 文件中共3个.json文件md内容会保存到 .md 文件中
:param output_dir: 输出结果的目录地址,会生成一个以 pdf 文件名命名的文件夹并保存所有结果
"""
try:
pdf_name = os.path.basename(pdf_path).split("/")[-1].replace(".pdf", "")
pdf_path_parent = os.path.dirname(pdf_path)
if output_dir:
output_path = os.path.join(output_dir, pdf_name)
else:
output_path = os.path.join(pdf_path_parent, pdf_name)
output_image_path = os.path.join(output_path, 'images')
# 获取图片的父路径,为的是以相对路径保存到 .md 和 conent_list.json 文件中
image_path_parent = os.path.basename(output_image_path)
pdf_bytes = open(pdf_path, "rb").read() # 读取 pdf 文件的二进制数据
if model_json_path:
# 读取已经被模型解析后的pdf文件的 json 原始数据list 类型
model_json = json.loads(open(model_json_path, "r", encoding="utf-8").read())
else:
model_json = []
# 执行解析步骤
# image_writer = DiskReaderWriter(output_image_path)
image_writer, md_writer = DiskReaderWriter(output_image_path), DiskReaderWriter(output_path)
# 选择解析方式
# jso_useful_key = {"_pdf_type": "", "model_list": model_json}
# pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer)
if parse_method == "auto":
jso_useful_key = {"_pdf_type": "", "model_list": model_json}
pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer)
elif parse_method == "txt":
pipe = TXTPipe(pdf_bytes, model_json, image_writer)
elif parse_method == "ocr":
pipe = OCRPipe(pdf_bytes, model_json, image_writer)
else:
logger.error("unknown parse method, only auto, ocr, txt allowed")
exit(1)
# 执行分类
pipe.pipe_classify()
# 如果没有传入模型数据,则使用内置模型解析
if not model_json:
if model_config.__use_inside_model__:
pipe.pipe_analyze() # 解析
else:
logger.error("need model list input")
exit(1)
# 执行解析
pipe.pipe_parse()
# 保存 text 和 md 格式的结果
content_list = pipe.pipe_mk_uni_format(image_path_parent, drop_mode="none")
md_content = pipe.pipe_mk_markdown(image_path_parent, drop_mode="none")
# 上传图像到图床
# image_urls = upload_image(output_image_path)
# 保存图像到本地
target_dir = "mp_cif/images"
image_urls = save_images_locally(output_image_path, target_dir)
md_content = replace_image_links(md_content, image_urls)
mysql_connection = mysql.connector.connect(
host='100.84.94.73',
user='metadata_mat_papers',
password='siat-mic',
database='metadata_mat_papers',
charset="utf8mb4", # 设置连接使用 utf8mb4
collation="utf8mb4_unicode_ci" # 使用适当的 collation
)
mysql_cursor = mysql_connection.cursor()
table = 'mp_cif_info'
# path = 'phosphorus/pdfs/' + pdf_name + '.pdf'
# print("path:", path)
doi = os.path.basename(pdf_path).replace(".pdf", "").replace('_', '/')
try:
# 编写query语句
query = f"UPDATE {table} SET en_text_content = %s WHERE doi = %s"
mysql_cursor.execute(query, (md_content, doi))
print(f"{doi},md保存成功")
# 提交更改到数据库
mysql_connection.commit()
except mysql.connector.Error as error:
print("Failed to insert record into MySQL table: {}".format(error))
# 如果发生错误,撤回事务
mysql_connection.rollback()
finally:
# 关闭游标和连接
mysql_cursor.close()
mysql_connection.close()
if is_json_md_dump:
json_md_dump(pipe, md_writer, pdf_name, content_list, md_content)
return 'sucess'
except Exception as e:
logger.exception(e)
return 'error'
def check_doi_not_in_db(pdf_name, cursor):
query = f"SELECT * FROM doi_status WHERE doi = ? AND convert_status = ? "
cursor.execute(query, (pdf_name, 'unprocessed'))
res = cursor.fetchone()
if res:
return True
else:
return False
def init_worker(devices, pdfs, gpu_index, process_id):
"""
Initialize a worker process to process a chunk of PDFs with a specific GPU.
"""
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_index)
process_pdf_chunk(pdfs, gpu_index, process_id)
def get_converted2md_dois():
table = 'mp_cif_info'
dois = []
mysql_connection = mysql.connector.connect(
host='100.84.94.73',
user='metadata_mat_papers',
password='siat-mic',
database='metadata_mat_papers',
charset="utf8mb4", # 设置连接使用 utf8mb4
collation="utf8mb4_unicode_ci" # 使用适当的 collation
)
mysql_cursor = mysql_connection.cursor()
try:
# 编写query语句
query = f"SELECT doi FROM {table} WHERE en_text_content IS NOT NULL;"
mysql_cursor.execute(query)
res = mysql_cursor.fetchall()
dois = [row[0] for row in res if row]
except mysql.connector.Error as error:
# 如果发生错误,撤回事务
mysql_connection.rollback()
finally:
# 关闭游标和连接
mysql_cursor.close()
mysql_connection.close()
return dois
def is_within_operational_hours(start_hour, end_hour):
now = datetime.now().time() # 获取当前时间(不含日期)
current_hour = now.hour # 获取当前小时
# 检查是否在晚上6点到第二天早上9点范围
if start_hour > end_hour:
return (current_hour >= start_hour or current_hour < end_hour) # 跨过午夜
else:
return start_hour <= current_hour < end_hour
def process_pdf_chunk(pdf_paths, gpu_index, process_id):
for pdf_path in tqdm(pdf_paths, desc=f"Worker {gpu_index}_{process_id} Progress"):
# 在规定时间内运行任务
start_hour = 15 # 18点(晚上6点)
end_hour = 9 # 9点(次日早上9点)
# 检查当前时间是否在允许的时间范围
while True:
if is_within_operational_hours(start_hour, end_hour):
print("当前时间在任务运行区间内开始处理PDF文件...")
try:
with open(pdf_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
print(os.path.basename(pdf_path).replace(".pdf", "").replace('_', '/'))
status = pdf_parse_main(pdf_path, parse_method='auto', output_dir=output_dir)
break # 执行结束,跳出循环
except PyPDF2.errors.PdfReadError:
logger.error(f"{pdf_path} has been broken")
break # 执行异常,跳出循环
except Exception as e:
logger.error(f"{pdf_path} has an error: {e}")
break # 执行异常,跳出循环
else:
# 当前时间不在允许的时间范围,阻塞任务
print("当前时间不在运行区间,稍后重试...")
time.sleep(60 * 60) # 沉睡1小时后再次检查
def multiprocessing_setup(pdf_paths, num_gpus):
num_processes_per_gpu = 3
chunk_size = len(pdf_paths) // (num_gpus * num_processes_per_gpu)
processes = []
# Create processes for each GPU
for gpu_id in range(num_gpus):
for process_id in range(num_processes_per_gpu):
start_idx = (gpu_id * num_processes_per_gpu + process_id) * chunk_size
end_idx = None if (gpu_id == num_gpus - 1 and process_id == num_processes_per_gpu - 1) else start_idx + chunk_size
chunk = pdf_paths[start_idx:end_idx]
p = multiprocessing.Process(target=init_worker, args=([gpu_id], chunk, gpu_id, process_id))
processes.append(p)
p.start()
# Ensure all processes have completed
for p in processes:
p.join()
if __name__ == '__main__':
_cur_dir = os.path.dirname(os.path.abspath(__file__))
# 此处更改路径
pdf_dir = os.path.join(_cur_dir, "mp_cif/pdfs")
output_dir = os.path.join(_cur_dir, "mp_cif/mds")
os.makedirs(output_dir, exist_ok=True)
pdf_paths = sorted(glob(os.path.join(pdf_dir, "*.pdf")))
dois = get_converted2md_dois()
print(len(dois))
new_pdf_paths = pdf_paths[:]
for path in tqdm(pdf_paths):
doi = os.path.basename(path).replace(".pdf", "").replace('_', '/')
if doi in dois:
new_pdf_paths.remove(path)
print(len(new_pdf_paths))
# Number of GPUs
num_gpus = 8
# Setup multiprocessing to handle PDFs across multiple GPUs
multiprocessing_setup(new_pdf_paths, num_gpus)