第一次合并clean代码
This commit is contained in:
0
clean/__init__.py
Normal file
0
clean/__init__.py
Normal file
273
clean/preprocess_mineru.py
Normal file
273
clean/preprocess_mineru.py
Normal file
@@ -0,0 +1,273 @@
|
||||
import re
|
||||
import os
|
||||
import json
|
||||
import copy
|
||||
import requests
|
||||
import time
|
||||
import sqlite3
|
||||
import PyPDF2
|
||||
import multiprocessing
|
||||
import mysql.connector
|
||||
|
||||
from loguru import logger
|
||||
from glob import glob
|
||||
from tqdm import tqdm
|
||||
|
||||
from magic_pdf.pipe.UNIPipe import UNIPipe
|
||||
from magic_pdf.pipe.OCRPipe import OCRPipe
|
||||
from magic_pdf.pipe.TXTPipe import TXTPipe
|
||||
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
|
||||
import magic_pdf.model as model_config
|
||||
|
||||
model_config.__use_inside_model__ = True
|
||||
|
||||
# 图床配置
|
||||
IMGBED_URL = "http://localhost:40027/"
|
||||
# 检查imgbed url是否以/结尾
|
||||
if not IMGBED_URL.endswith('/'):
|
||||
IMGBED_URL += '/'
|
||||
token_endpoint = f"{IMGBED_URL}api/v1/tokens"
|
||||
upload_endpoint = f"{IMGBED_URL}api/v1/upload"
|
||||
|
||||
# 通过如下方式获取token
|
||||
# curl -X POST http://localhost:40027/api/v1/tokens -H "Content-Type: application/json" -d '{"email":"yt.li2@siat.ac.cn", "password":"lyt20000414."}'
|
||||
IMGBED_TOKEN = "6|QsBh5H7txY3Hd7ju1nzYKOBSdFQeL0YberydSFIH"
|
||||
|
||||
def replace_image_links(md_content: str, images_urls: dict) -> str:
|
||||
# 匹配 Markdown 中的图像链接形式,即: 
|
||||
pattern = r'!\[(.*?)\]\((.*?)\)'
|
||||
|
||||
def replace_link(match):
|
||||
# 提取出当前匹配到的图片路径
|
||||
image_path = match.group(2)
|
||||
# 检查该路径是否在字典中
|
||||
if image_path in images_urls:
|
||||
# 从字典中获取新的 URL
|
||||
new_url = images_urls[image_path]
|
||||
return f""
|
||||
return match.group(0)
|
||||
|
||||
# 使用 sub 函数进行替换
|
||||
updated_md_content = re.sub(pattern, replace_link, md_content)
|
||||
return updated_md_content
|
||||
|
||||
# 上传图片到LSKY Pro
|
||||
def upload_image(img_dir):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {IMGBED_TOKEN}",
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
|
||||
image_urls = {}
|
||||
os.makedirs(img_dir, exist_ok=True)
|
||||
img_names = os.listdir(img_dir)
|
||||
for image_name in img_names:
|
||||
retry = 0
|
||||
image_path = os.path.join(img_dir, image_name)
|
||||
while retry < 5: # 最大重试次数
|
||||
try:
|
||||
with open(image_path, 'rb') as image_file: # 确保文件在上传时是打开状态
|
||||
files = {'file': image_file}
|
||||
|
||||
# 上传文件
|
||||
response = requests.post(upload_endpoint, headers=headers, files=files)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if result['status']:
|
||||
image_url = result['data']['links']['url']
|
||||
image_urls['images/'+image_name] = image_url
|
||||
break # 上传成功,退出重试循环
|
||||
else:
|
||||
raise Exception(f"图片上传失败: {result['message']}")
|
||||
elif response.status_code == 429:
|
||||
# 429 响应,等待一段时间再重试
|
||||
wait_time = min(2 ** retry, 60) # 指数退避,最大等待 60 秒
|
||||
logger.warning(f"请求过于频繁,等待 {wait_time} 秒...")
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
raise Exception(f"HTTP请求出错: {response.status_code}")
|
||||
|
||||
retry += 1 # 增加重试次数
|
||||
time.sleep(1) # 在重试失败后稍等一下
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"文件 {image_path} 不存在,请检查路径是否正确")
|
||||
return
|
||||
|
||||
return image_urls
|
||||
|
||||
def json_md_dump(
|
||||
pipe,
|
||||
md_writer,
|
||||
pdf_name,
|
||||
content_list,
|
||||
md_content,
|
||||
):
|
||||
# 写入模型结果到 model.json
|
||||
orig_model_list = copy.deepcopy(pipe.model_list)
|
||||
md_writer.write(
|
||||
content=json.dumps(orig_model_list, ensure_ascii=False, indent=4),
|
||||
path=f"{pdf_name}_model.json"
|
||||
)
|
||||
|
||||
# 写入中间结果到 middle.json
|
||||
md_writer.write(
|
||||
content=json.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4),
|
||||
path=f"{pdf_name}_middle.json"
|
||||
)
|
||||
|
||||
# text文本结果写入到 conent_list.json
|
||||
md_writer.write(
|
||||
content=json.dumps(content_list, ensure_ascii=False, indent=4),
|
||||
path=f"{pdf_name}_content_list.json"
|
||||
)
|
||||
|
||||
# 写入结果到 .md 文件中
|
||||
md_writer.write(
|
||||
content=md_content,
|
||||
path=f"{pdf_name}.md"
|
||||
)
|
||||
|
||||
def pdf_parse_main(
|
||||
pdf_path: str,
|
||||
parse_method: str = 'auto',
|
||||
model_json_path: str = None,
|
||||
is_json_md_dump: bool = True,
|
||||
output_dir: str = None
|
||||
):
|
||||
"""
|
||||
执行从 pdf 转换到 json、md 的过程,输出 md 和 json 文件到 pdf 文件所在的目录
|
||||
|
||||
:param pdf_path: .pdf 文件的路径,可以是相对路径,也可以是绝对路径
|
||||
:param parse_method: 解析方法, 共 auto、ocr、txt 三种,默认 auto,如果效果不好,可以尝试 ocr
|
||||
:param model_json_path: 已经存在的模型数据文件,如果为空则使用内置模型,pdf 和 model_json 务必对应
|
||||
:param is_json_md_dump: 是否将解析后的数据写入到 .json 和 .md 文件中,默认 True,会将不同阶段的数据写入到不同的 .json 文件中(共3个.json文件),md内容会保存到 .md 文件中
|
||||
:param output_dir: 输出结果的目录地址,会生成一个以 pdf 文件名命名的文件夹并保存所有结果
|
||||
"""
|
||||
try:
|
||||
pdf_name = os.path.basename(pdf_path).split("/")[-1].replace(".pdf", "")
|
||||
pdf_path_parent = os.path.dirname(pdf_path)
|
||||
|
||||
if output_dir:
|
||||
output_path = os.path.join(output_dir, pdf_name)
|
||||
else:
|
||||
output_path = os.path.join(pdf_path_parent, pdf_name)
|
||||
|
||||
output_image_path = os.path.join(output_path, 'images')
|
||||
|
||||
# 获取图片的父路径,为的是以相对路径保存到 .md 和 conent_list.json 文件中
|
||||
image_path_parent = os.path.basename(output_image_path)
|
||||
|
||||
pdf_bytes = open(pdf_path, "rb").read() # 读取 pdf 文件的二进制数据
|
||||
|
||||
if model_json_path:
|
||||
# 读取已经被模型解析后的pdf文件的 json 原始数据,list 类型
|
||||
model_json = json.loads(open(model_json_path, "r", encoding="utf-8").read())
|
||||
else:
|
||||
model_json = []
|
||||
|
||||
# 执行解析步骤
|
||||
# image_writer = DiskReaderWriter(output_image_path)
|
||||
image_writer, md_writer = DiskReaderWriter(output_image_path), DiskReaderWriter(output_path)
|
||||
|
||||
# 选择解析方式
|
||||
# jso_useful_key = {"_pdf_type": "", "model_list": model_json}
|
||||
# pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer)
|
||||
if parse_method == "auto":
|
||||
jso_useful_key = {"_pdf_type": "", "model_list": model_json}
|
||||
pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer)
|
||||
elif parse_method == "txt":
|
||||
pipe = TXTPipe(pdf_bytes, model_json, image_writer)
|
||||
elif parse_method == "ocr":
|
||||
pipe = OCRPipe(pdf_bytes, model_json, image_writer)
|
||||
else:
|
||||
logger.error("unknown parse method, only auto, ocr, txt allowed")
|
||||
exit(1)
|
||||
|
||||
# 执行分类
|
||||
pipe.pipe_classify()
|
||||
|
||||
# 如果没有传入模型数据,则使用内置模型解析
|
||||
if not model_json:
|
||||
if model_config.__use_inside_model__:
|
||||
pipe.pipe_analyze() # 解析
|
||||
else:
|
||||
logger.error("need model list input")
|
||||
exit(1)
|
||||
|
||||
# 执行解析
|
||||
pipe.pipe_parse()
|
||||
|
||||
# 保存 text 和 md 格式的结果
|
||||
content_list = pipe.pipe_mk_uni_format(image_path_parent, drop_mode="none")
|
||||
md_content = pipe.pipe_mk_markdown(image_path_parent, drop_mode="none")
|
||||
# 上传图像到图床
|
||||
image_urls = upload_image(output_image_path)
|
||||
md_content = replace_image_links(md_content, image_urls)
|
||||
|
||||
if is_json_md_dump:
|
||||
json_md_dump(pipe, md_writer, pdf_name, content_list, md_content)
|
||||
return 'sucess'
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return 'error'
|
||||
|
||||
def init_worker(devices, pdfs, gpu_index):
|
||||
"""
|
||||
Initialize a worker process to process a chunk of PDFs with a specific GPU.
|
||||
"""
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_index)
|
||||
process_pdf_chunk(pdfs, gpu_index)
|
||||
|
||||
def process_pdf_chunk(pdf_paths, worker_id):
|
||||
for pdf_path in tqdm(pdf_paths, desc=f"Worker {worker_id} Progress"):
|
||||
try:
|
||||
with open(pdf_path, 'rb') as file:
|
||||
pdf_reader = PyPDF2.PdfReader(file)
|
||||
print(os.path.basename(pdf_path).replace(".pdf", "").replace('_', '/'))
|
||||
status = pdf_parse_main(pdf_path, parse_method='auto', output_dir=output_dir)
|
||||
except PyPDF2.errors.PdfReadError:
|
||||
logger.error(f"{pdf_path} has been broken")
|
||||
except Exception as e:
|
||||
logger.error(f"{pdf_path} has an error: {e}")
|
||||
|
||||
def multiprocessing_setup(pdf_paths, num_gpus):
|
||||
num_processes_per_gpu = 2
|
||||
chunk_size = len(pdf_paths) // (num_gpus * num_processes_per_gpu)
|
||||
processes = []
|
||||
|
||||
# Create processes for each GPU
|
||||
for gpu_id in range(num_gpus):
|
||||
for process_id in range(num_processes_per_gpu):
|
||||
start_idx = (gpu_id * num_processes_per_gpu + process_id) * chunk_size
|
||||
end_idx = None if (gpu_id == num_gpus - 1 and process_id == num_processes_per_gpu - 1) else start_idx + chunk_size
|
||||
chunk = pdf_paths[start_idx:end_idx]
|
||||
|
||||
p = multiprocessing.Process(target=init_worker, args=([gpu_id], chunk, gpu_id))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
# Ensure all processes have completed
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
if __name__ == "__main__":
|
||||
_cur_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# 此处更改路径
|
||||
pdf_dir = os.path.join(_cur_dir, "black_phosphorus_wulie/黑磷文献/黑磷文献-任务1-推荐官能团")
|
||||
output_dir = os.path.join(_cur_dir, "black_phosphorus_wulie/黑磷文献-任务1-推荐官能团_pdf2md")
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
pdf_paths = sorted(glob(os.path.join(pdf_dir, "*.pdf")))
|
||||
|
||||
print("pdf数量:", len(pdf_paths))
|
||||
|
||||
# Number of GPUs
|
||||
num_gpus = 8
|
||||
|
||||
# Setup multiprocessing to handle PDFs across multiple GPUs
|
||||
# multiprocessing_setup(pdf_paths, num_gpus)
|
||||
|
||||
pdf_path = "/home/ubuntu/sas0/LYT/paper_dataset/black_phosphorus_wulie/黑磷文献/黑磷文献-任务1-推荐官能团/(P-O,P-O-P)Supporting_information.pdf"
|
||||
pdf_parse_main(pdf_path, parse_method='auto', output_dir=output_dir)
|
||||
245
clean/preprocess_mineru_new.py
Normal file
245
clean/preprocess_mineru_new.py
Normal file
@@ -0,0 +1,245 @@
|
||||
import re
|
||||
import os
|
||||
import requests
|
||||
import time
|
||||
import PyPDF2
|
||||
import multiprocessing as mp
|
||||
import math
|
||||
import sys
|
||||
import torch
|
||||
|
||||
from loguru import logger
|
||||
from glob import glob
|
||||
from tqdm import tqdm
|
||||
|
||||
from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader
|
||||
from magic_pdf.data.dataset import PymuDocDataset
|
||||
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
|
||||
from magic_pdf.config.enums import SupportedPdfParseMethod
|
||||
|
||||
# 图床配置
|
||||
IMGBED_URL = "http://localhost:40027/"
|
||||
# 检查imgbed url是否以/结尾
|
||||
if not IMGBED_URL.endswith('/'):
|
||||
IMGBED_URL += '/'
|
||||
token_endpoint = f"{IMGBED_URL}api/v1/tokens"
|
||||
upload_endpoint = f"{IMGBED_URL}api/v1/upload"
|
||||
|
||||
# 通过如下方式获取token
|
||||
# curl -X POST http://localhost:40027/api/v1/tokens -H "Content-Type: application/json" -d '{"email":"yt.li2@siat.ac.cn", "password":"lyt20000414."}'
|
||||
IMGBED_TOKEN = "6|QsBh5H7txY3Hd7ju1nzYKOBSdFQeL0YberydSFIH"
|
||||
|
||||
def replace_image_links(md_content: str, images_urls: dict) -> str:
|
||||
# 匹配 Markdown 中的图像链接形式,即: 
|
||||
pattern = r'!\[(.*?)\]\((.*?)\)'
|
||||
|
||||
def replace_link(match):
|
||||
# 提取出当前匹配到的图片路径
|
||||
image_path = match.group(2)
|
||||
# 检查该路径是否在字典中
|
||||
if image_path in images_urls:
|
||||
# 从字典中获取新的 URL
|
||||
new_url = images_urls[image_path]
|
||||
return f""
|
||||
return match.group(0)
|
||||
|
||||
# 使用 sub 函数进行替换
|
||||
updated_md_content = re.sub(pattern, replace_link, md_content)
|
||||
return updated_md_content
|
||||
|
||||
# 上传图片到LSKY Pro
|
||||
def upload_image(img_dir):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {IMGBED_TOKEN}",
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
|
||||
image_urls = {}
|
||||
os.makedirs(img_dir, exist_ok=True)
|
||||
img_names = os.listdir(img_dir)
|
||||
for image_name in img_names:
|
||||
retry = 0
|
||||
image_path = os.path.join(img_dir, image_name)
|
||||
while retry < 5: # 最大重试次数
|
||||
try:
|
||||
with open(image_path, 'rb') as image_file: # 确保文件在上传时是打开状态
|
||||
files = {'file': image_file}
|
||||
|
||||
# 上传文件
|
||||
response = requests.post(upload_endpoint, headers=headers, files=files)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if result['status']:
|
||||
image_url = result['data']['links']['url']
|
||||
image_urls['images/'+image_name] = image_url
|
||||
break # 上传成功,退出重试循环
|
||||
else:
|
||||
raise Exception(f"图片上传失败: {result['message']}")
|
||||
elif response.status_code == 429:
|
||||
# 429 响应,等待一段时间再重试
|
||||
wait_time = min(2 ** retry, 60) # 指数退避,最大等待 60 秒
|
||||
logger.warning(f"请求过于频繁,等待 {wait_time} 秒...")
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
raise Exception(f"HTTP请求出错: {response.status_code}")
|
||||
|
||||
retry += 1 # 增加重试次数
|
||||
time.sleep(1) # 在重试失败后稍等一下
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"文件 {image_path} 不存在,请检查路径是否正确")
|
||||
return
|
||||
|
||||
return image_urls
|
||||
|
||||
def pdf_parse_main(
|
||||
pdf_path: str,
|
||||
output_dir: str = None
|
||||
):
|
||||
try:
|
||||
name_without_suff = os.path.basename(pdf_path).replace('.pdf', '')
|
||||
|
||||
# prepare env
|
||||
local_md_dir = os.path.join(output_dir, name_without_suff)
|
||||
local_image_dir = os.path.join(local_md_dir, 'images')
|
||||
image_dir = str(os.path.basename(local_image_dir))
|
||||
|
||||
os.makedirs(local_image_dir, exist_ok=True)
|
||||
|
||||
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(
|
||||
local_md_dir
|
||||
)
|
||||
|
||||
# read bytes
|
||||
reader1 = FileBasedDataReader("")
|
||||
pdf_bytes = reader1.read(pdf_path) # read the pdf content
|
||||
# proc
|
||||
## Create Dataset Instance
|
||||
ds = PymuDocDataset(pdf_bytes)
|
||||
## inference
|
||||
if ds.classify() == SupportedPdfParseMethod.OCR:
|
||||
infer_result = ds.apply(doc_analyze, ocr=True)
|
||||
## pipeline
|
||||
pipe_result = infer_result.pipe_ocr_mode(image_writer)
|
||||
else:
|
||||
infer_result = ds.apply(doc_analyze, ocr=False)
|
||||
## pipeline
|
||||
pipe_result = infer_result.pipe_txt_mode(image_writer)
|
||||
### draw model result on each page
|
||||
infer_result.draw_model(os.path.join(local_md_dir, f"{name_without_suff}_model.pdf"))
|
||||
### draw layout result on each page
|
||||
pipe_result.draw_layout(os.path.join(local_md_dir, f"{name_without_suff}_layout.pdf"))
|
||||
### draw spans result on each page
|
||||
pipe_result.draw_span(os.path.join(local_md_dir, f"{name_without_suff}_spans.pdf"))
|
||||
### dump markdown
|
||||
md_content = pipe_result.dump_md(md_writer, os.path.join(local_md_dir, f"{name_without_suff}.md"), image_dir)
|
||||
### dump content list
|
||||
pipe_result.dump_content_list(md_writer, os.path.join(local_md_dir, f"{name_without_suff}_content_list.json"), image_dir)
|
||||
|
||||
# print(md_content)
|
||||
# 上传图像到图床
|
||||
image_urls = upload_image(local_image_dir)
|
||||
md_content = replace_image_links(md_content, image_urls)
|
||||
|
||||
md_writer.write_string(os.path.join(local_md_dir, f"{name_without_suff}.md"), md_content)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return 'error'
|
||||
|
||||
def init_worker(pdfs, gpu_index, output_dir): # 添加output_dir参数
|
||||
"""
|
||||
Initialize a worker process to process a chunk of PDFs with a specific GPU.
|
||||
"""
|
||||
try:
|
||||
# 设置CUDA设备
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_index)
|
||||
|
||||
import torch
|
||||
device = torch.device('cuda:0')
|
||||
|
||||
print(f"进程 {os.getpid()} 启动于GPU {gpu_index}")
|
||||
print(f"处理 {len(pdfs)} 个PDF文件")
|
||||
|
||||
process_pdf_chunk(pdfs, device, output_dir) # 传递output_dir
|
||||
|
||||
except Exception as e:
|
||||
print(f"进程 {os.getpid()} 在GPU {gpu_index} 上初始化失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_pdf_chunk(pdf_paths, worker_id, output_dir):
|
||||
for pdf_path in tqdm(pdf_paths, desc=f"Worker {worker_id} Progress"):
|
||||
try:
|
||||
# 定期清理GPU内存
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with open(pdf_path, 'rb') as file:
|
||||
pdf_reader = PyPDF2.PdfReader(file)
|
||||
print(os.path.basename(pdf_path).replace(".pdf", "").replace('_', '/'))
|
||||
pdf_parse_main(pdf_path, output_dir=output_dir)
|
||||
except PyPDF2.errors.PdfReadError:
|
||||
logger.error(f"{pdf_path} has been broken")
|
||||
except Exception as e:
|
||||
logger.error(f"{pdf_path} has an error: {e}")
|
||||
|
||||
def multiprocessing_setup(pdf_paths, num_gpus, output_dir):
|
||||
# 计算每个GPU处理的文件数量
|
||||
chunk_size = math.ceil(len(pdf_paths) / num_gpus)
|
||||
processes = []
|
||||
|
||||
# 为每个GPU创建一个进程
|
||||
for gpu_id in range(num_gpus):
|
||||
start_idx = gpu_id * chunk_size
|
||||
end_idx = min(len(pdf_paths), start_idx + chunk_size)
|
||||
chunk = pdf_paths[start_idx:end_idx]
|
||||
|
||||
p = mp.Process(target=init_worker, args=(chunk, gpu_id, output_dir)) # 传递output_dir
|
||||
processes.append(p)
|
||||
p.start()
|
||||
time.sleep(2)
|
||||
|
||||
# 等待所有进程完成
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_cur_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# 此处更改路径
|
||||
# pdf_dir = os.path.join(_cur_dir, "二维材料剥离/二维材料剥离/石墨烯")
|
||||
# output_dir = os.path.join(_cur_dir, "二维材料剥离/mds/石墨烯")
|
||||
# pdf_dir = os.path.join(_cur_dir, "二维材料剥离/二维材料剥离/黑磷烯")
|
||||
# output_dir = os.path.join(_cur_dir, "二维材料剥离/mds/黑磷烯")
|
||||
pdf_dir = os.path.join(_cur_dir, "模型评估/模型评估")
|
||||
output_dir = os.path.join(_cur_dir, "模型评估/mds")
|
||||
# pdf_dir = os.path.join(_cur_dir, "金纳米棒/金纳米棒")
|
||||
# output_dir = os.path.join(_cur_dir, "金纳米棒/mds")
|
||||
# pdf_dir = os.path.join(_cur_dir, "钙钛矿/钙钛矿-复合材料")
|
||||
# output_dir = os.path.join(_cur_dir, "钙钛矿/mds/复合材料")
|
||||
# pdf_dir = os.path.join(_cur_dir, "钙钛矿/钙钛矿-LAPR/PDF论文")
|
||||
# output_dir = os.path.join(_cur_dir, "钙钛矿/mds/LAPR")
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
pdf_paths = sorted(glob(os.path.join(pdf_dir, "*.pdf")))
|
||||
print("pdf数量:", len(pdf_paths))
|
||||
|
||||
# 输出目录中md文件的数量
|
||||
md_paths = sorted(glob(os.path.join(output_dir, "**", "*.md"), recursive=True))
|
||||
md_names = [os.path.basename(md_path) for md_path in md_paths]
|
||||
pdf_paths = [pdf_path for pdf_path in pdf_paths if os.path.basename(pdf_path).replace('.pdf', '.md') not in md_names]
|
||||
print("过滤后pdf数量:", len(pdf_paths))
|
||||
|
||||
# # 设置GPU数量
|
||||
# num_gpus = 2 # 先用2个GPU测试
|
||||
|
||||
# # 设置多进程启动方法
|
||||
# mp.set_start_method('spawn', force=True)
|
||||
|
||||
# try:
|
||||
# multiprocessing_setup(pdf_paths, num_gpus, output_dir)
|
||||
# except Exception as e:
|
||||
# print(f"程序执行出错: {str(e)}")
|
||||
|
||||
# pdf_path = "black_phosphorus/参考文献/2015.03-ACS Nano-Barbaros Özyilmaz-石墨烯接触、全封装的超薄黑磷基场效应晶体管中的空气稳定传输.pdf"
|
||||
for pdf_path in tqdm(pdf_paths):
|
||||
pdf_parse_main(pdf_path, output_dir=output_dir)
|
||||
319
clean/reparagraph.py
Executable file
319
clean/reparagraph.py
Executable file
@@ -0,0 +1,319 @@
|
||||
"""
|
||||
Author: Yutang LI
|
||||
Institution: SIAT-MIC
|
||||
Contact: yt.li2@siat.ac.cn
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
from openai import OpenAI
|
||||
from config import ReparagraphConfig
|
||||
|
||||
# 配置logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('reparagraph.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_true_level(title_info: list, config: ReparagraphConfig):
|
||||
source_title = json.dumps(title_info)
|
||||
instruction = """
|
||||
你是一个论文目录重排助手。
|
||||
有如下的JSON格式的目录信息,已知目录中每级标题的内容和行号。
|
||||
<PLACEHOLDER>
|
||||
请你重排该论文的目录层级,并为每级标题的level字段给出正确的层级关系,其中层级关系用数字(1,2,3,4)表示,数字越小,层级越高。
|
||||
注意:重排序目录要求多个1级标题的样式, 而非单一1级目录的样式。也就说level为1的标题数量必须大于1。
|
||||
通常情况下位于一级标题的有可能是:
|
||||
1. 论文的题目
|
||||
2. 论文的摘要(Abstract)
|
||||
3. 论文的介绍(Introduction)
|
||||
4. 论文的方法或实验(Methods or Experiment)
|
||||
5. 论文的结果或讨论(Result or Discussion)
|
||||
6. 论文的结论(Conclusion)
|
||||
7. 论文的参考文献(References)
|
||||
8. 论文的致谢(Acknowledgments)
|
||||
9. 论文的附录(Appendix)
|
||||
10. 论文的支撑信息(Supporting Information)
|
||||
有时候目录中存在序号,这时则优先使用序号顺序重建目录。
|
||||
|
||||
返回结果的时候严格遵守下列示例JSON格式:
|
||||
{ 'data': [
|
||||
{ 'title': 'A hierarchically porous MOF confined CsPbBr3 quantum dots: Fluorescence switching probe for detecting Cu (II) and melamine in food samples', 'line_num': 1, 'level': 1},
|
||||
...
|
||||
]
|
||||
"""
|
||||
# 创建 OpenAI 客户端
|
||||
client = OpenAI(api_key=config.openai_api_key, base_url=config.openai_base_url)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": instruction.replace("<PLACEHOLDER>", source_title)}
|
||||
]
|
||||
attempt = 0
|
||||
while attempt < config.max_retries:
|
||||
try:
|
||||
completion = client.chat.completions.create(
|
||||
model=config.model_name,
|
||||
stream=False, # 关闭流模式
|
||||
messages=messages,
|
||||
response_format={
|
||||
'type': 'json_object'
|
||||
}
|
||||
)
|
||||
|
||||
response = completion.choices[0].message.content
|
||||
response = json.loads(response)
|
||||
count_level_1 = sum(1 for item in response['data'] if item['level'] == 1)
|
||||
if count_level_1 == 1:
|
||||
attempt += 1
|
||||
messages.append({"role": "assistant", "content": str(response)})
|
||||
messages.append({"role": "user", "content": "上述目录中仅有1个1级标题, 请重新生成目录, 并保证目录中至少有两个1级标题。"})
|
||||
continue
|
||||
return response['data']
|
||||
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
logging.error(f"尝试 {attempt + 1}/{config.max_retries} 失败: {str(e)}")
|
||||
if attempt == config.max_retries - 1:
|
||||
logging.error("达到最大重试次数,放弃操作")
|
||||
return "Error"
|
||||
|
||||
|
||||
def read_file_content(file_path: str):
|
||||
"""读取文件内容"""
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
return file.readlines()
|
||||
|
||||
def write_file_content(file_path: str, content: list):
|
||||
"""写入文件内容"""
|
||||
with open(file_path, 'w', encoding='utf-8') as file:
|
||||
file.writelines(content)
|
||||
|
||||
def extract_headings(lines: list):
|
||||
"""从文件内容中提取所有以#开头的行及其行号"""
|
||||
headings = []
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if re.match(r'^#', line.strip()):
|
||||
headings.append((line_num, line.strip()))
|
||||
return headings
|
||||
|
||||
def extract_references(lines: list, headings: list, remove_refs: bool = False):
|
||||
"""从文件内容中提取参考文献部分
|
||||
Args:
|
||||
lines: 文件内容列表
|
||||
headings: 标题信息列表
|
||||
remove_refs: 是否抹去参考文献内容
|
||||
Returns:
|
||||
dict: 包含起始点、结束点和内容的信息
|
||||
{
|
||||
'start': ref_start,
|
||||
'end': ref_end,
|
||||
'content': references,
|
||||
'updated_headings': updated_headings
|
||||
}
|
||||
"""
|
||||
# 在标题中查找REFERENCE
|
||||
ref_heading = None
|
||||
for line_num, heading in headings:
|
||||
if "REFERENCE" in heading.upper().replace(" ", ""):
|
||||
ref_heading = (line_num, heading)
|
||||
break
|
||||
|
||||
if not ref_heading and "ACKNOWLEDGEMENT" in heading.upper().replace(" ", ""):
|
||||
ref_heading = (line_num, heading)
|
||||
|
||||
if not ref_heading:
|
||||
# 用正则匹配常见的引用格式并删除
|
||||
# 包括:[数字]、数字.、(数字) 格式
|
||||
ref_pattern = r'^(\[\d+\]|\d+\.|\(\d+\))'
|
||||
lines = [line for line in lines if not re.match(ref_pattern, line.strip())]
|
||||
return {
|
||||
'start': -1,
|
||||
'end': -1,
|
||||
'content': None
|
||||
}, lines
|
||||
|
||||
ref_start = ref_heading[0] - 1 # 转换为0-based索引
|
||||
|
||||
# 查找下一个标题或文件结尾
|
||||
ref_end = len(lines)
|
||||
for i in range(ref_start + 1, len(lines)):
|
||||
if re.match(r'^#', lines[i].strip()):
|
||||
ref_end = i
|
||||
break
|
||||
|
||||
# 提取参考文献内容
|
||||
references = ''.join(lines[ref_start:ref_end])
|
||||
|
||||
# 如果需要抹去内容
|
||||
if remove_refs:
|
||||
lines[ref_start:ref_end] = []
|
||||
|
||||
# # 如果需要更新headings
|
||||
# updated_headings = headings
|
||||
# if remove_refs and ref_heading:
|
||||
# # 从headings中移除Reference标题
|
||||
# updated_headings = [h for h in headings if h[1].upper() != ref_heading[1].upper()]
|
||||
|
||||
return {
|
||||
'start': ref_start,
|
||||
'end': ref_end,
|
||||
'content': references,
|
||||
#'updated_headings': updated_headings
|
||||
}, lines
|
||||
|
||||
def update_headings(lines: list, heading_data: list):
|
||||
"""根据提供的标题数据更新Markdown文件内容"""
|
||||
# 统计heading_data中level==1的数量
|
||||
# count_level_1 = sum(1 for item in heading_data if item['level'] == 1)
|
||||
# flag = 2 if count_level_1 > 1 else 3 # 存在多个一级标题是为2否则为3
|
||||
|
||||
for heading in heading_data:
|
||||
line_num = heading['line_num'] - 1
|
||||
if heading['level'] >= 2:#flag:
|
||||
lines[line_num] = "**" + lines[line_num].replace("#", "").strip() + "**\n"
|
||||
return lines
|
||||
|
||||
|
||||
def detect_file_encoding(file_path: str):
|
||||
"""检测文件编码"""
|
||||
import chardet
|
||||
with open(file_path, 'rb') as f:
|
||||
raw_data = f.read(1024)
|
||||
result = chardet.detect(raw_data)
|
||||
return result['encoding']
|
||||
|
||||
# def read_file_content(file_path: str, config: ReparagraphConfig):
|
||||
# """读取文件内容,带大小检查和编码检测"""
|
||||
# file_size = os.path.getsize(file_path)
|
||||
# if file_size > config.max_file_size:
|
||||
# logging.warning(f"文件 {file_path} 超过最大限制 {config.max_file_size} bytes,跳过处理")
|
||||
# return None
|
||||
|
||||
# encoding = detect_file_encoding(file_path)
|
||||
# try:
|
||||
# with open(file_path, 'r', encoding=encoding) as file:
|
||||
# return file.readlines()
|
||||
# except UnicodeDecodeError:
|
||||
# logging.error(f"无法解码文件 {file_path},尝试使用utf-8")
|
||||
# with open(file_path, 'r', encoding='utf-8') as file:
|
||||
# return file.readlines()
|
||||
|
||||
def process_single_file(file_path: str, config: ReparagraphConfig):
|
||||
"""处理单个文件并返回处理后的内容"""
|
||||
# 读取文件内容
|
||||
lines = read_file_content(file_path)
|
||||
if lines is None:
|
||||
return None
|
||||
|
||||
# 提取并更新标题
|
||||
headings = extract_headings(lines)
|
||||
title_info = [{"title": heading, "line_num": line_num, "level": "unknown"}
|
||||
for line_num, heading in headings]
|
||||
|
||||
# 提取参考文献
|
||||
ref_info, lines = extract_references(lines, headings, remove_refs=config.remove_refs)
|
||||
if ref_info:
|
||||
logging.info("提取的参考文献:")
|
||||
logging.info(f"起始行: {ref_info['start'] + 1}")
|
||||
logging.info(f"结束行: {ref_info['end']}")
|
||||
logging.info("内容:")
|
||||
logging.info(ref_info['content'])
|
||||
# 更新headings
|
||||
# headings = ref_info['updated_headings']
|
||||
else:
|
||||
logging.warning("未找到参考文献部分")
|
||||
|
||||
# 删除reference后可能会导致标题的行号变化,重新索引
|
||||
headings = extract_headings(lines)
|
||||
title_info = [{"title": heading, "line_num": line_num, "level": "unknown"}
|
||||
for line_num, heading in headings]
|
||||
|
||||
new_headings = get_true_level(title_info, config)
|
||||
updated_lines = update_headings(lines, new_headings)
|
||||
|
||||
logging.info(f"文件处理完成: {file_path}")
|
||||
return updated_lines
|
||||
|
||||
def create_output_dir(input_path: str, config: ReparagraphConfig):
|
||||
"""创建输出目录"""
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
# 获取输入路径的父目录
|
||||
parent_dir = os.path.dirname(input_path)
|
||||
|
||||
# 创建带时间戳的输出目录
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = os.path.join(parent_dir, f"{config.task_name}_{timestamp}")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
return output_dir
|
||||
|
||||
def save_processed_file(file_path: str, content: list, output_dir: str, input_path: str):
|
||||
"""保存处理后的文件"""
|
||||
import os
|
||||
|
||||
# 如果是单个文件
|
||||
if os.path.isfile(input_path):
|
||||
output_path = os.path.join(output_dir, os.path.basename(file_path))
|
||||
else:
|
||||
# 保持目录结构
|
||||
relative_path = os.path.relpath(file_path, input_path)
|
||||
output_path = os.path.join(output_dir, relative_path)
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.writelines(content)
|
||||
logging.info(f"已保存处理后的文件: {output_path}")
|
||||
|
||||
def reparagraph_file(path: str, config:ReparagraphConfig=None):
|
||||
"""处理单个文件或文件夹中的所有.md文件
|
||||
Args:
|
||||
path: 文件路径或文件夹路径
|
||||
config: ReparagraphConfig实例,包含处理配置
|
||||
Returns:
|
||||
str: 输出目录路径
|
||||
"""
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
if config is None:
|
||||
config = ReparagraphConfig()
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = create_output_dir(path, config)
|
||||
logging.info(f"输出目录: {output_dir}")
|
||||
|
||||
# 如果是文件夹,递归获取所有.md文件
|
||||
if os.path.isdir(path):
|
||||
files = []
|
||||
for root, _, filenames in os.walk(path):
|
||||
for filename in filenames:
|
||||
if filename.endswith('.md'):
|
||||
files.append(os.path.join(root, filename))
|
||||
else:
|
||||
files = [path]
|
||||
|
||||
def process_and_save(file_path: str):
|
||||
content = process_single_file(file_path, config)
|
||||
if content is not None and not config.dry_run:
|
||||
save_processed_file(file_path, content, output_dir, path)
|
||||
|
||||
if config.parallel:
|
||||
# 使用线程池并行处理
|
||||
with ThreadPoolExecutor() as executor:
|
||||
list(tqdm(executor.map(process_and_save, files), total=len(files), desc="Processing files"))
|
||||
else:
|
||||
# 顺序处理
|
||||
for file_path in tqdm(files, desc="Processing files"):
|
||||
process_and_save(file_path)
|
||||
|
||||
logging.info(f"处理完成,共处理 {len(files)} 个文件")
|
||||
return output_dir
|
||||
33
clean/step0_pdfs2sql.py
Normal file
33
clean/step0_pdfs2sql.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import os
|
||||
import tqdm
|
||||
import sqlite3
|
||||
import mysql.connector
|
||||
|
||||
def main():
|
||||
cur_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
TABLE_NAME = 'mp_cif_info'
|
||||
|
||||
mysql_connection = mysql.connector.connect(
|
||||
host='100.84.94.73',
|
||||
user='metadata_mat_papers',
|
||||
password='siat-mic',
|
||||
database='metadata_mat_papers'
|
||||
)
|
||||
mysql_cursor = mysql_connection.cursor()
|
||||
|
||||
pdf_list = os.listdir(os.path.join(cur_path, 'mp_cif/pdfs'))
|
||||
|
||||
doi_list = [pdf.replace('.pdf', '') for pdf in pdf_list]
|
||||
|
||||
try:
|
||||
for doi in doi_list:
|
||||
sql = f"INSERT INTO {TABLE_NAME} (doi) VALUES (%s)"
|
||||
mysql_cursor.execute(sql, (doi,))
|
||||
mysql_connection.commit()
|
||||
finally:
|
||||
mysql_connection.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
88
clean/step1_modify_status_with_database.py
Normal file
88
clean/step1_modify_status_with_database.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import os
|
||||
import tqdm
|
||||
import sqlite3
|
||||
import mysql.connector
|
||||
import PyPDF2
|
||||
|
||||
def read_dois_from_db(db_path):
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f"SELECT doi FROM doi_status;")
|
||||
dois = [row[0] for row in cursor.fetchall()]
|
||||
conn.close()
|
||||
return dois
|
||||
|
||||
def main():
|
||||
cur_path = os.path.dirname(os.path.abspath(__file__))
|
||||
# db_path = os.path.join(cur_path, 'psk_high_cited', 'doi_status.db')
|
||||
# dois_db = read_dois_from_db(db_path)
|
||||
|
||||
# for doi in tqdm.tqdm(dois_db):
|
||||
# pdf = doi.replace('/','_').replace('<','_').replace('>','_').replace(':','_') + '.pdf'
|
||||
# pdf_path = os.path.join(cur_path, 'psk_high_cited/pdfs', pdf)
|
||||
# if os.path.exists(pdf_path):
|
||||
# conn = sqlite3.connect(db_path)
|
||||
# cursor = conn.cursor()
|
||||
# cursor.execute(f"UPDATE doi_status SET status = 'success' WHERE doi = '{doi}';")
|
||||
# conn.close()
|
||||
|
||||
###########################################################################################
|
||||
|
||||
TABLE_NAME = 'mp_cif_info'
|
||||
|
||||
mysql_connection = mysql.connector.connect(
|
||||
host='100.84.94.73',
|
||||
user='metadata_mat_papers',
|
||||
password='siat-mic',
|
||||
database='metadata_mat_papers'
|
||||
)
|
||||
mysql_cursor = mysql_connection.cursor()
|
||||
|
||||
try:
|
||||
# 获取所有 doi
|
||||
mysql_cursor.execute(f"SELECT doi FROM {TABLE_NAME};")
|
||||
dois = [row[0] for row in mysql_cursor.fetchall()]
|
||||
|
||||
for doi in tqdm.tqdm(dois):
|
||||
# pdf = doi.replace('/','_').replace('<','_').replace('>','_').replace(':','_') + '.pdf'
|
||||
pdf = doi + '.pdf'
|
||||
|
||||
# 需要更改为你的pdf路径
|
||||
pdf_path = os.path.join(cur_path, 'mp_cif/pdfs', pdf)
|
||||
|
||||
if os.path.exists(pdf_path):
|
||||
try:
|
||||
# 尝试打开PDF文件
|
||||
with open(pdf_path, 'rb') as file:
|
||||
pdf_reader = PyPDF2.PdfReader(file) # 如果无法解析,可能抛出异常
|
||||
|
||||
# 如果文件成功打开和解析,更新数据库状态为 'success'
|
||||
query = f"UPDATE {TABLE_NAME} SET scihub_downloaded = %s WHERE doi = %s"
|
||||
mysql_cursor.execute(query, ('success', doi))
|
||||
mysql_connection.commit()
|
||||
|
||||
except (PyPDF2.errors.PdfReadError, PyPDF2.errors.PdfStreamError):
|
||||
# 如果 PDF 解析失败,将 scihub_downlowded 设置为 NULL
|
||||
query = f"UPDATE {TABLE_NAME} SET scihub_downloaded = %s WHERE doi = %s"
|
||||
mysql_cursor.execute(query, (None, doi)) # None 会映射为 SQL 中的 NULL
|
||||
mysql_connection.commit()
|
||||
|
||||
except Exception as e:
|
||||
# 其他异常处理
|
||||
print(f"处理 PDF {doi} 时出现未知错误: {e}")
|
||||
query = f"UPDATE {TABLE_NAME} SET scihub_downloaded = %s WHERE doi = %s"
|
||||
mysql_cursor.execute(query, (None, doi))
|
||||
mysql_connection.commit()
|
||||
|
||||
except mysql.connector.Error as error:
|
||||
print("Failed to insert record into MySQL table: {}".format(error))
|
||||
# 如果发生错误,撤回事务
|
||||
mysql_connection.rollback()
|
||||
|
||||
finally:
|
||||
# 关闭游标和连接
|
||||
mysql_cursor.close()
|
||||
mysql_connection.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
47
clean/step2_reserve_success_pdf_with_database.py
Normal file
47
clean/step2_reserve_success_pdf_with_database.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import sqlite3
|
||||
import mysql.connector
|
||||
import tqdm
|
||||
import os
|
||||
|
||||
TABLE_NAME = 'mp_synthesis_papers_info'
|
||||
input('TABLE_NAME = {} ?'.format(TABLE_NAME))
|
||||
|
||||
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# MySQL connection setup
|
||||
mysql_connection = mysql.connector.connect(
|
||||
host='100.84.94.73',
|
||||
user='metadata_mat_papers',
|
||||
password='siat-mic',
|
||||
database='metadata_mat_papers'
|
||||
)
|
||||
|
||||
try:
|
||||
mysql_cursor = mysql_connection.cursor()
|
||||
|
||||
# 编写query语句
|
||||
# query = f"SELECT pdf_url FROM {TABLE_NAME} WHERE scihub_downlowded IN ('broken', 'timeout', 'failed') and pdf_url IS NOT NULL;"
|
||||
query = f"SELECT pdf_url FROM {TABLE_NAME} WHERE scihub_downlowded IS NULL AND pdf_url IS NOT NULL;"
|
||||
mysql_cursor.execute(query)
|
||||
records = mysql_cursor.fetchall()
|
||||
|
||||
for record in tqdm.tqdm(records):
|
||||
# pdf_path = os.path.join(cur_dir, record[0])
|
||||
# if os.path.exists(pdf_path):
|
||||
# os.remove(pdf_path)
|
||||
query = f"UPDATE {TABLE_NAME} SET pdf_url = NULL WHERE pdf_url = '{record[0]}';"
|
||||
mysql_cursor.execute(query)
|
||||
mysql_connection.commit()
|
||||
|
||||
# 提交更改到数据库
|
||||
mysql_connection.commit()
|
||||
|
||||
except mysql.connector.Error as error:
|
||||
print("Failed to insert record into MySQL table: {}".format(error))
|
||||
# 如果发生错误,撤回事务
|
||||
mysql_connection.rollback()
|
||||
|
||||
finally:
|
||||
# 关闭游标和连接
|
||||
mysql_cursor.close()
|
||||
mysql_connection.close()
|
||||
52
clean/step3_path_change_with_database.py
Normal file
52
clean/step3_path_change_with_database.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import sqlite3
|
||||
import mysql.connector
|
||||
import tqdm
|
||||
import os
|
||||
|
||||
TABLE_NAME = 'mp_cif_info'
|
||||
input('TABLE_NAME = {} ?'.format(TABLE_NAME))
|
||||
|
||||
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# MySQL connection setup
|
||||
mysql_connection = mysql.connector.connect(
|
||||
host='100.84.94.73',
|
||||
user='metadata_mat_papers',
|
||||
password='siat-mic',
|
||||
database='metadata_mat_papers'
|
||||
)
|
||||
|
||||
try:
|
||||
mysql_cursor = mysql_connection.cursor()
|
||||
|
||||
# 获取所有下载为 success 的 doi
|
||||
query = f"SELECT doi, pdf_url FROM {TABLE_NAME} WHERE scihub_downloaded = 'success';"
|
||||
mysql_cursor.execute(query)
|
||||
results = mysql_cursor.fetchall()
|
||||
dois = [row[0] for row in results]
|
||||
pdf_urls = [row[1] for row in results]
|
||||
|
||||
for doi, pdf_url in tqdm.tqdm(zip(dois, pdf_urls), total=len(dois)):
|
||||
# 若是已经修改过的,则直接跳过
|
||||
if pdf_url is not None and pdf_url.split('/')[0] == 'mp_cif' and pdf_url.split('/')[1] == 'pdfs':
|
||||
continue
|
||||
# pdf = doi.replace('/','_').replace('<','_').replace('>','_').replace(':','_') + '.pdf'
|
||||
pdf = doi + '.pdf'
|
||||
# 新的路径
|
||||
pdf_path = os.path.join('mp_cif/pdfs', pdf)
|
||||
query = f"UPDATE {TABLE_NAME} SET pdf_url = '{pdf_path}' WHERE doi = '{doi}';"
|
||||
mysql_cursor.execute(query)
|
||||
mysql_connection.commit()
|
||||
|
||||
# 提交更改到数据库
|
||||
mysql_connection.commit()
|
||||
|
||||
except mysql.connector.Error as error:
|
||||
print("Failed to insert record into MySQL table: {}".format(error))
|
||||
# 如果发生错误,撤回事务
|
||||
mysql_connection.rollback()
|
||||
|
||||
finally:
|
||||
# 关闭游标和连接
|
||||
mysql_cursor.close()
|
||||
mysql_connection.close()
|
||||
51
clean/step4.2_modify_md_with_database.py
Normal file
51
clean/step4.2_modify_md_with_database.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import mysql.connector
|
||||
import tqdm
|
||||
import os
|
||||
|
||||
TABLE_NAME = 'phosphorus_synthesis_info_new'
|
||||
input('TABLE_NAME = {} ?'.format(TABLE_NAME))
|
||||
|
||||
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# MySQL connection setup
|
||||
mysql_connection = mysql.connector.connect(
|
||||
host='100.84.94.73',
|
||||
user='metadata_mat_papers',
|
||||
password='siat-mic',
|
||||
database='metadata_mat_papers'
|
||||
)
|
||||
|
||||
try:
|
||||
mysql_cursor = mysql_connection.cursor()
|
||||
|
||||
# 获取所有已转换的 doi
|
||||
query = f"SELECT doi, md_url FROM {TABLE_NAME} WHERE en_text_content IS NOT NULL;"
|
||||
mysql_cursor.execute(query)
|
||||
results = mysql_cursor.fetchall()
|
||||
dois = [row[0] for row in results]
|
||||
md_urls = [row[1] for row in results]
|
||||
|
||||
for doi, md_url in tqdm.tqdm(zip(dois, md_urls), total=len(dois)):
|
||||
# 若是已经修改过的,则直接跳过
|
||||
dir_name = 'phosphorus'
|
||||
if md_url is not None and md_url.split('/')[0] == dir_name and md_url.split('/')[1] == 'mds':
|
||||
continue
|
||||
md_name = doi.replace('/','_').replace('<','_').replace('>','_').replace(':','_')
|
||||
md = md_name + '.md'
|
||||
md_path = os.path.join(dir_name+'/mds', md_name, md)
|
||||
query = f"UPDATE {TABLE_NAME} SET md_url = '{md_path}', convert2md = 'success' WHERE doi = '{doi}';"
|
||||
mysql_cursor.execute(query)
|
||||
mysql_connection.commit()
|
||||
|
||||
# 提交更改到数据库
|
||||
mysql_connection.commit()
|
||||
|
||||
except mysql.connector.Error as error:
|
||||
print("Failed to insert record into MySQL table: {}".format(error))
|
||||
# 如果发生错误,撤回事务
|
||||
mysql_connection.rollback()
|
||||
|
||||
finally:
|
||||
# 关闭游标和连接
|
||||
mysql_cursor.close()
|
||||
mysql_connection.close()
|
||||
424
clean/step4_preprocess_mineru_multi_with_database.py
Normal file
424
clean/step4_preprocess_mineru_multi_with_database.py
Normal file
@@ -0,0 +1,424 @@
|
||||
import re
|
||||
import os
|
||||
import json
|
||||
import copy
|
||||
import requests
|
||||
import time
|
||||
import shutil
|
||||
import uuid
|
||||
import sqlite3
|
||||
import PyPDF2
|
||||
import multiprocessing
|
||||
import mysql.connector
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from loguru import logger
|
||||
from glob import glob
|
||||
from tqdm import tqdm
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
|
||||
from magic_pdf.pipe.UNIPipe import UNIPipe
|
||||
from magic_pdf.pipe.OCRPipe import OCRPipe
|
||||
from magic_pdf.pipe.TXTPipe import TXTPipe
|
||||
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
|
||||
import magic_pdf.model as model_config
|
||||
|
||||
model_config.__use_inside_model__ = True
|
||||
|
||||
|
||||
# 图床配置
|
||||
# IMGBED_URL = "http://localhost:40027/"
|
||||
IMGBED_URL = "http://172.20.103.171:40027/"
|
||||
# 检查imgbed url是否以/结尾
|
||||
if not IMGBED_URL.endswith('/'):
|
||||
IMGBED_URL += '/'
|
||||
token_endpoint = f"{IMGBED_URL}api/v1/tokens"
|
||||
upload_endpoint = f"{IMGBED_URL}api/v1/upload"
|
||||
|
||||
# 通过如下方式获取token
|
||||
# curl -X POST http://localhost:40027/api/v1/tokens -H "Content-Type: application/json" -d '{"email":"yt.li2@siat.ac.cn", "password":"lyt20000414."}'
|
||||
IMGBED_TOKEN = "6|QsBh5H7txY3Hd7ju1nzYKOBSdFQeL0YberydSFIH"
|
||||
|
||||
|
||||
def replace_image_links(md_content: str, images_urls: dict) -> str:
|
||||
# 匹配 Markdown 中的图像链接形式,即: 
|
||||
pattern = r'!\[(.*?)\]\((.*?)\)'
|
||||
|
||||
def replace_link(match):
|
||||
# 提取出当前匹配到的图片路径
|
||||
image_path = match.group(2)
|
||||
# 检查该路径是否在字典中
|
||||
if image_path in images_urls:
|
||||
# 从字典中获取新的 URL
|
||||
new_url = images_urls[image_path]
|
||||
return f""
|
||||
return match.group(0)
|
||||
|
||||
# 使用 sub 函数进行替换
|
||||
updated_md_content = re.sub(pattern, replace_link, md_content)
|
||||
return updated_md_content
|
||||
|
||||
|
||||
# 上传图片到LSKY Pro
|
||||
def upload_image(img_dir):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {IMGBED_TOKEN}",
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
|
||||
image_urls = {}
|
||||
img_names = os.listdir(img_dir)
|
||||
for image_name in img_names:
|
||||
retry = 0
|
||||
image_path = os.path.join(img_dir, image_name)
|
||||
while retry < 5: # 最大重试次数
|
||||
try:
|
||||
with open(image_path, 'rb') as image_file: # 确保文件在上传时是打开状态
|
||||
files = {'file': image_file}
|
||||
|
||||
# 上传文件
|
||||
response = requests.post(upload_endpoint, headers=headers, files=files)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if result['status']:
|
||||
image_url = result['data']['links']['url']
|
||||
image_urls['images/'+image_name] = image_url
|
||||
print(f"图片上传成功: {image_url}")
|
||||
break # 上传成功,退出重试循环
|
||||
else:
|
||||
raise Exception(f"图片上传失败: {result['message']}")
|
||||
elif response.status_code == 429:
|
||||
# 429 响应,等待一段时间再重试
|
||||
wait_time = 3
|
||||
# wait_time = min(2 ** retry, 10) # 指数退避,最大等待 10 秒
|
||||
# logger.warning(f"请求过于频繁,等待 {wait_time} 秒...")
|
||||
print(f"请求过于频繁,等待 {wait_time} 秒...")
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
raise Exception(f"HTTP请求出错: {response.status_code}")
|
||||
|
||||
retry += 1 # 增加重试次数
|
||||
time.sleep(1) # 在重试失败后稍等一下
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"文件 {image_path} 不存在,请检查路径是否正确")
|
||||
return
|
||||
|
||||
return image_urls
|
||||
|
||||
# 保存图片到本地,并确保生成的文件名唯一
|
||||
def save_images_locally(img_dir, target_dir):
|
||||
if not os.path.exists(target_dir):
|
||||
os.makedirs(target_dir)
|
||||
|
||||
image_urls = {}
|
||||
|
||||
img_names = os.listdir(img_dir)
|
||||
|
||||
# 遍历图片并保存到目标文件夹
|
||||
for image_name in img_names:
|
||||
image_path = os.path.join(img_dir, image_name)
|
||||
|
||||
# 使用UUID生成唯一的文件名,以保持图片名称的唯一性
|
||||
unique_name = f"{uuid.uuid4()}{os.path.splitext(image_name)[1]}" # 保留原扩展名
|
||||
save_path = os.path.join(target_dir, unique_name)
|
||||
|
||||
try:
|
||||
# 复制文件到目标目录
|
||||
shutil.copy2(image_path, save_path)
|
||||
# 将图片名称与保存路径加入字典
|
||||
image_urls[f'images/{unique_name}'] = save_path
|
||||
print(f"图片保存成功: {save_path}")
|
||||
except FileNotFoundError:
|
||||
print(f"文件 {image_path} 不存在,跳过该图片")
|
||||
except Exception as e:
|
||||
print(f"保存图片 {image_name} 过程中发生错误: {e}")
|
||||
|
||||
return image_urls
|
||||
|
||||
def json_md_dump(
|
||||
pipe,
|
||||
md_writer,
|
||||
pdf_name,
|
||||
content_list,
|
||||
md_content,
|
||||
):
|
||||
# 写入模型结果到 model.json
|
||||
orig_model_list = copy.deepcopy(pipe.model_list)
|
||||
md_writer.write(
|
||||
content=json.dumps(orig_model_list, ensure_ascii=False, indent=4),
|
||||
path=f"{pdf_name}_model.json"
|
||||
)
|
||||
|
||||
# 写入中间结果到 middle.json
|
||||
md_writer.write(
|
||||
content=json.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4),
|
||||
path=f"{pdf_name}_middle.json"
|
||||
)
|
||||
|
||||
# text文本结果写入到 conent_list.json
|
||||
md_writer.write(
|
||||
content=json.dumps(content_list, ensure_ascii=False, indent=4),
|
||||
path=f"{pdf_name}_content_list.json"
|
||||
)
|
||||
|
||||
# 写入结果到 .md 文件中
|
||||
md_writer.write(
|
||||
content=md_content,
|
||||
path=f"{pdf_name}.md"
|
||||
)
|
||||
|
||||
|
||||
def pdf_parse_main(
|
||||
pdf_path: str,
|
||||
parse_method: str = 'auto',
|
||||
model_json_path: str = None,
|
||||
is_json_md_dump: bool = True,
|
||||
output_dir: str = None
|
||||
):
|
||||
"""
|
||||
执行从 pdf 转换到 json、md 的过程,输出 md 和 json 文件到 pdf 文件所在的目录
|
||||
|
||||
:param pdf_path: .pdf 文件的路径,可以是相对路径,也可以是绝对路径
|
||||
:param parse_method: 解析方法, 共 auto、ocr、txt 三种,默认 auto,如果效果不好,可以尝试 ocr
|
||||
:param model_json_path: 已经存在的模型数据文件,如果为空则使用内置模型,pdf 和 model_json 务必对应
|
||||
:param is_json_md_dump: 是否将解析后的数据写入到 .json 和 .md 文件中,默认 True,会将不同阶段的数据写入到不同的 .json 文件中(共3个.json文件),md内容会保存到 .md 文件中
|
||||
:param output_dir: 输出结果的目录地址,会生成一个以 pdf 文件名命名的文件夹并保存所有结果
|
||||
"""
|
||||
try:
|
||||
pdf_name = os.path.basename(pdf_path).split("/")[-1].replace(".pdf", "")
|
||||
pdf_path_parent = os.path.dirname(pdf_path)
|
||||
|
||||
if output_dir:
|
||||
output_path = os.path.join(output_dir, pdf_name)
|
||||
else:
|
||||
output_path = os.path.join(pdf_path_parent, pdf_name)
|
||||
|
||||
output_image_path = os.path.join(output_path, 'images')
|
||||
|
||||
# 获取图片的父路径,为的是以相对路径保存到 .md 和 conent_list.json 文件中
|
||||
image_path_parent = os.path.basename(output_image_path)
|
||||
|
||||
pdf_bytes = open(pdf_path, "rb").read() # 读取 pdf 文件的二进制数据
|
||||
|
||||
if model_json_path:
|
||||
# 读取已经被模型解析后的pdf文件的 json 原始数据,list 类型
|
||||
model_json = json.loads(open(model_json_path, "r", encoding="utf-8").read())
|
||||
else:
|
||||
model_json = []
|
||||
|
||||
# 执行解析步骤
|
||||
# image_writer = DiskReaderWriter(output_image_path)
|
||||
image_writer, md_writer = DiskReaderWriter(output_image_path), DiskReaderWriter(output_path)
|
||||
|
||||
# 选择解析方式
|
||||
# jso_useful_key = {"_pdf_type": "", "model_list": model_json}
|
||||
# pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer)
|
||||
if parse_method == "auto":
|
||||
jso_useful_key = {"_pdf_type": "", "model_list": model_json}
|
||||
pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer)
|
||||
elif parse_method == "txt":
|
||||
pipe = TXTPipe(pdf_bytes, model_json, image_writer)
|
||||
elif parse_method == "ocr":
|
||||
pipe = OCRPipe(pdf_bytes, model_json, image_writer)
|
||||
else:
|
||||
logger.error("unknown parse method, only auto, ocr, txt allowed")
|
||||
exit(1)
|
||||
|
||||
# 执行分类
|
||||
pipe.pipe_classify()
|
||||
|
||||
# 如果没有传入模型数据,则使用内置模型解析
|
||||
if not model_json:
|
||||
if model_config.__use_inside_model__:
|
||||
pipe.pipe_analyze() # 解析
|
||||
else:
|
||||
logger.error("need model list input")
|
||||
exit(1)
|
||||
|
||||
# 执行解析
|
||||
pipe.pipe_parse()
|
||||
|
||||
# 保存 text 和 md 格式的结果
|
||||
content_list = pipe.pipe_mk_uni_format(image_path_parent, drop_mode="none")
|
||||
md_content = pipe.pipe_mk_markdown(image_path_parent, drop_mode="none")
|
||||
# 上传图像到图床
|
||||
# image_urls = upload_image(output_image_path)
|
||||
# 保存图像到本地
|
||||
target_dir = "mp_cif/images"
|
||||
image_urls = save_images_locally(output_image_path, target_dir)
|
||||
md_content = replace_image_links(md_content, image_urls)
|
||||
|
||||
mysql_connection = mysql.connector.connect(
|
||||
host='100.84.94.73',
|
||||
user='metadata_mat_papers',
|
||||
password='siat-mic',
|
||||
database='metadata_mat_papers',
|
||||
charset="utf8mb4", # 设置连接使用 utf8mb4
|
||||
collation="utf8mb4_unicode_ci" # 使用适当的 collation
|
||||
)
|
||||
mysql_cursor = mysql_connection.cursor()
|
||||
|
||||
table = 'mp_cif_info'
|
||||
# path = 'phosphorus/pdfs/' + pdf_name + '.pdf'
|
||||
# print("path:", path)
|
||||
doi = os.path.basename(pdf_path).replace(".pdf", "").replace('_', '/')
|
||||
|
||||
try:
|
||||
# 编写query语句
|
||||
query = f"UPDATE {table} SET en_text_content = %s WHERE doi = %s"
|
||||
mysql_cursor.execute(query, (md_content, doi))
|
||||
print(f"{doi},md保存成功")
|
||||
|
||||
# 提交更改到数据库
|
||||
mysql_connection.commit()
|
||||
|
||||
except mysql.connector.Error as error:
|
||||
print("Failed to insert record into MySQL table: {}".format(error))
|
||||
# 如果发生错误,撤回事务
|
||||
mysql_connection.rollback()
|
||||
|
||||
finally:
|
||||
# 关闭游标和连接
|
||||
mysql_cursor.close()
|
||||
mysql_connection.close()
|
||||
|
||||
if is_json_md_dump:
|
||||
json_md_dump(pipe, md_writer, pdf_name, content_list, md_content)
|
||||
return 'sucess'
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return 'error'
|
||||
|
||||
|
||||
def check_doi_not_in_db(pdf_name, cursor):
|
||||
query = f"SELECT * FROM doi_status WHERE doi = ? AND convert_status = ? "
|
||||
cursor.execute(query, (pdf_name, 'unprocessed'))
|
||||
res = cursor.fetchone()
|
||||
if res:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def init_worker(devices, pdfs, gpu_index, process_id):
|
||||
"""
|
||||
Initialize a worker process to process a chunk of PDFs with a specific GPU.
|
||||
"""
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_index)
|
||||
process_pdf_chunk(pdfs, gpu_index, process_id)
|
||||
|
||||
def get_converted2md_dois():
|
||||
table = 'mp_cif_info'
|
||||
|
||||
dois = []
|
||||
|
||||
mysql_connection = mysql.connector.connect(
|
||||
host='100.84.94.73',
|
||||
user='metadata_mat_papers',
|
||||
password='siat-mic',
|
||||
database='metadata_mat_papers',
|
||||
charset="utf8mb4", # 设置连接使用 utf8mb4
|
||||
collation="utf8mb4_unicode_ci" # 使用适当的 collation
|
||||
)
|
||||
mysql_cursor = mysql_connection.cursor()
|
||||
|
||||
try:
|
||||
# 编写query语句
|
||||
query = f"SELECT doi FROM {table} WHERE en_text_content IS NOT NULL;"
|
||||
mysql_cursor.execute(query)
|
||||
res = mysql_cursor.fetchall()
|
||||
dois = [row[0] for row in res if row]
|
||||
except mysql.connector.Error as error:
|
||||
# 如果发生错误,撤回事务
|
||||
mysql_connection.rollback()
|
||||
finally:
|
||||
# 关闭游标和连接
|
||||
mysql_cursor.close()
|
||||
mysql_connection.close()
|
||||
return dois
|
||||
|
||||
def is_within_operational_hours(start_hour, end_hour):
|
||||
now = datetime.now().time() # 获取当前时间(不含日期)
|
||||
current_hour = now.hour # 获取当前小时
|
||||
|
||||
# 检查是否在晚上6点到第二天早上9点范围
|
||||
if start_hour > end_hour:
|
||||
return (current_hour >= start_hour or current_hour < end_hour) # 跨过午夜
|
||||
else:
|
||||
return start_hour <= current_hour < end_hour
|
||||
|
||||
def process_pdf_chunk(pdf_paths, gpu_index, process_id):
|
||||
for pdf_path in tqdm(pdf_paths, desc=f"Worker {gpu_index}_{process_id} Progress"):
|
||||
# 在规定时间内运行任务
|
||||
start_hour = 15 # 18点(晚上6点)
|
||||
end_hour = 9 # 9点(次日早上9点)
|
||||
|
||||
# 检查当前时间是否在允许的时间范围
|
||||
while True:
|
||||
if is_within_operational_hours(start_hour, end_hour):
|
||||
print("当前时间在任务运行区间内,开始处理PDF文件...")
|
||||
try:
|
||||
with open(pdf_path, 'rb') as file:
|
||||
pdf_reader = PyPDF2.PdfReader(file)
|
||||
print(os.path.basename(pdf_path).replace(".pdf", "").replace('_', '/'))
|
||||
status = pdf_parse_main(pdf_path, parse_method='auto', output_dir=output_dir)
|
||||
break # 执行结束,跳出循环
|
||||
except PyPDF2.errors.PdfReadError:
|
||||
logger.error(f"{pdf_path} has been broken")
|
||||
break # 执行异常,跳出循环
|
||||
except Exception as e:
|
||||
logger.error(f"{pdf_path} has an error: {e}")
|
||||
break # 执行异常,跳出循环
|
||||
else:
|
||||
# 当前时间不在允许的时间范围,阻塞任务
|
||||
print("当前时间不在运行区间,稍后重试...")
|
||||
time.sleep(60 * 60) # 沉睡1小时后再次检查
|
||||
|
||||
def multiprocessing_setup(pdf_paths, num_gpus):
|
||||
num_processes_per_gpu = 3
|
||||
chunk_size = len(pdf_paths) // (num_gpus * num_processes_per_gpu)
|
||||
processes = []
|
||||
|
||||
# Create processes for each GPU
|
||||
for gpu_id in range(num_gpus):
|
||||
for process_id in range(num_processes_per_gpu):
|
||||
start_idx = (gpu_id * num_processes_per_gpu + process_id) * chunk_size
|
||||
end_idx = None if (gpu_id == num_gpus - 1 and process_id == num_processes_per_gpu - 1) else start_idx + chunk_size
|
||||
chunk = pdf_paths[start_idx:end_idx]
|
||||
|
||||
p = multiprocessing.Process(target=init_worker, args=([gpu_id], chunk, gpu_id, process_id))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
# Ensure all processes have completed
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_cur_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# 此处更改路径
|
||||
pdf_dir = os.path.join(_cur_dir, "mp_cif/pdfs")
|
||||
output_dir = os.path.join(_cur_dir, "mp_cif/mds")
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
pdf_paths = sorted(glob(os.path.join(pdf_dir, "*.pdf")))
|
||||
|
||||
dois = get_converted2md_dois()
|
||||
print(len(dois))
|
||||
new_pdf_paths = pdf_paths[:]
|
||||
for path in tqdm(pdf_paths):
|
||||
doi = os.path.basename(path).replace(".pdf", "").replace('_', '/')
|
||||
if doi in dois:
|
||||
new_pdf_paths.remove(path)
|
||||
print(len(new_pdf_paths))
|
||||
|
||||
# Number of GPUs
|
||||
num_gpus = 8
|
||||
|
||||
# Setup multiprocessing to handle PDFs across multiple GPUs
|
||||
multiprocessing_setup(new_pdf_paths, num_gpus)
|
||||
|
||||
160
clean/stp1_bib2sql.py
Normal file
160
clean/stp1_bib2sql.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import os
|
||||
import glob
|
||||
import mysql.connector
|
||||
import bibtexparser
|
||||
import tqdm
|
||||
|
||||
|
||||
TABLE_NAME = 'phosphorus_synthesis_info'
|
||||
input('你确定TABLE_NAME是{}吗?'.format(TABLE_NAME))
|
||||
|
||||
# phosphorus_synthesis
|
||||
bibs_dir = os.path.join(os.path.dirname(__file__), 'synthesis23-25')
|
||||
if_file_path = os.path.join(os.path.dirname(__file__), '2023JCR.xlsx')
|
||||
input('你确定导入文件夹是{}吗?'.format(bibs_dir))
|
||||
|
||||
# MySQL connection setup
|
||||
connection = mysql.connector.connect(
|
||||
host='localhost',
|
||||
user='metadata_mat_papers',
|
||||
password='siat-mic',
|
||||
database='metadata_mat_papers'
|
||||
)
|
||||
cursor = connection.cursor()
|
||||
|
||||
|
||||
# Function to check if a table exists
|
||||
def check_table_exists(table_name):
|
||||
cursor.execute(f"""
|
||||
SELECT COUNT(*)
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '{table_name}'
|
||||
""")
|
||||
return cursor.fetchone()[0] == 1
|
||||
|
||||
# Function to create the table if it doesn't exist
|
||||
def create_table(table_name):
|
||||
if not check_table_exists(table_name):
|
||||
query = f"""
|
||||
CREATE TABLE IF NOT EXISTS `{table_name}` (
|
||||
doi VARCHAR(255) PRIMARY KEY,
|
||||
unique_id VARCHAR(255),
|
||||
author TEXT,
|
||||
title TEXT,
|
||||
journal VARCHAR(255),
|
||||
year INT,
|
||||
volume VARCHAR(50),
|
||||
number VARCHAR(50),
|
||||
pages VARCHAR(50),
|
||||
month VARCHAR(50),
|
||||
issn VARCHAR(50),
|
||||
eissn VARCHAR(50),
|
||||
researcher_id TEXT,
|
||||
if2023 VARCHAR(50),
|
||||
if5 VARCHAR(50),
|
||||
journal_index VARCHAR(50),
|
||||
jcr_quartile VARCHAR(50),
|
||||
orcid TEXT,
|
||||
early_access_date VARCHAR(50),
|
||||
scihub_downlowded VARCHAR(50),
|
||||
convert2md VARCHAR(50),
|
||||
pdf_url TEXT,
|
||||
md_url TEXT,
|
||||
abstract TEXT,
|
||||
image_url JSON,
|
||||
text_content LONGTEXT
|
||||
);
|
||||
"""
|
||||
cursor.execute(query)
|
||||
|
||||
def record_exists(doi, table_name):
|
||||
query = f"SELECT COUNT(*) FROM `{table_name}` WHERE doi = %s"
|
||||
cursor.execute(query, (doi,))
|
||||
count = cursor.fetchone()[0]
|
||||
return count > 0
|
||||
|
||||
# Function to insert a record into the MySQL database
|
||||
def insert_record(entry, table_name):
|
||||
# 定义列名列表
|
||||
columns = [
|
||||
'doi', 'unique_id', 'author', 'title', 'journal', 'year', 'volume',
|
||||
'number', 'pages', 'month', 'issn', 'eissn', 'researcher_id', 'if2023', 'if5', 'journal_index', 'jcr_quartile',
|
||||
'orcid', 'early_access_date', 'scihub_downlowded', 'convert2md', 'pdf_url', 'md_url', 'abstract', 'image_url', 'text_content'
|
||||
]
|
||||
|
||||
# 构建SQL查询语句
|
||||
placeholders = ', '.join(['%s'] * len(columns))
|
||||
query = f"""
|
||||
INSERT INTO `{table_name}` ({', '.join(columns)})
|
||||
VALUES ({placeholders})
|
||||
"""
|
||||
|
||||
values = (
|
||||
entry.get('doi'),
|
||||
entry.get('unique-id'),
|
||||
entry.get('author'),
|
||||
entry.get('title'),
|
||||
entry.get('journal'),
|
||||
entry.get('year'),
|
||||
entry.get('volume'),
|
||||
entry.get('number', None),
|
||||
entry.get('pages', None),
|
||||
entry.get('month', None),
|
||||
entry.get('issn', None),
|
||||
entry.get('eissn', None),
|
||||
entry.get('researcherid-numbers', None),
|
||||
entry.get('if2023', None),
|
||||
entry.get('if5', None),
|
||||
entry.get('journal_index', None),
|
||||
entry.get('jcr_quartile', None),
|
||||
entry.get('ocrid-numbers', None),
|
||||
entry.get('earlyaccessdate', None),
|
||||
entry.get('scihub_downlowded', None),
|
||||
entry.get('convert2md', None),
|
||||
entry.get('pdf_url', None),
|
||||
entry.get('md_url', None),
|
||||
entry.get('abstract', None),
|
||||
entry.get('image_url', None),
|
||||
entry.get('text_content', None)
|
||||
)
|
||||
cursor.execute(query, values)
|
||||
|
||||
|
||||
|
||||
# 用pandas打开excel文件
|
||||
import pandas as pd
|
||||
df = pd.read_excel(if_file_path)
|
||||
# 替换所有的nan为None
|
||||
df = df.replace({pd.NA: None})
|
||||
|
||||
# Create the table if it doesn't exist
|
||||
create_table(TABLE_NAME)
|
||||
|
||||
bib_files = sorted(glob.glob(os.path.join(bibs_dir, '*.bib')))
|
||||
for bib_file in tqdm.tqdm(bib_files):
|
||||
# Read and parse the .bib file
|
||||
with open(bib_file, 'r') as bibtex_file:
|
||||
bib_database = bibtexparser.load(bibtex_file)
|
||||
for entry in bib_database.entries:
|
||||
entry = {k.lower(): v for k, v in entry.items()}
|
||||
journal = entry.get('journal')
|
||||
if journal is not None:
|
||||
journal_lower = journal.lower() # 将期刊名称转为小写以进行不区分大小写的匹配
|
||||
matching_journal = df[df['JournalName'].str.lower() == journal_lower] # 在DataFrame中查找该期刊
|
||||
if not matching_journal.empty:
|
||||
entry['if2023'] = matching_journal['IF2023'].values[0]
|
||||
entry['if5'] = matching_journal['IF5'].values[0]
|
||||
entry['journal_index'] = matching_journal['INDEX'].values[0]
|
||||
entry['jcr_quartile'] = matching_journal['Quartile'].values[0]
|
||||
|
||||
doi = entry.get('doi')
|
||||
# 先检查记录是否存在,同时doi不能为空
|
||||
if not record_exists(doi, TABLE_NAME) and doi is not None:
|
||||
insert_record(entry, TABLE_NAME)
|
||||
|
||||
# Commit the changes and close the connection
|
||||
connection.commit()
|
||||
cursor.close()
|
||||
connection.close()
|
||||
print("Data has been inserted into the database!")
|
||||
193
clean/stp1_excel2sql.py
Normal file
193
clean/stp1_excel2sql.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import os
|
||||
import mysql.connector
|
||||
|
||||
|
||||
TABLE_NAME = 'crispr_papers_info'
|
||||
input('你确定TABLE_NAME是{}吗?'.format(TABLE_NAME))
|
||||
|
||||
# phosphorus_synthesis
|
||||
excels_dir = os.path.join(os.path.dirname(__file__), 'CRISPR/CRISPR_engineered')
|
||||
if_file_path = os.path.join(os.path.dirname(__file__), 'CRISPR/2023JCR.xlsx')
|
||||
input('你确定导入文件夹是{}吗?'.format(excels_dir))
|
||||
|
||||
# MySQL connection setup
|
||||
connection = mysql.connector.connect(
|
||||
host='100.84.94.73',
|
||||
user='metadata_mat_papers',
|
||||
password='siat-mic',
|
||||
database='metadata_mat_papers'
|
||||
)
|
||||
cursor = connection.cursor()
|
||||
|
||||
|
||||
# Function to check if a table exists
|
||||
def check_table_exists(table_name):
|
||||
cursor.execute(f"""
|
||||
SELECT COUNT(*)
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '{table_name}'
|
||||
""")
|
||||
return cursor.fetchone()[0] == 1
|
||||
|
||||
|
||||
# Function to create the table if it doesn't exist
|
||||
def create_table(table_name):
|
||||
if not check_table_exists(table_name):
|
||||
query = f"""
|
||||
CREATE TABLE IF NOT EXISTS `{table_name}` (
|
||||
doi VARCHAR(255) PRIMARY KEY,
|
||||
unique_id VARCHAR(255),
|
||||
author TEXT,
|
||||
title TEXT,
|
||||
journal VARCHAR(255),
|
||||
year INT,
|
||||
volume VARCHAR(50),
|
||||
number VARCHAR(50),
|
||||
pages VARCHAR(50),
|
||||
month VARCHAR(50),
|
||||
issn VARCHAR(50),
|
||||
eissn VARCHAR(50),
|
||||
researcher_id TEXT,
|
||||
if2023 VARCHAR(50),
|
||||
if5 VARCHAR(50),
|
||||
journal_index VARCHAR(50),
|
||||
jcr_quartile VARCHAR(50),
|
||||
orcid TEXT,
|
||||
early_access_date VARCHAR(50),
|
||||
scihub_downlowded VARCHAR(50),
|
||||
convert2md VARCHAR(50),
|
||||
pdf_url TEXT,
|
||||
md_url TEXT,
|
||||
abstract TEXT,
|
||||
image_url JSON,
|
||||
en_text_content LONGTEXT,
|
||||
cited_reference_count INT,
|
||||
doi_link TEXT,
|
||||
research_areas TEXT,
|
||||
unique_wos_id VARCHAR(255)
|
||||
);
|
||||
"""
|
||||
cursor.execute(query)
|
||||
|
||||
|
||||
def record_exists(doi, table_name):
|
||||
query = f"SELECT COUNT(*) FROM `{table_name}` WHERE doi = %s"
|
||||
cursor.execute(query, (doi,))
|
||||
count = cursor.fetchone()[0]
|
||||
return count > 0
|
||||
|
||||
|
||||
# Function to insert a record into the MySQL database
|
||||
def insert_record(entry, table_name):
|
||||
# 定义列名列表
|
||||
columns = [
|
||||
'doi', 'unique_id', 'author', 'title', 'journal', 'year', 'volume',
|
||||
'number', 'pages', 'month', 'issn', 'eissn', 'researcher_id', 'if2023', 'if5', 'journal_index', 'jcr_quartile',
|
||||
'orcid', 'early_access_date', 'scihub_downlowded', 'convert2md', 'pdf_url', 'md_url', 'abstract', 'image_url',
|
||||
'text_content', 'cited_reference_count', 'doi_link', 'research_areas', 'unique_wos_id'
|
||||
]
|
||||
|
||||
# 构建SQL查询语句
|
||||
placeholders = ', '.join(['%s'] * len(columns))
|
||||
query = f"""
|
||||
INSERT INTO `{table_name}` ({', '.join(columns)})
|
||||
VALUES ({placeholders})
|
||||
"""
|
||||
|
||||
values = (
|
||||
entry.get('doi'),
|
||||
entry.get('unique-id'),
|
||||
entry.get('author'),
|
||||
entry.get('title'),
|
||||
entry.get('journal'),
|
||||
entry.get('year'),
|
||||
entry.get('volume'),
|
||||
entry.get('number', None),
|
||||
entry.get('pages', None),
|
||||
entry.get('month', None),
|
||||
entry.get('issn', None),
|
||||
entry.get('eissn', None),
|
||||
entry.get('researcherid-numbers', None),
|
||||
entry.get('if2023', None),
|
||||
entry.get('if5', None),
|
||||
entry.get('journal_index', None),
|
||||
entry.get('jcr_quartile', None),
|
||||
entry.get('ocrid-numbers', None),
|
||||
entry.get('earlyaccessdate', None),
|
||||
entry.get('scihub_downlowded', None),
|
||||
entry.get('convert2md', None),
|
||||
entry.get('pdf_url', None),
|
||||
entry.get('md_url', None),
|
||||
entry.get('abstract', None),
|
||||
entry.get('image_url', None),
|
||||
entry.get('text_content', None),
|
||||
entry.get('cited_reference_count', None),
|
||||
entry.get('doi_link', None),
|
||||
entry.get('research_areas', None),
|
||||
entry.get('unique_wos_id', None)
|
||||
)
|
||||
cursor.execute(query, values)
|
||||
|
||||
|
||||
# 用pandas打开excel文件
|
||||
import pandas as pd
|
||||
|
||||
df = pd.read_excel(if_file_path)
|
||||
# 替换所有的nan为None
|
||||
df = df.replace({pd.NA: None})
|
||||
|
||||
# Create the table if it doesn't exist
|
||||
create_table(TABLE_NAME)
|
||||
|
||||
excels_file_list = []
|
||||
for file in os.listdir(excels_dir): # os.listdir('溶剂热文献-230505-swx-V3')
|
||||
if file.endswith('.xls'):
|
||||
excels_file_list.append(os.path.splitext(file)[0])
|
||||
|
||||
|
||||
for excels_file in excels_file_list:
|
||||
print(os.path.join(excels_dir, excels_file + '.xls'))
|
||||
|
||||
# 指定Excel文件路径
|
||||
file_path = os.path.join(excels_dir, excels_file + '.xls')
|
||||
|
||||
# 读取Excel文件
|
||||
excel_df = pd.read_excel(file_path)
|
||||
# 替换所有的nan为None
|
||||
excel_df = excel_df.replace({pd.NA: None})
|
||||
|
||||
# 显示DataFrame的前几行
|
||||
# print(df.head(5))
|
||||
for i in range(len(excel_df)):
|
||||
entry = dict()
|
||||
entry['doi'] = str(excel_df.loc[i, 'DOI'])
|
||||
entry['title'] = str(excel_df.loc[i, 'Article Title'])
|
||||
entry['journal'] = str(excel_df.loc[i, 'Source Title'])
|
||||
entry['abstract'] = str(excel_df.loc[i, 'Abstract'])
|
||||
entry['cited_reference_count'] = int(excel_df.loc[i, 'Cited Reference Count'])
|
||||
entry['year'] = int(excel_df.loc[i, 'Publication Year'])
|
||||
entry['doi_link'] = str(excel_df.loc[i, 'DOI Link'])
|
||||
entry['research_areas'] = str(excel_df.loc[i, 'Research Areas'])
|
||||
entry['unique_wos_id'] = str(excel_df.loc[i, 'UT (Unique WOS ID)'])
|
||||
|
||||
journal = entry.get('journal')
|
||||
if journal is not None:
|
||||
journal_lower = journal.lower() # 将期刊名称转为小写以进行不区分大小写的匹配
|
||||
matching_journal = df[df['JournalName'].str.lower() == journal_lower] # 在DataFrame中查找该期刊
|
||||
if not matching_journal.empty:
|
||||
entry['if2023'] = matching_journal['IF2023'].values[0]
|
||||
entry['if5'] = matching_journal['IF5'].values[0]
|
||||
entry['journal_index'] = matching_journal['INDEX'].values[0]
|
||||
entry['jcr_quartile'] = matching_journal['Quartile'].values[0]
|
||||
|
||||
doi = entry.get('doi')
|
||||
# 先检查记录是否存在,同时doi不能为空
|
||||
if not record_exists(doi, TABLE_NAME) and doi is not None:
|
||||
insert_record(entry, TABLE_NAME)
|
||||
|
||||
# Commit the changes and close the connection
|
||||
connection.commit()
|
||||
cursor.close()
|
||||
connection.close()
|
||||
print("Data has been inserted into the database!")
|
||||
65
clean/stp2.1_migrate_download_sqlite2mysql.py
Normal file
65
clean/stp2.1_migrate_download_sqlite2mysql.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# 脚本是为了将SQLite数据库中的数据迁移到MySQL数据库中。
|
||||
# 专门针对使用sqlite阶段写的代码,如果后续直接对Mysql做操作,就不要用这个脚本
|
||||
|
||||
import sqlite3
|
||||
import mysql.connector
|
||||
|
||||
TABLE_NAME = 'phosphorus_synthesis_info'
|
||||
input('你确定TABLE_NAME是{}吗?'.format(TABLE_NAME))
|
||||
|
||||
# SQLite setup
|
||||
sqlite_connection = sqlite3.connect('/home/ubuntu/workplace/LYT/llm-agent/phosphorus/doi_status.db') # Ensure this is your actual SQLite database file
|
||||
sqlite_cursor = sqlite_connection.cursor()
|
||||
|
||||
# MySQL connection setup
|
||||
mysql_connection = mysql.connector.connect(
|
||||
host='100.84.94.73',
|
||||
user='metadata_mat_papers',
|
||||
password='siat-mic',
|
||||
database='metadata_mat_papers'
|
||||
)
|
||||
mysql_cursor = mysql_connection.cursor()
|
||||
|
||||
# Define the SQLite query to retrieve data
|
||||
sqlite_query = "SELECT doi, status, pdf_url FROM doi_status" # Ensure these field names match your SQLite table
|
||||
|
||||
# Function to check if a record exists in the MySQL database
|
||||
def record_exists(doi, table_name):
|
||||
query = f"SELECT COUNT(*) FROM `{table_name}` WHERE doi = %s"
|
||||
mysql_cursor.execute(query, (doi,))
|
||||
count = mysql_cursor.fetchone()[0]
|
||||
return count > 0
|
||||
|
||||
# Function to update a record in the MySQL database
|
||||
def update_record(doi, scihub_downlowded, pdf_url, table_name):
|
||||
query = f"""
|
||||
UPDATE `{table_name}`
|
||||
SET scihub_downlowded = %s, pdf_url = %s
|
||||
WHERE doi = %s
|
||||
"""
|
||||
mysql_cursor.execute(query, (scihub_downlowded, pdf_url, doi))
|
||||
|
||||
# Fetch data from SQLite
|
||||
sqlite_cursor.execute(sqlite_query)
|
||||
rows = sqlite_cursor.fetchall()
|
||||
|
||||
# Iterate over SQLite rows and update MySQL records
|
||||
for row in rows:
|
||||
doi, scihub_downlowded, pdf_url = row
|
||||
if record_exists(doi, TABLE_NAME): # Replace with your actual MySQL table name
|
||||
update_record(doi, scihub_downlowded, pdf_url, TABLE_NAME) # Adjust table name if necessary
|
||||
else:
|
||||
# You can choose to handle non-existent DOI entries differently if necessary
|
||||
print(f"Record with DOI {doi} does not exist in MySQL database.")
|
||||
|
||||
|
||||
# Commit the changes to the MySQL database
|
||||
mysql_connection.commit()
|
||||
|
||||
# Close connections
|
||||
sqlite_cursor.close()
|
||||
sqlite_connection.close()
|
||||
mysql_cursor.close()
|
||||
mysql_connection.close()
|
||||
|
||||
print("Data migration from SQLite to MySQL completed successfully!")
|
||||
28
clean/stp2.2_remove_broken_pdf.py
Normal file
28
clean/stp2.2_remove_broken_pdf.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import sqlite3
|
||||
import mysql.connector
|
||||
import tqdm
|
||||
import os
|
||||
|
||||
TABLE_NAME = 'phosphorus_synthesis_info'
|
||||
input('TABLE_NAME = {} ?'.format(TABLE_NAME))
|
||||
|
||||
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# MySQL connection setup
|
||||
mysql_connection = mysql.connector.connect(
|
||||
host='100.84.94.73',
|
||||
user='metadata_mat_papers',
|
||||
password='siat-mic',
|
||||
database='metadata_mat_papers'
|
||||
)
|
||||
mysql_cursor = mysql_connection.cursor()
|
||||
|
||||
|
||||
# 编写query语句
|
||||
query = f"SELECT pdf_url FROM {TABLE_NAME} WHERE scihub_downlowded = 'broken'"
|
||||
mysql_cursor.execute(query)
|
||||
records = mysql_cursor.fetchall()
|
||||
|
||||
for record in tqdm.tqdm(records):
|
||||
pdf_path = os.path.join(cur_dir, record[0])
|
||||
os.remove(pdf_path)
|
||||
211
clean/stp2_down_ipidea_multi.py
Normal file
211
clean/stp2_down_ipidea_multi.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import tqdm
|
||||
import requests
|
||||
import subprocess
|
||||
import concurrent.futures
|
||||
import sqlite3
|
||||
from scidownl import scihub_download
|
||||
import logging
|
||||
import pymupdf
|
||||
|
||||
|
||||
NUM_PROCESSES = 32 # 设置并发进程数
|
||||
SCIHUB_URLS = [
|
||||
"https://sci-hub.st/",
|
||||
"https://sci-hub.se/",
|
||||
"https://sci-hub.ru/"
|
||||
]
|
||||
PROXY_SERVICE_URL = f"http://api.proxy.ipidea.io/getProxyIp?num={NUM_PROCESSES}&tag=static_balance&return_type=txt&lb=1&sb=0&flow=1&protocol=http"
|
||||
SINGLE_PROXY_SERVICE_URL = f"http://api.proxy.ipidea.io/getProxyIp?num=1&tag=static_balance&return_type=txt&lb=1&sb=0&flow=1&protocol=http"
|
||||
DOI_PATTERN = re.compile(r"DOI\s*=\s*\{(10\.\d{4,9}/[-._;()/:A-Z0-9]+)\}", re.IGNORECASE)
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='[%(levelname)s] | %(asctime)s | %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_directories(bib_dir_name, output_dirname):
|
||||
current_path = os.path.dirname(os.path.abspath(__file__))
|
||||
output_dir = os.path.join(current_path, output_dirname)
|
||||
bib_dir_path = os.path.join(current_path, bib_dir_name)
|
||||
db_path = os.path.join(current_path, "doi_status.db")
|
||||
return output_dir, bib_dir_path, db_path
|
||||
|
||||
def create_directory_if_not_exists(directory):
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
def fetch_proxies():
|
||||
proxies = []
|
||||
try:
|
||||
response = requests.get(PROXY_SERVICE_URL)
|
||||
if response.status_code == 200:
|
||||
proxy_list = response.text.strip().split('\r\n')
|
||||
for proxy in proxy_list:
|
||||
proxies.append({
|
||||
"http": f"http://{proxy}",
|
||||
"https": f"http://{proxy}",
|
||||
})
|
||||
if proxies:
|
||||
logger.info(f"Fetched proxies: {proxies}")
|
||||
return proxies
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching proxies: {e}")
|
||||
return None
|
||||
|
||||
def fetch_proxy():
|
||||
proxies = []
|
||||
try:
|
||||
response = requests.get(SINGLE_PROXY_SERVICE_URL)
|
||||
if response.status_code == 200:
|
||||
proxy_list = response.text.strip().split('\r\n')
|
||||
for proxy in proxy_list:
|
||||
proxies.append({
|
||||
"http": f"http://{proxy}",
|
||||
"https": f"http://{proxy}",
|
||||
})
|
||||
if proxies:
|
||||
logger.info(f"Fetched proxies: {proxies}")
|
||||
return proxies
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching proxies: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def read_dois_from_files(bib_dir_path):
|
||||
all_dois = []
|
||||
for bib_file_name in sorted(os.listdir(bib_dir_path)):
|
||||
if bib_file_name.endswith(".bib"):
|
||||
with open(os.path.join(bib_dir_path, bib_file_name), "r") as file:
|
||||
dois = DOI_PATTERN.findall(file.read())
|
||||
logger.info(f"{bib_file_name} has {len(dois)} doi(s)")
|
||||
all_dois.extend(dois)
|
||||
return list(set(all_dois))
|
||||
|
||||
def filter_downloaded_dois(all_dois, output_dir):
|
||||
for doi in os.listdir(output_dir):
|
||||
if doi.endswith(".pdf"):
|
||||
doi = doi.replace(".pdf", "").replace("_", "/")
|
||||
if doi in all_dois:
|
||||
all_dois.remove(doi)
|
||||
return all_dois
|
||||
|
||||
def read_dois_from_db(db_path, status):
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f"SELECT doi FROM doi_status WHERE status = '{status}'")
|
||||
dois = [row[0] for row in cursor.fetchall()]
|
||||
conn.close()
|
||||
return dois
|
||||
|
||||
def write_doi_to_db(db_path, doi, output_dirname, status):
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("INSERT OR REPLACE INTO doi_status (doi, status, pdf_url) VALUES (?, ?, ?)", (doi, status, f"{output_dirname}/{doi.replace('/', '_')}.pdf"))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def initialize_db(db_path):
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS doi_status (
|
||||
doi TEXT PRIMARY KEY,
|
||||
status TEXT,
|
||||
pdf_url TEXT
|
||||
)
|
||||
''')
|
||||
conn.commit()
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def download_doi(doi, output_dir, proxy, scihub_urls, db_path):
|
||||
success_dois, broken_dois, failed_dois, timeout_dois = [], [], [], []
|
||||
output_dirname = output_dir.split("/")[-1]
|
||||
for scihub_url in scihub_urls:
|
||||
output_path = os.path.join(output_dir, f"{doi.replace('/', '_')}.pdf")
|
||||
proxy_url = "https=" + proxy['https']
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
['scidownl', 'download', '--doi', doi, '--out', output_path, '--scihub-url', scihub_url, '--proxy', proxy_url],
|
||||
check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
)
|
||||
logger.info(result.stderr)
|
||||
|
||||
if "No pdf tag" in result.stderr:
|
||||
timeout_dois.append(doi)
|
||||
write_doi_to_db(db_path, doi, output_dirname, 'timeout')
|
||||
break
|
||||
elif "403" in result.stderr or "Unable to connect to proxy" in result.stderr or "504" in result.stderr or 'crawling_failed, error: HTTPSConnectionPool' in result.stderr:
|
||||
logger.warning("Proxy error detected, fetching new proxy.")
|
||||
proxy = fetch_proxy()[0]
|
||||
# time.sleep(2)
|
||||
continue
|
||||
elif result.stdout.strip() != '':
|
||||
try:
|
||||
# 尝试打开pdf文件
|
||||
with pymupdf.open(output_path) as pdf:
|
||||
logger.info(f"Downloaded {doi} successfully.")
|
||||
write_doi_to_db(db_path, doi, output_dirname, 'success')
|
||||
success_dois.append(doi)
|
||||
except:
|
||||
write_doi_to_db(db_path, doi, output_dirname, 'broken')
|
||||
logger.info(f"{doi}.pdf has been broken!")
|
||||
broken_dois.append(doi)
|
||||
break
|
||||
else:
|
||||
write_doi_to_db(db_path, doi, output_dirname, 'failed')
|
||||
break
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"Error: {e}")
|
||||
failed_dois.append(doi)
|
||||
write_doi_to_db(db_path, doi, 'failed')
|
||||
continue
|
||||
|
||||
return success_dois, broken_dois, failed_dois, timeout_dois
|
||||
|
||||
def download_dois(all_dois, output_dir, db_path):
|
||||
success_dois, broken_dois, failed_dois, timeout_dois = [], [], [], []
|
||||
proxies = fetch_proxies()
|
||||
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=NUM_PROCESSES) as executor:
|
||||
futures = []
|
||||
for i, doi in enumerate(all_dois):
|
||||
proxy = proxies[i % len(proxies)]
|
||||
futures.append(executor.submit(download_doi, doi, output_dir, proxy, SCIHUB_URLS, db_path))
|
||||
|
||||
for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc='Downloading DOIs', unit='doi'):
|
||||
result = future.result()
|
||||
if result:
|
||||
success, broken, failed, timeout = result
|
||||
success_dois.extend(success)
|
||||
broken_dois.extend(broken)
|
||||
failed_dois.extend(failed)
|
||||
timeout_dois.extend(timeout)
|
||||
|
||||
logger.info(f"Success: {len(success_dois)}, Broken: {len(broken_dois)}, Failed: {len(failed_dois)}, Timeout: {len(timeout_dois)}")
|
||||
|
||||
def main():
|
||||
bib_dir_name = "synthesis23-25"
|
||||
output_dirname = "synthesis23-25_pdfs"
|
||||
input('你确定是文件夹{}和{}吗?'.format(bib_dir_name, output_dirname))
|
||||
output_dir, bib_dir_path, db_path = get_directories(bib_dir_name, output_dirname)
|
||||
create_directory_if_not_exists(output_dir)
|
||||
|
||||
initialize_db(db_path)
|
||||
|
||||
all_dois = read_dois_from_files(bib_dir_path)
|
||||
logger.info(f"Total {len(all_dois)} doi(s)")
|
||||
|
||||
all_dois = filter_downloaded_dois(all_dois, output_dir)
|
||||
|
||||
all_dois = [doi for doi in all_dois if doi not in read_dois_from_db(db_path, 'success')]
|
||||
all_dois = [doi for doi in all_dois if doi not in read_dois_from_db(db_path, 'failed')]
|
||||
all_dois = [doi for doi in all_dois if doi not in read_dois_from_db(db_path, 'timeout')]
|
||||
|
||||
download_dois(all_dois, output_dir, db_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user