第一次合并clean代码

This commit is contained in:
2025-01-18 17:09:51 +08:00
parent e33a8b069e
commit a0f5ca9a35
21 changed files with 2252 additions and 375 deletions

0
clean/__init__.py Normal file
View File

273
clean/preprocess_mineru.py Normal file
View File

@@ -0,0 +1,273 @@
import re
import os
import json
import copy
import requests
import time
import sqlite3
import PyPDF2
import multiprocessing
import mysql.connector
from loguru import logger
from glob import glob
from tqdm import tqdm
from magic_pdf.pipe.UNIPipe import UNIPipe
from magic_pdf.pipe.OCRPipe import OCRPipe
from magic_pdf.pipe.TXTPipe import TXTPipe
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
import magic_pdf.model as model_config
model_config.__use_inside_model__ = True
# 图床配置
IMGBED_URL = "http://localhost:40027/"
# 检查imgbed url是否以/结尾
if not IMGBED_URL.endswith('/'):
IMGBED_URL += '/'
token_endpoint = f"{IMGBED_URL}api/v1/tokens"
upload_endpoint = f"{IMGBED_URL}api/v1/upload"
# 通过如下方式获取token
# curl -X POST http://localhost:40027/api/v1/tokens -H "Content-Type: application/json" -d '{"email":"yt.li2@siat.ac.cn", "password":"lyt20000414."}'
IMGBED_TOKEN = "6|QsBh5H7txY3Hd7ju1nzYKOBSdFQeL0YberydSFIH"
def replace_image_links(md_content: str, images_urls: dict) -> str:
# 匹配 Markdown 中的图像链接形式,即: ![alt text](image_path)
pattern = r'!\[(.*?)\]\((.*?)\)'
def replace_link(match):
# 提取出当前匹配到的图片路径
image_path = match.group(2)
# 检查该路径是否在字典中
if image_path in images_urls:
# 从字典中获取新的 URL
new_url = images_urls[image_path]
return f"![]({new_url})"
return match.group(0)
# 使用 sub 函数进行替换
updated_md_content = re.sub(pattern, replace_link, md_content)
return updated_md_content
# 上传图片到LSKY Pro
def upload_image(img_dir):
headers = {
"Authorization": f"Bearer {IMGBED_TOKEN}",
'Accept': 'application/json'
}
image_urls = {}
os.makedirs(img_dir, exist_ok=True)
img_names = os.listdir(img_dir)
for image_name in img_names:
retry = 0
image_path = os.path.join(img_dir, image_name)
while retry < 5: # 最大重试次数
try:
with open(image_path, 'rb') as image_file: # 确保文件在上传时是打开状态
files = {'file': image_file}
# 上传文件
response = requests.post(upload_endpoint, headers=headers, files=files)
if response.status_code == 200:
result = response.json()
if result['status']:
image_url = result['data']['links']['url']
image_urls['images/'+image_name] = image_url
break # 上传成功,退出重试循环
else:
raise Exception(f"图片上传失败: {result['message']}")
elif response.status_code == 429:
# 429 响应,等待一段时间再重试
wait_time = min(2 ** retry, 60) # 指数退避,最大等待 60 秒
logger.warning(f"请求过于频繁,等待 {wait_time} 秒...")
time.sleep(wait_time)
else:
raise Exception(f"HTTP请求出错: {response.status_code}")
retry += 1 # 增加重试次数
time.sleep(1) # 在重试失败后稍等一下
except FileNotFoundError:
logger.error(f"文件 {image_path} 不存在,请检查路径是否正确")
return
return image_urls
def json_md_dump(
pipe,
md_writer,
pdf_name,
content_list,
md_content,
):
# 写入模型结果到 model.json
orig_model_list = copy.deepcopy(pipe.model_list)
md_writer.write(
content=json.dumps(orig_model_list, ensure_ascii=False, indent=4),
path=f"{pdf_name}_model.json"
)
# 写入中间结果到 middle.json
md_writer.write(
content=json.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4),
path=f"{pdf_name}_middle.json"
)
# text文本结果写入到 conent_list.json
md_writer.write(
content=json.dumps(content_list, ensure_ascii=False, indent=4),
path=f"{pdf_name}_content_list.json"
)
# 写入结果到 .md 文件中
md_writer.write(
content=md_content,
path=f"{pdf_name}.md"
)
def pdf_parse_main(
pdf_path: str,
parse_method: str = 'auto',
model_json_path: str = None,
is_json_md_dump: bool = True,
output_dir: str = None
):
"""
执行从 pdf 转换到 json、md 的过程,输出 md 和 json 文件到 pdf 文件所在的目录
:param pdf_path: .pdf 文件的路径,可以是相对路径,也可以是绝对路径
:param parse_method: 解析方法, 共 auto、ocr、txt 三种,默认 auto如果效果不好可以尝试 ocr
:param model_json_path: 已经存在的模型数据文件如果为空则使用内置模型pdf 和 model_json 务必对应
:param is_json_md_dump: 是否将解析后的数据写入到 .json 和 .md 文件中,默认 True会将不同阶段的数据写入到不同的 .json 文件中共3个.json文件md内容会保存到 .md 文件中
:param output_dir: 输出结果的目录地址,会生成一个以 pdf 文件名命名的文件夹并保存所有结果
"""
try:
pdf_name = os.path.basename(pdf_path).split("/")[-1].replace(".pdf", "")
pdf_path_parent = os.path.dirname(pdf_path)
if output_dir:
output_path = os.path.join(output_dir, pdf_name)
else:
output_path = os.path.join(pdf_path_parent, pdf_name)
output_image_path = os.path.join(output_path, 'images')
# 获取图片的父路径,为的是以相对路径保存到 .md 和 conent_list.json 文件中
image_path_parent = os.path.basename(output_image_path)
pdf_bytes = open(pdf_path, "rb").read() # 读取 pdf 文件的二进制数据
if model_json_path:
# 读取已经被模型解析后的pdf文件的 json 原始数据list 类型
model_json = json.loads(open(model_json_path, "r", encoding="utf-8").read())
else:
model_json = []
# 执行解析步骤
# image_writer = DiskReaderWriter(output_image_path)
image_writer, md_writer = DiskReaderWriter(output_image_path), DiskReaderWriter(output_path)
# 选择解析方式
# jso_useful_key = {"_pdf_type": "", "model_list": model_json}
# pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer)
if parse_method == "auto":
jso_useful_key = {"_pdf_type": "", "model_list": model_json}
pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer)
elif parse_method == "txt":
pipe = TXTPipe(pdf_bytes, model_json, image_writer)
elif parse_method == "ocr":
pipe = OCRPipe(pdf_bytes, model_json, image_writer)
else:
logger.error("unknown parse method, only auto, ocr, txt allowed")
exit(1)
# 执行分类
pipe.pipe_classify()
# 如果没有传入模型数据,则使用内置模型解析
if not model_json:
if model_config.__use_inside_model__:
pipe.pipe_analyze() # 解析
else:
logger.error("need model list input")
exit(1)
# 执行解析
pipe.pipe_parse()
# 保存 text 和 md 格式的结果
content_list = pipe.pipe_mk_uni_format(image_path_parent, drop_mode="none")
md_content = pipe.pipe_mk_markdown(image_path_parent, drop_mode="none")
# 上传图像到图床
image_urls = upload_image(output_image_path)
md_content = replace_image_links(md_content, image_urls)
if is_json_md_dump:
json_md_dump(pipe, md_writer, pdf_name, content_list, md_content)
return 'sucess'
except Exception as e:
logger.exception(e)
return 'error'
def init_worker(devices, pdfs, gpu_index):
"""
Initialize a worker process to process a chunk of PDFs with a specific GPU.
"""
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_index)
process_pdf_chunk(pdfs, gpu_index)
def process_pdf_chunk(pdf_paths, worker_id):
for pdf_path in tqdm(pdf_paths, desc=f"Worker {worker_id} Progress"):
try:
with open(pdf_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
print(os.path.basename(pdf_path).replace(".pdf", "").replace('_', '/'))
status = pdf_parse_main(pdf_path, parse_method='auto', output_dir=output_dir)
except PyPDF2.errors.PdfReadError:
logger.error(f"{pdf_path} has been broken")
except Exception as e:
logger.error(f"{pdf_path} has an error: {e}")
def multiprocessing_setup(pdf_paths, num_gpus):
num_processes_per_gpu = 2
chunk_size = len(pdf_paths) // (num_gpus * num_processes_per_gpu)
processes = []
# Create processes for each GPU
for gpu_id in range(num_gpus):
for process_id in range(num_processes_per_gpu):
start_idx = (gpu_id * num_processes_per_gpu + process_id) * chunk_size
end_idx = None if (gpu_id == num_gpus - 1 and process_id == num_processes_per_gpu - 1) else start_idx + chunk_size
chunk = pdf_paths[start_idx:end_idx]
p = multiprocessing.Process(target=init_worker, args=([gpu_id], chunk, gpu_id))
processes.append(p)
p.start()
# Ensure all processes have completed
for p in processes:
p.join()
if __name__ == "__main__":
_cur_dir = os.path.dirname(os.path.abspath(__file__))
# 此处更改路径
pdf_dir = os.path.join(_cur_dir, "black_phosphorus_wulie/黑磷文献/黑磷文献-任务1-推荐官能团")
output_dir = os.path.join(_cur_dir, "black_phosphorus_wulie/黑磷文献-任务1-推荐官能团_pdf2md")
os.makedirs(output_dir, exist_ok=True)
pdf_paths = sorted(glob(os.path.join(pdf_dir, "*.pdf")))
print("pdf数量", len(pdf_paths))
# Number of GPUs
num_gpus = 8
# Setup multiprocessing to handle PDFs across multiple GPUs
# multiprocessing_setup(pdf_paths, num_gpus)
pdf_path = "/home/ubuntu/sas0/LYT/paper_dataset/black_phosphorus_wulie/黑磷文献/黑磷文献-任务1-推荐官能团/P-O,P-O-PSupporting_information.pdf"
pdf_parse_main(pdf_path, parse_method='auto', output_dir=output_dir)

View File

@@ -0,0 +1,245 @@
import re
import os
import requests
import time
import PyPDF2
import multiprocessing as mp
import math
import sys
import torch
from loguru import logger
from glob import glob
from tqdm import tqdm
from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.config.enums import SupportedPdfParseMethod
# 图床配置
IMGBED_URL = "http://localhost:40027/"
# 检查imgbed url是否以/结尾
if not IMGBED_URL.endswith('/'):
IMGBED_URL += '/'
token_endpoint = f"{IMGBED_URL}api/v1/tokens"
upload_endpoint = f"{IMGBED_URL}api/v1/upload"
# 通过如下方式获取token
# curl -X POST http://localhost:40027/api/v1/tokens -H "Content-Type: application/json" -d '{"email":"yt.li2@siat.ac.cn", "password":"lyt20000414."}'
IMGBED_TOKEN = "6|QsBh5H7txY3Hd7ju1nzYKOBSdFQeL0YberydSFIH"
def replace_image_links(md_content: str, images_urls: dict) -> str:
# 匹配 Markdown 中的图像链接形式,即: ![alt text](image_path)
pattern = r'!\[(.*?)\]\((.*?)\)'
def replace_link(match):
# 提取出当前匹配到的图片路径
image_path = match.group(2)
# 检查该路径是否在字典中
if image_path in images_urls:
# 从字典中获取新的 URL
new_url = images_urls[image_path]
return f"![]({new_url})"
return match.group(0)
# 使用 sub 函数进行替换
updated_md_content = re.sub(pattern, replace_link, md_content)
return updated_md_content
# 上传图片到LSKY Pro
def upload_image(img_dir):
headers = {
"Authorization": f"Bearer {IMGBED_TOKEN}",
'Accept': 'application/json'
}
image_urls = {}
os.makedirs(img_dir, exist_ok=True)
img_names = os.listdir(img_dir)
for image_name in img_names:
retry = 0
image_path = os.path.join(img_dir, image_name)
while retry < 5: # 最大重试次数
try:
with open(image_path, 'rb') as image_file: # 确保文件在上传时是打开状态
files = {'file': image_file}
# 上传文件
response = requests.post(upload_endpoint, headers=headers, files=files)
if response.status_code == 200:
result = response.json()
if result['status']:
image_url = result['data']['links']['url']
image_urls['images/'+image_name] = image_url
break # 上传成功,退出重试循环
else:
raise Exception(f"图片上传失败: {result['message']}")
elif response.status_code == 429:
# 429 响应,等待一段时间再重试
wait_time = min(2 ** retry, 60) # 指数退避,最大等待 60 秒
logger.warning(f"请求过于频繁,等待 {wait_time} 秒...")
time.sleep(wait_time)
else:
raise Exception(f"HTTP请求出错: {response.status_code}")
retry += 1 # 增加重试次数
time.sleep(1) # 在重试失败后稍等一下
except FileNotFoundError:
logger.error(f"文件 {image_path} 不存在,请检查路径是否正确")
return
return image_urls
def pdf_parse_main(
pdf_path: str,
output_dir: str = None
):
try:
name_without_suff = os.path.basename(pdf_path).replace('.pdf', '')
# prepare env
local_md_dir = os.path.join(output_dir, name_without_suff)
local_image_dir = os.path.join(local_md_dir, 'images')
image_dir = str(os.path.basename(local_image_dir))
os.makedirs(local_image_dir, exist_ok=True)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(
local_md_dir
)
# read bytes
reader1 = FileBasedDataReader("")
pdf_bytes = reader1.read(pdf_path) # read the pdf content
# proc
## Create Dataset Instance
ds = PymuDocDataset(pdf_bytes)
## inference
if ds.classify() == SupportedPdfParseMethod.OCR:
infer_result = ds.apply(doc_analyze, ocr=True)
## pipeline
pipe_result = infer_result.pipe_ocr_mode(image_writer)
else:
infer_result = ds.apply(doc_analyze, ocr=False)
## pipeline
pipe_result = infer_result.pipe_txt_mode(image_writer)
### draw model result on each page
infer_result.draw_model(os.path.join(local_md_dir, f"{name_without_suff}_model.pdf"))
### draw layout result on each page
pipe_result.draw_layout(os.path.join(local_md_dir, f"{name_without_suff}_layout.pdf"))
### draw spans result on each page
pipe_result.draw_span(os.path.join(local_md_dir, f"{name_without_suff}_spans.pdf"))
### dump markdown
md_content = pipe_result.dump_md(md_writer, os.path.join(local_md_dir, f"{name_without_suff}.md"), image_dir)
### dump content list
pipe_result.dump_content_list(md_writer, os.path.join(local_md_dir, f"{name_without_suff}_content_list.json"), image_dir)
# print(md_content)
# 上传图像到图床
image_urls = upload_image(local_image_dir)
md_content = replace_image_links(md_content, image_urls)
md_writer.write_string(os.path.join(local_md_dir, f"{name_without_suff}.md"), md_content)
except Exception as e:
logger.exception(e)
return 'error'
def init_worker(pdfs, gpu_index, output_dir): # 添加output_dir参数
"""
Initialize a worker process to process a chunk of PDFs with a specific GPU.
"""
try:
# 设置CUDA设备
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_index)
import torch
device = torch.device('cuda:0')
print(f"进程 {os.getpid()} 启动于GPU {gpu_index}")
print(f"处理 {len(pdfs)} 个PDF文件")
process_pdf_chunk(pdfs, device, output_dir) # 传递output_dir
except Exception as e:
print(f"进程 {os.getpid()} 在GPU {gpu_index} 上初始化失败: {str(e)}")
raise e
def process_pdf_chunk(pdf_paths, worker_id, output_dir):
for pdf_path in tqdm(pdf_paths, desc=f"Worker {worker_id} Progress"):
try:
# 定期清理GPU内存
torch.cuda.empty_cache()
with open(pdf_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
print(os.path.basename(pdf_path).replace(".pdf", "").replace('_', '/'))
pdf_parse_main(pdf_path, output_dir=output_dir)
except PyPDF2.errors.PdfReadError:
logger.error(f"{pdf_path} has been broken")
except Exception as e:
logger.error(f"{pdf_path} has an error: {e}")
def multiprocessing_setup(pdf_paths, num_gpus, output_dir):
# 计算每个GPU处理的文件数量
chunk_size = math.ceil(len(pdf_paths) / num_gpus)
processes = []
# 为每个GPU创建一个进程
for gpu_id in range(num_gpus):
start_idx = gpu_id * chunk_size
end_idx = min(len(pdf_paths), start_idx + chunk_size)
chunk = pdf_paths[start_idx:end_idx]
p = mp.Process(target=init_worker, args=(chunk, gpu_id, output_dir)) # 传递output_dir
processes.append(p)
p.start()
time.sleep(2)
# 等待所有进程完成
for p in processes:
p.join()
if __name__ == "__main__":
_cur_dir = os.path.dirname(os.path.abspath(__file__))
# 此处更改路径
# pdf_dir = os.path.join(_cur_dir, "二维材料剥离/二维材料剥离/石墨烯")
# output_dir = os.path.join(_cur_dir, "二维材料剥离/mds/石墨烯")
# pdf_dir = os.path.join(_cur_dir, "二维材料剥离/二维材料剥离/黑磷烯")
# output_dir = os.path.join(_cur_dir, "二维材料剥离/mds/黑磷烯")
pdf_dir = os.path.join(_cur_dir, "模型评估/模型评估")
output_dir = os.path.join(_cur_dir, "模型评估/mds")
# pdf_dir = os.path.join(_cur_dir, "金纳米棒/金纳米棒")
# output_dir = os.path.join(_cur_dir, "金纳米棒/mds")
# pdf_dir = os.path.join(_cur_dir, "钙钛矿/钙钛矿-复合材料")
# output_dir = os.path.join(_cur_dir, "钙钛矿/mds/复合材料")
# pdf_dir = os.path.join(_cur_dir, "钙钛矿/钙钛矿-LAPR/PDF论文")
# output_dir = os.path.join(_cur_dir, "钙钛矿/mds/LAPR")
os.makedirs(output_dir, exist_ok=True)
pdf_paths = sorted(glob(os.path.join(pdf_dir, "*.pdf")))
print("pdf数量", len(pdf_paths))
# 输出目录中md文件的数量
md_paths = sorted(glob(os.path.join(output_dir, "**", "*.md"), recursive=True))
md_names = [os.path.basename(md_path) for md_path in md_paths]
pdf_paths = [pdf_path for pdf_path in pdf_paths if os.path.basename(pdf_path).replace('.pdf', '.md') not in md_names]
print("过滤后pdf数量", len(pdf_paths))
# # 设置GPU数量
# num_gpus = 2 # 先用2个GPU测试
# # 设置多进程启动方法
# mp.set_start_method('spawn', force=True)
# try:
# multiprocessing_setup(pdf_paths, num_gpus, output_dir)
# except Exception as e:
# print(f"程序执行出错: {str(e)}")
# pdf_path = "black_phosphorus/参考文献/2015.03-ACS Nano-Barbaros Özyilmaz-石墨烯接触、全封装的超薄黑磷基场效应晶体管中的空气稳定传输.pdf"
for pdf_path in tqdm(pdf_paths):
pdf_parse_main(pdf_path, output_dir=output_dir)

319
clean/reparagraph.py Executable file
View File

@@ -0,0 +1,319 @@
"""
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import os
import re
import json
from tqdm import tqdm
import logging
from openai import OpenAI
from config import ReparagraphConfig
# 配置logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('reparagraph.log'),
logging.StreamHandler()
]
)
def get_true_level(title_info: list, config: ReparagraphConfig):
source_title = json.dumps(title_info)
instruction = """
你是一个论文目录重排助手。
有如下的JSON格式的目录信息,已知目录中每级标题的内容和行号。
<PLACEHOLDER>
请你重排该论文的目录层级,并为每级标题的level字段给出正确的层级关系,其中层级关系用数字(1,2,3,4)表示,数字越小,层级越高。
注意:重排序目录要求多个1级标题的样式, 而非单一1级目录的样式。也就说level为1的标题数量必须大于1。
通常情况下位于一级标题的有可能是:
1. 论文的题目
2. 论文的摘要(Abstract)
3. 论文的介绍(Introduction)
4. 论文的方法或实验(Methods or Experiment)
5. 论文的结果或讨论(Result or Discussion)
6. 论文的结论(Conclusion)
7. 论文的参考文献(References)
8. 论文的致谢(Acknowledgments)
9. 论文的附录(Appendix)
10. 论文的支撑信息(Supporting Information)
有时候目录中存在序号,这时则优先使用序号顺序重建目录。
返回结果的时候严格遵守下列示例JSON格式:
{ 'data': [
{ 'title': 'A hierarchically porous MOF confined CsPbBr3 quantum dots: Fluorescence switching probe for detecting Cu (II) and melamine in food samples', 'line_num': 1, 'level': 1},
...
]
"""
# 创建 OpenAI 客户端
client = OpenAI(api_key=config.openai_api_key, base_url=config.openai_base_url)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": instruction.replace("<PLACEHOLDER>", source_title)}
]
attempt = 0
while attempt < config.max_retries:
try:
completion = client.chat.completions.create(
model=config.model_name,
stream=False, # 关闭流模式
messages=messages,
response_format={
'type': 'json_object'
}
)
response = completion.choices[0].message.content
response = json.loads(response)
count_level_1 = sum(1 for item in response['data'] if item['level'] == 1)
if count_level_1 == 1:
attempt += 1
messages.append({"role": "assistant", "content": str(response)})
messages.append({"role": "user", "content": "上述目录中仅有1个1级标题, 请重新生成目录, 并保证目录中至少有两个1级标题。"})
continue
return response['data']
except (json.JSONDecodeError, Exception) as e:
logging.error(f"尝试 {attempt + 1}/{config.max_retries} 失败: {str(e)}")
if attempt == config.max_retries - 1:
logging.error("达到最大重试次数,放弃操作")
return "Error"
def read_file_content(file_path: str):
"""读取文件内容"""
with open(file_path, 'r', encoding='utf-8') as file:
return file.readlines()
def write_file_content(file_path: str, content: list):
"""写入文件内容"""
with open(file_path, 'w', encoding='utf-8') as file:
file.writelines(content)
def extract_headings(lines: list):
"""从文件内容中提取所有以#开头的行及其行号"""
headings = []
for line_num, line in enumerate(lines, 1):
if re.match(r'^#', line.strip()):
headings.append((line_num, line.strip()))
return headings
def extract_references(lines: list, headings: list, remove_refs: bool = False):
"""从文件内容中提取参考文献部分
Args:
lines: 文件内容列表
headings: 标题信息列表
remove_refs: 是否抹去参考文献内容
Returns:
dict: 包含起始点、结束点和内容的信息
{
'start': ref_start,
'end': ref_end,
'content': references,
'updated_headings': updated_headings
}
"""
# 在标题中查找REFERENCE
ref_heading = None
for line_num, heading in headings:
if "REFERENCE" in heading.upper().replace(" ", ""):
ref_heading = (line_num, heading)
break
if not ref_heading and "ACKNOWLEDGEMENT" in heading.upper().replace(" ", ""):
ref_heading = (line_num, heading)
if not ref_heading:
# 用正则匹配常见的引用格式并删除
# 包括:[数字]、数字.、(数字) 格式
ref_pattern = r'^(\[\d+\]|\d+\.|\(\d+\))'
lines = [line for line in lines if not re.match(ref_pattern, line.strip())]
return {
'start': -1,
'end': -1,
'content': None
}, lines
ref_start = ref_heading[0] - 1 # 转换为0-based索引
# 查找下一个标题或文件结尾
ref_end = len(lines)
for i in range(ref_start + 1, len(lines)):
if re.match(r'^#', lines[i].strip()):
ref_end = i
break
# 提取参考文献内容
references = ''.join(lines[ref_start:ref_end])
# 如果需要抹去内容
if remove_refs:
lines[ref_start:ref_end] = []
# # 如果需要更新headings
# updated_headings = headings
# if remove_refs and ref_heading:
# # 从headings中移除Reference标题
# updated_headings = [h for h in headings if h[1].upper() != ref_heading[1].upper()]
return {
'start': ref_start,
'end': ref_end,
'content': references,
#'updated_headings': updated_headings
}, lines
def update_headings(lines: list, heading_data: list):
"""根据提供的标题数据更新Markdown文件内容"""
# 统计heading_data中level==1的数量
# count_level_1 = sum(1 for item in heading_data if item['level'] == 1)
# flag = 2 if count_level_1 > 1 else 3 # 存在多个一级标题是为2否则为3
for heading in heading_data:
line_num = heading['line_num'] - 1
if heading['level'] >= 2:#flag:
lines[line_num] = "**" + lines[line_num].replace("#", "").strip() + "**\n"
return lines
def detect_file_encoding(file_path: str):
"""检测文件编码"""
import chardet
with open(file_path, 'rb') as f:
raw_data = f.read(1024)
result = chardet.detect(raw_data)
return result['encoding']
# def read_file_content(file_path: str, config: ReparagraphConfig):
# """读取文件内容,带大小检查和编码检测"""
# file_size = os.path.getsize(file_path)
# if file_size > config.max_file_size:
# logging.warning(f"文件 {file_path} 超过最大限制 {config.max_file_size} bytes跳过处理")
# return None
# encoding = detect_file_encoding(file_path)
# try:
# with open(file_path, 'r', encoding=encoding) as file:
# return file.readlines()
# except UnicodeDecodeError:
# logging.error(f"无法解码文件 {file_path}尝试使用utf-8")
# with open(file_path, 'r', encoding='utf-8') as file:
# return file.readlines()
def process_single_file(file_path: str, config: ReparagraphConfig):
"""处理单个文件并返回处理后的内容"""
# 读取文件内容
lines = read_file_content(file_path)
if lines is None:
return None
# 提取并更新标题
headings = extract_headings(lines)
title_info = [{"title": heading, "line_num": line_num, "level": "unknown"}
for line_num, heading in headings]
# 提取参考文献
ref_info, lines = extract_references(lines, headings, remove_refs=config.remove_refs)
if ref_info:
logging.info("提取的参考文献:")
logging.info(f"起始行: {ref_info['start'] + 1}")
logging.info(f"结束行: {ref_info['end']}")
logging.info("内容:")
logging.info(ref_info['content'])
# 更新headings
# headings = ref_info['updated_headings']
else:
logging.warning("未找到参考文献部分")
# 删除reference后可能会导致标题的行号变化重新索引
headings = extract_headings(lines)
title_info = [{"title": heading, "line_num": line_num, "level": "unknown"}
for line_num, heading in headings]
new_headings = get_true_level(title_info, config)
updated_lines = update_headings(lines, new_headings)
logging.info(f"文件处理完成: {file_path}")
return updated_lines
def create_output_dir(input_path: str, config: ReparagraphConfig):
"""创建输出目录"""
import os
from datetime import datetime
# 获取输入路径的父目录
parent_dir = os.path.dirname(input_path)
# 创建带时间戳的输出目录
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = os.path.join(parent_dir, f"{config.task_name}_{timestamp}")
os.makedirs(output_dir, exist_ok=True)
return output_dir
def save_processed_file(file_path: str, content: list, output_dir: str, input_path: str):
"""保存处理后的文件"""
import os
# 如果是单个文件
if os.path.isfile(input_path):
output_path = os.path.join(output_dir, os.path.basename(file_path))
else:
# 保持目录结构
relative_path = os.path.relpath(file_path, input_path)
output_path = os.path.join(output_dir, relative_path)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
f.writelines(content)
logging.info(f"已保存处理后的文件: {output_path}")
def reparagraph_file(path: str, config:ReparagraphConfig=None):
"""处理单个文件或文件夹中的所有.md文件
Args:
path: 文件路径或文件夹路径
config: ReparagraphConfig实例包含处理配置
Returns:
str: 输出目录路径
"""
import os
from concurrent.futures import ThreadPoolExecutor
if config is None:
config = ReparagraphConfig()
# 创建输出目录
output_dir = create_output_dir(path, config)
logging.info(f"输出目录: {output_dir}")
# 如果是文件夹,递归获取所有.md文件
if os.path.isdir(path):
files = []
for root, _, filenames in os.walk(path):
for filename in filenames:
if filename.endswith('.md'):
files.append(os.path.join(root, filename))
else:
files = [path]
def process_and_save(file_path: str):
content = process_single_file(file_path, config)
if content is not None and not config.dry_run:
save_processed_file(file_path, content, output_dir, path)
if config.parallel:
# 使用线程池并行处理
with ThreadPoolExecutor() as executor:
list(tqdm(executor.map(process_and_save, files), total=len(files), desc="Processing files"))
else:
# 顺序处理
for file_path in tqdm(files, desc="Processing files"):
process_and_save(file_path)
logging.info(f"处理完成,共处理 {len(files)} 个文件")
return output_dir

33
clean/step0_pdfs2sql.py Normal file
View File

@@ -0,0 +1,33 @@
import os
import tqdm
import sqlite3
import mysql.connector
def main():
cur_path = os.path.dirname(os.path.abspath(__file__))
TABLE_NAME = 'mp_cif_info'
mysql_connection = mysql.connector.connect(
host='100.84.94.73',
user='metadata_mat_papers',
password='siat-mic',
database='metadata_mat_papers'
)
mysql_cursor = mysql_connection.cursor()
pdf_list = os.listdir(os.path.join(cur_path, 'mp_cif/pdfs'))
doi_list = [pdf.replace('.pdf', '') for pdf in pdf_list]
try:
for doi in doi_list:
sql = f"INSERT INTO {TABLE_NAME} (doi) VALUES (%s)"
mysql_cursor.execute(sql, (doi,))
mysql_connection.commit()
finally:
mysql_connection.close()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,88 @@
import os
import tqdm
import sqlite3
import mysql.connector
import PyPDF2
def read_dois_from_db(db_path):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute(f"SELECT doi FROM doi_status;")
dois = [row[0] for row in cursor.fetchall()]
conn.close()
return dois
def main():
cur_path = os.path.dirname(os.path.abspath(__file__))
# db_path = os.path.join(cur_path, 'psk_high_cited', 'doi_status.db')
# dois_db = read_dois_from_db(db_path)
# for doi in tqdm.tqdm(dois_db):
# pdf = doi.replace('/','_').replace('<','_').replace('>','_').replace(':','_') + '.pdf'
# pdf_path = os.path.join(cur_path, 'psk_high_cited/pdfs', pdf)
# if os.path.exists(pdf_path):
# conn = sqlite3.connect(db_path)
# cursor = conn.cursor()
# cursor.execute(f"UPDATE doi_status SET status = 'success' WHERE doi = '{doi}';")
# conn.close()
###########################################################################################
TABLE_NAME = 'mp_cif_info'
mysql_connection = mysql.connector.connect(
host='100.84.94.73',
user='metadata_mat_papers',
password='siat-mic',
database='metadata_mat_papers'
)
mysql_cursor = mysql_connection.cursor()
try:
# 获取所有 doi
mysql_cursor.execute(f"SELECT doi FROM {TABLE_NAME};")
dois = [row[0] for row in mysql_cursor.fetchall()]
for doi in tqdm.tqdm(dois):
# pdf = doi.replace('/','_').replace('<','_').replace('>','_').replace(':','_') + '.pdf'
pdf = doi + '.pdf'
# 需要更改为你的pdf路径
pdf_path = os.path.join(cur_path, 'mp_cif/pdfs', pdf)
if os.path.exists(pdf_path):
try:
# 尝试打开PDF文件
with open(pdf_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file) # 如果无法解析,可能抛出异常
# 如果文件成功打开和解析,更新数据库状态为 'success'
query = f"UPDATE {TABLE_NAME} SET scihub_downloaded = %s WHERE doi = %s"
mysql_cursor.execute(query, ('success', doi))
mysql_connection.commit()
except (PyPDF2.errors.PdfReadError, PyPDF2.errors.PdfStreamError):
# 如果 PDF 解析失败,将 scihub_downlowded 设置为 NULL
query = f"UPDATE {TABLE_NAME} SET scihub_downloaded = %s WHERE doi = %s"
mysql_cursor.execute(query, (None, doi)) # None 会映射为 SQL 中的 NULL
mysql_connection.commit()
except Exception as e:
# 其他异常处理
print(f"处理 PDF {doi} 时出现未知错误: {e}")
query = f"UPDATE {TABLE_NAME} SET scihub_downloaded = %s WHERE doi = %s"
mysql_cursor.execute(query, (None, doi))
mysql_connection.commit()
except mysql.connector.Error as error:
print("Failed to insert record into MySQL table: {}".format(error))
# 如果发生错误,撤回事务
mysql_connection.rollback()
finally:
# 关闭游标和连接
mysql_cursor.close()
mysql_connection.close()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,47 @@
import sqlite3
import mysql.connector
import tqdm
import os
TABLE_NAME = 'mp_synthesis_papers_info'
input('TABLE_NAME = {} ?'.format(TABLE_NAME))
cur_dir = os.path.dirname(os.path.abspath(__file__))
# MySQL connection setup
mysql_connection = mysql.connector.connect(
host='100.84.94.73',
user='metadata_mat_papers',
password='siat-mic',
database='metadata_mat_papers'
)
try:
mysql_cursor = mysql_connection.cursor()
# 编写query语句
# query = f"SELECT pdf_url FROM {TABLE_NAME} WHERE scihub_downlowded IN ('broken', 'timeout', 'failed') and pdf_url IS NOT NULL;"
query = f"SELECT pdf_url FROM {TABLE_NAME} WHERE scihub_downlowded IS NULL AND pdf_url IS NOT NULL;"
mysql_cursor.execute(query)
records = mysql_cursor.fetchall()
for record in tqdm.tqdm(records):
# pdf_path = os.path.join(cur_dir, record[0])
# if os.path.exists(pdf_path):
# os.remove(pdf_path)
query = f"UPDATE {TABLE_NAME} SET pdf_url = NULL WHERE pdf_url = '{record[0]}';"
mysql_cursor.execute(query)
mysql_connection.commit()
# 提交更改到数据库
mysql_connection.commit()
except mysql.connector.Error as error:
print("Failed to insert record into MySQL table: {}".format(error))
# 如果发生错误,撤回事务
mysql_connection.rollback()
finally:
# 关闭游标和连接
mysql_cursor.close()
mysql_connection.close()

View File

@@ -0,0 +1,52 @@
import sqlite3
import mysql.connector
import tqdm
import os
TABLE_NAME = 'mp_cif_info'
input('TABLE_NAME = {} ?'.format(TABLE_NAME))
cur_dir = os.path.dirname(os.path.abspath(__file__))
# MySQL connection setup
mysql_connection = mysql.connector.connect(
host='100.84.94.73',
user='metadata_mat_papers',
password='siat-mic',
database='metadata_mat_papers'
)
try:
mysql_cursor = mysql_connection.cursor()
# 获取所有下载为 success 的 doi
query = f"SELECT doi, pdf_url FROM {TABLE_NAME} WHERE scihub_downloaded = 'success';"
mysql_cursor.execute(query)
results = mysql_cursor.fetchall()
dois = [row[0] for row in results]
pdf_urls = [row[1] for row in results]
for doi, pdf_url in tqdm.tqdm(zip(dois, pdf_urls), total=len(dois)):
# 若是已经修改过的,则直接跳过
if pdf_url is not None and pdf_url.split('/')[0] == 'mp_cif' and pdf_url.split('/')[1] == 'pdfs':
continue
# pdf = doi.replace('/','_').replace('<','_').replace('>','_').replace(':','_') + '.pdf'
pdf = doi + '.pdf'
# 新的路径
pdf_path = os.path.join('mp_cif/pdfs', pdf)
query = f"UPDATE {TABLE_NAME} SET pdf_url = '{pdf_path}' WHERE doi = '{doi}';"
mysql_cursor.execute(query)
mysql_connection.commit()
# 提交更改到数据库
mysql_connection.commit()
except mysql.connector.Error as error:
print("Failed to insert record into MySQL table: {}".format(error))
# 如果发生错误,撤回事务
mysql_connection.rollback()
finally:
# 关闭游标和连接
mysql_cursor.close()
mysql_connection.close()

View File

@@ -0,0 +1,51 @@
import mysql.connector
import tqdm
import os
TABLE_NAME = 'phosphorus_synthesis_info_new'
input('TABLE_NAME = {} ?'.format(TABLE_NAME))
cur_dir = os.path.dirname(os.path.abspath(__file__))
# MySQL connection setup
mysql_connection = mysql.connector.connect(
host='100.84.94.73',
user='metadata_mat_papers',
password='siat-mic',
database='metadata_mat_papers'
)
try:
mysql_cursor = mysql_connection.cursor()
# 获取所有已转换的 doi
query = f"SELECT doi, md_url FROM {TABLE_NAME} WHERE en_text_content IS NOT NULL;"
mysql_cursor.execute(query)
results = mysql_cursor.fetchall()
dois = [row[0] for row in results]
md_urls = [row[1] for row in results]
for doi, md_url in tqdm.tqdm(zip(dois, md_urls), total=len(dois)):
# 若是已经修改过的,则直接跳过
dir_name = 'phosphorus'
if md_url is not None and md_url.split('/')[0] == dir_name and md_url.split('/')[1] == 'mds':
continue
md_name = doi.replace('/','_').replace('<','_').replace('>','_').replace(':','_')
md = md_name + '.md'
md_path = os.path.join(dir_name+'/mds', md_name, md)
query = f"UPDATE {TABLE_NAME} SET md_url = '{md_path}', convert2md = 'success' WHERE doi = '{doi}';"
mysql_cursor.execute(query)
mysql_connection.commit()
# 提交更改到数据库
mysql_connection.commit()
except mysql.connector.Error as error:
print("Failed to insert record into MySQL table: {}".format(error))
# 如果发生错误,撤回事务
mysql_connection.rollback()
finally:
# 关闭游标和连接
mysql_cursor.close()
mysql_connection.close()

View File

@@ -0,0 +1,424 @@
import re
import os
import json
import copy
import requests
import time
import shutil
import uuid
import sqlite3
import PyPDF2
import multiprocessing
import mysql.connector
from concurrent.futures import ThreadPoolExecutor, as_completed
from loguru import logger
from glob import glob
from tqdm import tqdm
from datetime import datetime
import asyncio
from magic_pdf.pipe.UNIPipe import UNIPipe
from magic_pdf.pipe.OCRPipe import OCRPipe
from magic_pdf.pipe.TXTPipe import TXTPipe
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
import magic_pdf.model as model_config
model_config.__use_inside_model__ = True
# 图床配置
# IMGBED_URL = "http://localhost:40027/"
IMGBED_URL = "http://172.20.103.171:40027/"
# 检查imgbed url是否以/结尾
if not IMGBED_URL.endswith('/'):
IMGBED_URL += '/'
token_endpoint = f"{IMGBED_URL}api/v1/tokens"
upload_endpoint = f"{IMGBED_URL}api/v1/upload"
# 通过如下方式获取token
# curl -X POST http://localhost:40027/api/v1/tokens -H "Content-Type: application/json" -d '{"email":"yt.li2@siat.ac.cn", "password":"lyt20000414."}'
IMGBED_TOKEN = "6|QsBh5H7txY3Hd7ju1nzYKOBSdFQeL0YberydSFIH"
def replace_image_links(md_content: str, images_urls: dict) -> str:
# 匹配 Markdown 中的图像链接形式,即: ![alt text](image_path)
pattern = r'!\[(.*?)\]\((.*?)\)'
def replace_link(match):
# 提取出当前匹配到的图片路径
image_path = match.group(2)
# 检查该路径是否在字典中
if image_path in images_urls:
# 从字典中获取新的 URL
new_url = images_urls[image_path]
return f"![]({new_url})"
return match.group(0)
# 使用 sub 函数进行替换
updated_md_content = re.sub(pattern, replace_link, md_content)
return updated_md_content
# 上传图片到LSKY Pro
def upload_image(img_dir):
headers = {
"Authorization": f"Bearer {IMGBED_TOKEN}",
'Accept': 'application/json'
}
image_urls = {}
img_names = os.listdir(img_dir)
for image_name in img_names:
retry = 0
image_path = os.path.join(img_dir, image_name)
while retry < 5: # 最大重试次数
try:
with open(image_path, 'rb') as image_file: # 确保文件在上传时是打开状态
files = {'file': image_file}
# 上传文件
response = requests.post(upload_endpoint, headers=headers, files=files)
if response.status_code == 200:
result = response.json()
if result['status']:
image_url = result['data']['links']['url']
image_urls['images/'+image_name] = image_url
print(f"图片上传成功: {image_url}")
break # 上传成功,退出重试循环
else:
raise Exception(f"图片上传失败: {result['message']}")
elif response.status_code == 429:
# 429 响应,等待一段时间再重试
wait_time = 3
# wait_time = min(2 ** retry, 10) # 指数退避,最大等待 10 秒
# logger.warning(f"请求过于频繁,等待 {wait_time} 秒...")
print(f"请求过于频繁,等待 {wait_time} 秒...")
time.sleep(wait_time)
else:
raise Exception(f"HTTP请求出错: {response.status_code}")
retry += 1 # 增加重试次数
time.sleep(1) # 在重试失败后稍等一下
except FileNotFoundError:
logger.error(f"文件 {image_path} 不存在,请检查路径是否正确")
return
return image_urls
# 保存图片到本地,并确保生成的文件名唯一
def save_images_locally(img_dir, target_dir):
if not os.path.exists(target_dir):
os.makedirs(target_dir)
image_urls = {}
img_names = os.listdir(img_dir)
# 遍历图片并保存到目标文件夹
for image_name in img_names:
image_path = os.path.join(img_dir, image_name)
# 使用UUID生成唯一的文件名以保持图片名称的唯一性
unique_name = f"{uuid.uuid4()}{os.path.splitext(image_name)[1]}" # 保留原扩展名
save_path = os.path.join(target_dir, unique_name)
try:
# 复制文件到目标目录
shutil.copy2(image_path, save_path)
# 将图片名称与保存路径加入字典
image_urls[f'images/{unique_name}'] = save_path
print(f"图片保存成功: {save_path}")
except FileNotFoundError:
print(f"文件 {image_path} 不存在,跳过该图片")
except Exception as e:
print(f"保存图片 {image_name} 过程中发生错误: {e}")
return image_urls
def json_md_dump(
pipe,
md_writer,
pdf_name,
content_list,
md_content,
):
# 写入模型结果到 model.json
orig_model_list = copy.deepcopy(pipe.model_list)
md_writer.write(
content=json.dumps(orig_model_list, ensure_ascii=False, indent=4),
path=f"{pdf_name}_model.json"
)
# 写入中间结果到 middle.json
md_writer.write(
content=json.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4),
path=f"{pdf_name}_middle.json"
)
# text文本结果写入到 conent_list.json
md_writer.write(
content=json.dumps(content_list, ensure_ascii=False, indent=4),
path=f"{pdf_name}_content_list.json"
)
# 写入结果到 .md 文件中
md_writer.write(
content=md_content,
path=f"{pdf_name}.md"
)
def pdf_parse_main(
pdf_path: str,
parse_method: str = 'auto',
model_json_path: str = None,
is_json_md_dump: bool = True,
output_dir: str = None
):
"""
执行从 pdf 转换到 json、md 的过程,输出 md 和 json 文件到 pdf 文件所在的目录
:param pdf_path: .pdf 文件的路径,可以是相对路径,也可以是绝对路径
:param parse_method: 解析方法, 共 auto、ocr、txt 三种,默认 auto如果效果不好可以尝试 ocr
:param model_json_path: 已经存在的模型数据文件如果为空则使用内置模型pdf 和 model_json 务必对应
:param is_json_md_dump: 是否将解析后的数据写入到 .json 和 .md 文件中,默认 True会将不同阶段的数据写入到不同的 .json 文件中共3个.json文件md内容会保存到 .md 文件中
:param output_dir: 输出结果的目录地址,会生成一个以 pdf 文件名命名的文件夹并保存所有结果
"""
try:
pdf_name = os.path.basename(pdf_path).split("/")[-1].replace(".pdf", "")
pdf_path_parent = os.path.dirname(pdf_path)
if output_dir:
output_path = os.path.join(output_dir, pdf_name)
else:
output_path = os.path.join(pdf_path_parent, pdf_name)
output_image_path = os.path.join(output_path, 'images')
# 获取图片的父路径,为的是以相对路径保存到 .md 和 conent_list.json 文件中
image_path_parent = os.path.basename(output_image_path)
pdf_bytes = open(pdf_path, "rb").read() # 读取 pdf 文件的二进制数据
if model_json_path:
# 读取已经被模型解析后的pdf文件的 json 原始数据list 类型
model_json = json.loads(open(model_json_path, "r", encoding="utf-8").read())
else:
model_json = []
# 执行解析步骤
# image_writer = DiskReaderWriter(output_image_path)
image_writer, md_writer = DiskReaderWriter(output_image_path), DiskReaderWriter(output_path)
# 选择解析方式
# jso_useful_key = {"_pdf_type": "", "model_list": model_json}
# pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer)
if parse_method == "auto":
jso_useful_key = {"_pdf_type": "", "model_list": model_json}
pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer)
elif parse_method == "txt":
pipe = TXTPipe(pdf_bytes, model_json, image_writer)
elif parse_method == "ocr":
pipe = OCRPipe(pdf_bytes, model_json, image_writer)
else:
logger.error("unknown parse method, only auto, ocr, txt allowed")
exit(1)
# 执行分类
pipe.pipe_classify()
# 如果没有传入模型数据,则使用内置模型解析
if not model_json:
if model_config.__use_inside_model__:
pipe.pipe_analyze() # 解析
else:
logger.error("need model list input")
exit(1)
# 执行解析
pipe.pipe_parse()
# 保存 text 和 md 格式的结果
content_list = pipe.pipe_mk_uni_format(image_path_parent, drop_mode="none")
md_content = pipe.pipe_mk_markdown(image_path_parent, drop_mode="none")
# 上传图像到图床
# image_urls = upload_image(output_image_path)
# 保存图像到本地
target_dir = "mp_cif/images"
image_urls = save_images_locally(output_image_path, target_dir)
md_content = replace_image_links(md_content, image_urls)
mysql_connection = mysql.connector.connect(
host='100.84.94.73',
user='metadata_mat_papers',
password='siat-mic',
database='metadata_mat_papers',
charset="utf8mb4", # 设置连接使用 utf8mb4
collation="utf8mb4_unicode_ci" # 使用适当的 collation
)
mysql_cursor = mysql_connection.cursor()
table = 'mp_cif_info'
# path = 'phosphorus/pdfs/' + pdf_name + '.pdf'
# print("path:", path)
doi = os.path.basename(pdf_path).replace(".pdf", "").replace('_', '/')
try:
# 编写query语句
query = f"UPDATE {table} SET en_text_content = %s WHERE doi = %s"
mysql_cursor.execute(query, (md_content, doi))
print(f"{doi},md保存成功")
# 提交更改到数据库
mysql_connection.commit()
except mysql.connector.Error as error:
print("Failed to insert record into MySQL table: {}".format(error))
# 如果发生错误,撤回事务
mysql_connection.rollback()
finally:
# 关闭游标和连接
mysql_cursor.close()
mysql_connection.close()
if is_json_md_dump:
json_md_dump(pipe, md_writer, pdf_name, content_list, md_content)
return 'sucess'
except Exception as e:
logger.exception(e)
return 'error'
def check_doi_not_in_db(pdf_name, cursor):
query = f"SELECT * FROM doi_status WHERE doi = ? AND convert_status = ? "
cursor.execute(query, (pdf_name, 'unprocessed'))
res = cursor.fetchone()
if res:
return True
else:
return False
def init_worker(devices, pdfs, gpu_index, process_id):
"""
Initialize a worker process to process a chunk of PDFs with a specific GPU.
"""
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_index)
process_pdf_chunk(pdfs, gpu_index, process_id)
def get_converted2md_dois():
table = 'mp_cif_info'
dois = []
mysql_connection = mysql.connector.connect(
host='100.84.94.73',
user='metadata_mat_papers',
password='siat-mic',
database='metadata_mat_papers',
charset="utf8mb4", # 设置连接使用 utf8mb4
collation="utf8mb4_unicode_ci" # 使用适当的 collation
)
mysql_cursor = mysql_connection.cursor()
try:
# 编写query语句
query = f"SELECT doi FROM {table} WHERE en_text_content IS NOT NULL;"
mysql_cursor.execute(query)
res = mysql_cursor.fetchall()
dois = [row[0] for row in res if row]
except mysql.connector.Error as error:
# 如果发生错误,撤回事务
mysql_connection.rollback()
finally:
# 关闭游标和连接
mysql_cursor.close()
mysql_connection.close()
return dois
def is_within_operational_hours(start_hour, end_hour):
now = datetime.now().time() # 获取当前时间(不含日期)
current_hour = now.hour # 获取当前小时
# 检查是否在晚上6点到第二天早上9点范围
if start_hour > end_hour:
return (current_hour >= start_hour or current_hour < end_hour) # 跨过午夜
else:
return start_hour <= current_hour < end_hour
def process_pdf_chunk(pdf_paths, gpu_index, process_id):
for pdf_path in tqdm(pdf_paths, desc=f"Worker {gpu_index}_{process_id} Progress"):
# 在规定时间内运行任务
start_hour = 15 # 18点(晚上6点)
end_hour = 9 # 9点(次日早上9点)
# 检查当前时间是否在允许的时间范围
while True:
if is_within_operational_hours(start_hour, end_hour):
print("当前时间在任务运行区间内开始处理PDF文件...")
try:
with open(pdf_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
print(os.path.basename(pdf_path).replace(".pdf", "").replace('_', '/'))
status = pdf_parse_main(pdf_path, parse_method='auto', output_dir=output_dir)
break # 执行结束,跳出循环
except PyPDF2.errors.PdfReadError:
logger.error(f"{pdf_path} has been broken")
break # 执行异常,跳出循环
except Exception as e:
logger.error(f"{pdf_path} has an error: {e}")
break # 执行异常,跳出循环
else:
# 当前时间不在允许的时间范围,阻塞任务
print("当前时间不在运行区间,稍后重试...")
time.sleep(60 * 60) # 沉睡1小时后再次检查
def multiprocessing_setup(pdf_paths, num_gpus):
num_processes_per_gpu = 3
chunk_size = len(pdf_paths) // (num_gpus * num_processes_per_gpu)
processes = []
# Create processes for each GPU
for gpu_id in range(num_gpus):
for process_id in range(num_processes_per_gpu):
start_idx = (gpu_id * num_processes_per_gpu + process_id) * chunk_size
end_idx = None if (gpu_id == num_gpus - 1 and process_id == num_processes_per_gpu - 1) else start_idx + chunk_size
chunk = pdf_paths[start_idx:end_idx]
p = multiprocessing.Process(target=init_worker, args=([gpu_id], chunk, gpu_id, process_id))
processes.append(p)
p.start()
# Ensure all processes have completed
for p in processes:
p.join()
if __name__ == '__main__':
_cur_dir = os.path.dirname(os.path.abspath(__file__))
# 此处更改路径
pdf_dir = os.path.join(_cur_dir, "mp_cif/pdfs")
output_dir = os.path.join(_cur_dir, "mp_cif/mds")
os.makedirs(output_dir, exist_ok=True)
pdf_paths = sorted(glob(os.path.join(pdf_dir, "*.pdf")))
dois = get_converted2md_dois()
print(len(dois))
new_pdf_paths = pdf_paths[:]
for path in tqdm(pdf_paths):
doi = os.path.basename(path).replace(".pdf", "").replace('_', '/')
if doi in dois:
new_pdf_paths.remove(path)
print(len(new_pdf_paths))
# Number of GPUs
num_gpus = 8
# Setup multiprocessing to handle PDFs across multiple GPUs
multiprocessing_setup(new_pdf_paths, num_gpus)

160
clean/stp1_bib2sql.py Normal file
View File

@@ -0,0 +1,160 @@
import os
import glob
import mysql.connector
import bibtexparser
import tqdm
TABLE_NAME = 'phosphorus_synthesis_info'
input('你确定TABLE_NAME是{}吗?'.format(TABLE_NAME))
# phosphorus_synthesis
bibs_dir = os.path.join(os.path.dirname(__file__), 'synthesis23-25')
if_file_path = os.path.join(os.path.dirname(__file__), '2023JCR.xlsx')
input('你确定导入文件夹是{}吗?'.format(bibs_dir))
# MySQL connection setup
connection = mysql.connector.connect(
host='localhost',
user='metadata_mat_papers',
password='siat-mic',
database='metadata_mat_papers'
)
cursor = connection.cursor()
# Function to check if a table exists
def check_table_exists(table_name):
cursor.execute(f"""
SELECT COUNT(*)
FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_name = '{table_name}'
""")
return cursor.fetchone()[0] == 1
# Function to create the table if it doesn't exist
def create_table(table_name):
if not check_table_exists(table_name):
query = f"""
CREATE TABLE IF NOT EXISTS `{table_name}` (
doi VARCHAR(255) PRIMARY KEY,
unique_id VARCHAR(255),
author TEXT,
title TEXT,
journal VARCHAR(255),
year INT,
volume VARCHAR(50),
number VARCHAR(50),
pages VARCHAR(50),
month VARCHAR(50),
issn VARCHAR(50),
eissn VARCHAR(50),
researcher_id TEXT,
if2023 VARCHAR(50),
if5 VARCHAR(50),
journal_index VARCHAR(50),
jcr_quartile VARCHAR(50),
orcid TEXT,
early_access_date VARCHAR(50),
scihub_downlowded VARCHAR(50),
convert2md VARCHAR(50),
pdf_url TEXT,
md_url TEXT,
abstract TEXT,
image_url JSON,
text_content LONGTEXT
);
"""
cursor.execute(query)
def record_exists(doi, table_name):
query = f"SELECT COUNT(*) FROM `{table_name}` WHERE doi = %s"
cursor.execute(query, (doi,))
count = cursor.fetchone()[0]
return count > 0
# Function to insert a record into the MySQL database
def insert_record(entry, table_name):
# 定义列名列表
columns = [
'doi', 'unique_id', 'author', 'title', 'journal', 'year', 'volume',
'number', 'pages', 'month', 'issn', 'eissn', 'researcher_id', 'if2023', 'if5', 'journal_index', 'jcr_quartile',
'orcid', 'early_access_date', 'scihub_downlowded', 'convert2md', 'pdf_url', 'md_url', 'abstract', 'image_url', 'text_content'
]
# 构建SQL查询语句
placeholders = ', '.join(['%s'] * len(columns))
query = f"""
INSERT INTO `{table_name}` ({', '.join(columns)})
VALUES ({placeholders})
"""
values = (
entry.get('doi'),
entry.get('unique-id'),
entry.get('author'),
entry.get('title'),
entry.get('journal'),
entry.get('year'),
entry.get('volume'),
entry.get('number', None),
entry.get('pages', None),
entry.get('month', None),
entry.get('issn', None),
entry.get('eissn', None),
entry.get('researcherid-numbers', None),
entry.get('if2023', None),
entry.get('if5', None),
entry.get('journal_index', None),
entry.get('jcr_quartile', None),
entry.get('ocrid-numbers', None),
entry.get('earlyaccessdate', None),
entry.get('scihub_downlowded', None),
entry.get('convert2md', None),
entry.get('pdf_url', None),
entry.get('md_url', None),
entry.get('abstract', None),
entry.get('image_url', None),
entry.get('text_content', None)
)
cursor.execute(query, values)
# 用pandas打开excel文件
import pandas as pd
df = pd.read_excel(if_file_path)
# 替换所有的nan为None
df = df.replace({pd.NA: None})
# Create the table if it doesn't exist
create_table(TABLE_NAME)
bib_files = sorted(glob.glob(os.path.join(bibs_dir, '*.bib')))
for bib_file in tqdm.tqdm(bib_files):
# Read and parse the .bib file
with open(bib_file, 'r') as bibtex_file:
bib_database = bibtexparser.load(bibtex_file)
for entry in bib_database.entries:
entry = {k.lower(): v for k, v in entry.items()}
journal = entry.get('journal')
if journal is not None:
journal_lower = journal.lower() # 将期刊名称转为小写以进行不区分大小写的匹配
matching_journal = df[df['JournalName'].str.lower() == journal_lower] # 在DataFrame中查找该期刊
if not matching_journal.empty:
entry['if2023'] = matching_journal['IF2023'].values[0]
entry['if5'] = matching_journal['IF5'].values[0]
entry['journal_index'] = matching_journal['INDEX'].values[0]
entry['jcr_quartile'] = matching_journal['Quartile'].values[0]
doi = entry.get('doi')
# 先检查记录是否存在同时doi不能为空
if not record_exists(doi, TABLE_NAME) and doi is not None:
insert_record(entry, TABLE_NAME)
# Commit the changes and close the connection
connection.commit()
cursor.close()
connection.close()
print("Data has been inserted into the database!")

193
clean/stp1_excel2sql.py Normal file
View File

@@ -0,0 +1,193 @@
import os
import mysql.connector
TABLE_NAME = 'crispr_papers_info'
input('你确定TABLE_NAME是{}吗?'.format(TABLE_NAME))
# phosphorus_synthesis
excels_dir = os.path.join(os.path.dirname(__file__), 'CRISPR/CRISPR_engineered')
if_file_path = os.path.join(os.path.dirname(__file__), 'CRISPR/2023JCR.xlsx')
input('你确定导入文件夹是{}吗?'.format(excels_dir))
# MySQL connection setup
connection = mysql.connector.connect(
host='100.84.94.73',
user='metadata_mat_papers',
password='siat-mic',
database='metadata_mat_papers'
)
cursor = connection.cursor()
# Function to check if a table exists
def check_table_exists(table_name):
cursor.execute(f"""
SELECT COUNT(*)
FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_name = '{table_name}'
""")
return cursor.fetchone()[0] == 1
# Function to create the table if it doesn't exist
def create_table(table_name):
if not check_table_exists(table_name):
query = f"""
CREATE TABLE IF NOT EXISTS `{table_name}` (
doi VARCHAR(255) PRIMARY KEY,
unique_id VARCHAR(255),
author TEXT,
title TEXT,
journal VARCHAR(255),
year INT,
volume VARCHAR(50),
number VARCHAR(50),
pages VARCHAR(50),
month VARCHAR(50),
issn VARCHAR(50),
eissn VARCHAR(50),
researcher_id TEXT,
if2023 VARCHAR(50),
if5 VARCHAR(50),
journal_index VARCHAR(50),
jcr_quartile VARCHAR(50),
orcid TEXT,
early_access_date VARCHAR(50),
scihub_downlowded VARCHAR(50),
convert2md VARCHAR(50),
pdf_url TEXT,
md_url TEXT,
abstract TEXT,
image_url JSON,
en_text_content LONGTEXT,
cited_reference_count INT,
doi_link TEXT,
research_areas TEXT,
unique_wos_id VARCHAR(255)
);
"""
cursor.execute(query)
def record_exists(doi, table_name):
query = f"SELECT COUNT(*) FROM `{table_name}` WHERE doi = %s"
cursor.execute(query, (doi,))
count = cursor.fetchone()[0]
return count > 0
# Function to insert a record into the MySQL database
def insert_record(entry, table_name):
# 定义列名列表
columns = [
'doi', 'unique_id', 'author', 'title', 'journal', 'year', 'volume',
'number', 'pages', 'month', 'issn', 'eissn', 'researcher_id', 'if2023', 'if5', 'journal_index', 'jcr_quartile',
'orcid', 'early_access_date', 'scihub_downlowded', 'convert2md', 'pdf_url', 'md_url', 'abstract', 'image_url',
'text_content', 'cited_reference_count', 'doi_link', 'research_areas', 'unique_wos_id'
]
# 构建SQL查询语句
placeholders = ', '.join(['%s'] * len(columns))
query = f"""
INSERT INTO `{table_name}` ({', '.join(columns)})
VALUES ({placeholders})
"""
values = (
entry.get('doi'),
entry.get('unique-id'),
entry.get('author'),
entry.get('title'),
entry.get('journal'),
entry.get('year'),
entry.get('volume'),
entry.get('number', None),
entry.get('pages', None),
entry.get('month', None),
entry.get('issn', None),
entry.get('eissn', None),
entry.get('researcherid-numbers', None),
entry.get('if2023', None),
entry.get('if5', None),
entry.get('journal_index', None),
entry.get('jcr_quartile', None),
entry.get('ocrid-numbers', None),
entry.get('earlyaccessdate', None),
entry.get('scihub_downlowded', None),
entry.get('convert2md', None),
entry.get('pdf_url', None),
entry.get('md_url', None),
entry.get('abstract', None),
entry.get('image_url', None),
entry.get('text_content', None),
entry.get('cited_reference_count', None),
entry.get('doi_link', None),
entry.get('research_areas', None),
entry.get('unique_wos_id', None)
)
cursor.execute(query, values)
# 用pandas打开excel文件
import pandas as pd
df = pd.read_excel(if_file_path)
# 替换所有的nan为None
df = df.replace({pd.NA: None})
# Create the table if it doesn't exist
create_table(TABLE_NAME)
excels_file_list = []
for file in os.listdir(excels_dir): # os.listdir('溶剂热文献-230505-swx-V3')
if file.endswith('.xls'):
excels_file_list.append(os.path.splitext(file)[0])
for excels_file in excels_file_list:
print(os.path.join(excels_dir, excels_file + '.xls'))
# 指定Excel文件路径
file_path = os.path.join(excels_dir, excels_file + '.xls')
# 读取Excel文件
excel_df = pd.read_excel(file_path)
# 替换所有的nan为None
excel_df = excel_df.replace({pd.NA: None})
# 显示DataFrame的前几行
# print(df.head(5))
for i in range(len(excel_df)):
entry = dict()
entry['doi'] = str(excel_df.loc[i, 'DOI'])
entry['title'] = str(excel_df.loc[i, 'Article Title'])
entry['journal'] = str(excel_df.loc[i, 'Source Title'])
entry['abstract'] = str(excel_df.loc[i, 'Abstract'])
entry['cited_reference_count'] = int(excel_df.loc[i, 'Cited Reference Count'])
entry['year'] = int(excel_df.loc[i, 'Publication Year'])
entry['doi_link'] = str(excel_df.loc[i, 'DOI Link'])
entry['research_areas'] = str(excel_df.loc[i, 'Research Areas'])
entry['unique_wos_id'] = str(excel_df.loc[i, 'UT (Unique WOS ID)'])
journal = entry.get('journal')
if journal is not None:
journal_lower = journal.lower() # 将期刊名称转为小写以进行不区分大小写的匹配
matching_journal = df[df['JournalName'].str.lower() == journal_lower] # 在DataFrame中查找该期刊
if not matching_journal.empty:
entry['if2023'] = matching_journal['IF2023'].values[0]
entry['if5'] = matching_journal['IF5'].values[0]
entry['journal_index'] = matching_journal['INDEX'].values[0]
entry['jcr_quartile'] = matching_journal['Quartile'].values[0]
doi = entry.get('doi')
# 先检查记录是否存在同时doi不能为空
if not record_exists(doi, TABLE_NAME) and doi is not None:
insert_record(entry, TABLE_NAME)
# Commit the changes and close the connection
connection.commit()
cursor.close()
connection.close()
print("Data has been inserted into the database!")

View File

@@ -0,0 +1,65 @@
# 脚本是为了将SQLite数据库中的数据迁移到MySQL数据库中。
# 专门针对使用sqlite阶段写的代码如果后续直接对Mysql做操作就不要用这个脚本
import sqlite3
import mysql.connector
TABLE_NAME = 'phosphorus_synthesis_info'
input('你确定TABLE_NAME是{}吗?'.format(TABLE_NAME))
# SQLite setup
sqlite_connection = sqlite3.connect('/home/ubuntu/workplace/LYT/llm-agent/phosphorus/doi_status.db') # Ensure this is your actual SQLite database file
sqlite_cursor = sqlite_connection.cursor()
# MySQL connection setup
mysql_connection = mysql.connector.connect(
host='100.84.94.73',
user='metadata_mat_papers',
password='siat-mic',
database='metadata_mat_papers'
)
mysql_cursor = mysql_connection.cursor()
# Define the SQLite query to retrieve data
sqlite_query = "SELECT doi, status, pdf_url FROM doi_status" # Ensure these field names match your SQLite table
# Function to check if a record exists in the MySQL database
def record_exists(doi, table_name):
query = f"SELECT COUNT(*) FROM `{table_name}` WHERE doi = %s"
mysql_cursor.execute(query, (doi,))
count = mysql_cursor.fetchone()[0]
return count > 0
# Function to update a record in the MySQL database
def update_record(doi, scihub_downlowded, pdf_url, table_name):
query = f"""
UPDATE `{table_name}`
SET scihub_downlowded = %s, pdf_url = %s
WHERE doi = %s
"""
mysql_cursor.execute(query, (scihub_downlowded, pdf_url, doi))
# Fetch data from SQLite
sqlite_cursor.execute(sqlite_query)
rows = sqlite_cursor.fetchall()
# Iterate over SQLite rows and update MySQL records
for row in rows:
doi, scihub_downlowded, pdf_url = row
if record_exists(doi, TABLE_NAME): # Replace with your actual MySQL table name
update_record(doi, scihub_downlowded, pdf_url, TABLE_NAME) # Adjust table name if necessary
else:
# You can choose to handle non-existent DOI entries differently if necessary
print(f"Record with DOI {doi} does not exist in MySQL database.")
# Commit the changes to the MySQL database
mysql_connection.commit()
# Close connections
sqlite_cursor.close()
sqlite_connection.close()
mysql_cursor.close()
mysql_connection.close()
print("Data migration from SQLite to MySQL completed successfully!")

View File

@@ -0,0 +1,28 @@
import sqlite3
import mysql.connector
import tqdm
import os
TABLE_NAME = 'phosphorus_synthesis_info'
input('TABLE_NAME = {} ?'.format(TABLE_NAME))
cur_dir = os.path.dirname(os.path.abspath(__file__))
# MySQL connection setup
mysql_connection = mysql.connector.connect(
host='100.84.94.73',
user='metadata_mat_papers',
password='siat-mic',
database='metadata_mat_papers'
)
mysql_cursor = mysql_connection.cursor()
# 编写query语句
query = f"SELECT pdf_url FROM {TABLE_NAME} WHERE scihub_downlowded = 'broken'"
mysql_cursor.execute(query)
records = mysql_cursor.fetchall()
for record in tqdm.tqdm(records):
pdf_path = os.path.join(cur_dir, record[0])
os.remove(pdf_path)

View File

@@ -0,0 +1,211 @@
import os
import re
import time
import tqdm
import requests
import subprocess
import concurrent.futures
import sqlite3
from scidownl import scihub_download
import logging
import pymupdf
NUM_PROCESSES = 32 # 设置并发进程数
SCIHUB_URLS = [
"https://sci-hub.st/",
"https://sci-hub.se/",
"https://sci-hub.ru/"
]
PROXY_SERVICE_URL = f"http://api.proxy.ipidea.io/getProxyIp?num={NUM_PROCESSES}&tag=static_balance&return_type=txt&lb=1&sb=0&flow=1&protocol=http"
SINGLE_PROXY_SERVICE_URL = f"http://api.proxy.ipidea.io/getProxyIp?num=1&tag=static_balance&return_type=txt&lb=1&sb=0&flow=1&protocol=http"
DOI_PATTERN = re.compile(r"DOI\s*=\s*\{(10\.\d{4,9}/[-._;()/:A-Z0-9]+)\}", re.IGNORECASE)
logging.basicConfig(level=logging.INFO, format='[%(levelname)s] | %(asctime)s | %(message)s')
logger = logging.getLogger(__name__)
def get_directories(bib_dir_name, output_dirname):
current_path = os.path.dirname(os.path.abspath(__file__))
output_dir = os.path.join(current_path, output_dirname)
bib_dir_path = os.path.join(current_path, bib_dir_name)
db_path = os.path.join(current_path, "doi_status.db")
return output_dir, bib_dir_path, db_path
def create_directory_if_not_exists(directory):
os.makedirs(directory, exist_ok=True)
def fetch_proxies():
proxies = []
try:
response = requests.get(PROXY_SERVICE_URL)
if response.status_code == 200:
proxy_list = response.text.strip().split('\r\n')
for proxy in proxy_list:
proxies.append({
"http": f"http://{proxy}",
"https": f"http://{proxy}",
})
if proxies:
logger.info(f"Fetched proxies: {proxies}")
return proxies
except Exception as e:
logger.error(f"Error fetching proxies: {e}")
return None
def fetch_proxy():
proxies = []
try:
response = requests.get(SINGLE_PROXY_SERVICE_URL)
if response.status_code == 200:
proxy_list = response.text.strip().split('\r\n')
for proxy in proxy_list:
proxies.append({
"http": f"http://{proxy}",
"https": f"http://{proxy}",
})
if proxies:
logger.info(f"Fetched proxies: {proxies}")
return proxies
except Exception as e:
logger.error(f"Error fetching proxies: {e}")
return None
def read_dois_from_files(bib_dir_path):
all_dois = []
for bib_file_name in sorted(os.listdir(bib_dir_path)):
if bib_file_name.endswith(".bib"):
with open(os.path.join(bib_dir_path, bib_file_name), "r") as file:
dois = DOI_PATTERN.findall(file.read())
logger.info(f"{bib_file_name} has {len(dois)} doi(s)")
all_dois.extend(dois)
return list(set(all_dois))
def filter_downloaded_dois(all_dois, output_dir):
for doi in os.listdir(output_dir):
if doi.endswith(".pdf"):
doi = doi.replace(".pdf", "").replace("_", "/")
if doi in all_dois:
all_dois.remove(doi)
return all_dois
def read_dois_from_db(db_path, status):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute(f"SELECT doi FROM doi_status WHERE status = '{status}'")
dois = [row[0] for row in cursor.fetchall()]
conn.close()
return dois
def write_doi_to_db(db_path, doi, output_dirname, status):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("INSERT OR REPLACE INTO doi_status (doi, status, pdf_url) VALUES (?, ?, ?)", (doi, status, f"{output_dirname}/{doi.replace('/', '_')}.pdf"))
conn.commit()
conn.close()
def initialize_db(db_path):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS doi_status (
doi TEXT PRIMARY KEY,
status TEXT,
pdf_url TEXT
)
''')
conn.commit()
cursor.execute("PRAGMA journal_mode=WAL")
conn.commit()
conn.close()
def download_doi(doi, output_dir, proxy, scihub_urls, db_path):
success_dois, broken_dois, failed_dois, timeout_dois = [], [], [], []
output_dirname = output_dir.split("/")[-1]
for scihub_url in scihub_urls:
output_path = os.path.join(output_dir, f"{doi.replace('/', '_')}.pdf")
proxy_url = "https=" + proxy['https']
try:
result = subprocess.run(
['scidownl', 'download', '--doi', doi, '--out', output_path, '--scihub-url', scihub_url, '--proxy', proxy_url],
check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
logger.info(result.stderr)
if "No pdf tag" in result.stderr:
timeout_dois.append(doi)
write_doi_to_db(db_path, doi, output_dirname, 'timeout')
break
elif "403" in result.stderr or "Unable to connect to proxy" in result.stderr or "504" in result.stderr or 'crawling_failed, error: HTTPSConnectionPool' in result.stderr:
logger.warning("Proxy error detected, fetching new proxy.")
proxy = fetch_proxy()[0]
# time.sleep(2)
continue
elif result.stdout.strip() != '':
try:
# 尝试打开pdf文件
with pymupdf.open(output_path) as pdf:
logger.info(f"Downloaded {doi} successfully.")
write_doi_to_db(db_path, doi, output_dirname, 'success')
success_dois.append(doi)
except:
write_doi_to_db(db_path, doi, output_dirname, 'broken')
logger.info(f"{doi}.pdf has been broken!")
broken_dois.append(doi)
break
else:
write_doi_to_db(db_path, doi, output_dirname, 'failed')
break
except subprocess.CalledProcessError as e:
logger.error(f"Error: {e}")
failed_dois.append(doi)
write_doi_to_db(db_path, doi, 'failed')
continue
return success_dois, broken_dois, failed_dois, timeout_dois
def download_dois(all_dois, output_dir, db_path):
success_dois, broken_dois, failed_dois, timeout_dois = [], [], [], []
proxies = fetch_proxies()
with concurrent.futures.ProcessPoolExecutor(max_workers=NUM_PROCESSES) as executor:
futures = []
for i, doi in enumerate(all_dois):
proxy = proxies[i % len(proxies)]
futures.append(executor.submit(download_doi, doi, output_dir, proxy, SCIHUB_URLS, db_path))
for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc='Downloading DOIs', unit='doi'):
result = future.result()
if result:
success, broken, failed, timeout = result
success_dois.extend(success)
broken_dois.extend(broken)
failed_dois.extend(failed)
timeout_dois.extend(timeout)
logger.info(f"Success: {len(success_dois)}, Broken: {len(broken_dois)}, Failed: {len(failed_dois)}, Timeout: {len(timeout_dois)}")
def main():
bib_dir_name = "synthesis23-25"
output_dirname = "synthesis23-25_pdfs"
input('你确定是文件夹{}{}吗?'.format(bib_dir_name, output_dirname))
output_dir, bib_dir_path, db_path = get_directories(bib_dir_name, output_dirname)
create_directory_if_not_exists(output_dir)
initialize_db(db_path)
all_dois = read_dois_from_files(bib_dir_path)
logger.info(f"Total {len(all_dois)} doi(s)")
all_dois = filter_downloaded_dois(all_dois, output_dir)
all_dois = [doi for doi in all_dois if doi not in read_dois_from_db(db_path, 'success')]
all_dois = [doi for doi in all_dois if doi not in read_dois_from_db(db_path, 'failed')]
all_dois = [doi for doi in all_dois if doi not in read_dois_from_db(db_path, 'timeout')]
download_dois(all_dois, output_dir, db_path)
if __name__ == "__main__":
main()