320 lines
12 KiB
Python
Executable File
320 lines
12 KiB
Python
Executable File
"""
|
||
Author: Yutang LI
|
||
Institution: SIAT-MIC
|
||
Contact: yt.li2@siat.ac.cn
|
||
"""
|
||
|
||
import os
|
||
import re
|
||
import json
|
||
from tqdm import tqdm
|
||
import logging
|
||
from openai import OpenAI
|
||
from config import ReparagraphConfig
|
||
|
||
# 配置logging
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||
handlers=[
|
||
logging.FileHandler('reparagraph.log'),
|
||
logging.StreamHandler()
|
||
]
|
||
)
|
||
|
||
|
||
def get_true_level(title_info: list, config: ReparagraphConfig):
|
||
source_title = json.dumps(title_info)
|
||
instruction = """
|
||
你是一个论文目录重排助手。
|
||
有如下的JSON格式的目录信息,已知目录中每级标题的内容和行号。
|
||
<PLACEHOLDER>
|
||
请你重排该论文的目录层级,并为每级标题的level字段给出正确的层级关系,其中层级关系用数字(1,2,3,4)表示,数字越小,层级越高。
|
||
注意:重排序目录要求多个1级标题的样式, 而非单一1级目录的样式。也就说level为1的标题数量必须大于1。
|
||
通常情况下位于一级标题的有可能是:
|
||
1. 论文的题目
|
||
2. 论文的摘要(Abstract)
|
||
3. 论文的介绍(Introduction)
|
||
4. 论文的方法或实验(Methods or Experiment)
|
||
5. 论文的结果或讨论(Result or Discussion)
|
||
6. 论文的结论(Conclusion)
|
||
7. 论文的参考文献(References)
|
||
8. 论文的致谢(Acknowledgments)
|
||
9. 论文的附录(Appendix)
|
||
10. 论文的支撑信息(Supporting Information)
|
||
有时候目录中存在序号,这时则优先使用序号顺序重建目录。
|
||
|
||
返回结果的时候严格遵守下列示例JSON格式:
|
||
{ 'data': [
|
||
{ 'title': 'A hierarchically porous MOF confined CsPbBr3 quantum dots: Fluorescence switching probe for detecting Cu (II) and melamine in food samples', 'line_num': 1, 'level': 1},
|
||
...
|
||
]
|
||
"""
|
||
# 创建 OpenAI 客户端
|
||
client = OpenAI(api_key=config.openai_api_key, base_url=config.openai_base_url)
|
||
messages = [
|
||
{"role": "system", "content": "You are a helpful assistant."},
|
||
{"role": "user", "content": instruction.replace("<PLACEHOLDER>", source_title)}
|
||
]
|
||
attempt = 0
|
||
while attempt < config.max_retries:
|
||
try:
|
||
completion = client.chat.completions.create(
|
||
model=config.model_name,
|
||
stream=False, # 关闭流模式
|
||
messages=messages,
|
||
response_format={
|
||
'type': 'json_object'
|
||
}
|
||
)
|
||
|
||
response = completion.choices[0].message.content
|
||
response = json.loads(response)
|
||
count_level_1 = sum(1 for item in response['data'] if item['level'] == 1)
|
||
if count_level_1 == 1:
|
||
attempt += 1
|
||
messages.append({"role": "assistant", "content": str(response)})
|
||
messages.append({"role": "user", "content": "上述目录中仅有1个1级标题, 请重新生成目录, 并保证目录中至少有两个1级标题。"})
|
||
continue
|
||
return response['data']
|
||
|
||
except (json.JSONDecodeError, Exception) as e:
|
||
logging.error(f"尝试 {attempt + 1}/{config.max_retries} 失败: {str(e)}")
|
||
if attempt == config.max_retries - 1:
|
||
logging.error("达到最大重试次数,放弃操作")
|
||
return "Error"
|
||
|
||
|
||
def read_file_content(file_path: str):
|
||
"""读取文件内容"""
|
||
with open(file_path, 'r', encoding='utf-8') as file:
|
||
return file.readlines()
|
||
|
||
def write_file_content(file_path: str, content: list):
|
||
"""写入文件内容"""
|
||
with open(file_path, 'w', encoding='utf-8') as file:
|
||
file.writelines(content)
|
||
|
||
def extract_headings(lines: list):
|
||
"""从文件内容中提取所有以#开头的行及其行号"""
|
||
headings = []
|
||
for line_num, line in enumerate(lines, 1):
|
||
if re.match(r'^#', line.strip()):
|
||
headings.append((line_num, line.strip()))
|
||
return headings
|
||
|
||
def extract_references(lines: list, headings: list, remove_refs: bool = False):
|
||
"""从文件内容中提取参考文献部分
|
||
Args:
|
||
lines: 文件内容列表
|
||
headings: 标题信息列表
|
||
remove_refs: 是否抹去参考文献内容
|
||
Returns:
|
||
dict: 包含起始点、结束点和内容的信息
|
||
{
|
||
'start': ref_start,
|
||
'end': ref_end,
|
||
'content': references,
|
||
'updated_headings': updated_headings
|
||
}
|
||
"""
|
||
# 在标题中查找REFERENCE
|
||
ref_heading = None
|
||
for line_num, heading in headings:
|
||
if "REFERENCE" in heading.upper().replace(" ", ""):
|
||
ref_heading = (line_num, heading)
|
||
break
|
||
|
||
if not ref_heading and "ACKNOWLEDGEMENT" in heading.upper().replace(" ", ""):
|
||
ref_heading = (line_num, heading)
|
||
|
||
if not ref_heading:
|
||
# 用正则匹配常见的引用格式并删除
|
||
# 包括:[数字]、数字.、(数字) 格式
|
||
ref_pattern = r'^(\[\d+\]|\d+\.|\(\d+\))'
|
||
lines = [line for line in lines if not re.match(ref_pattern, line.strip())]
|
||
return {
|
||
'start': -1,
|
||
'end': -1,
|
||
'content': None
|
||
}, lines
|
||
|
||
ref_start = ref_heading[0] - 1 # 转换为0-based索引
|
||
|
||
# 查找下一个标题或文件结尾
|
||
ref_end = len(lines)
|
||
for i in range(ref_start + 1, len(lines)):
|
||
if re.match(r'^#', lines[i].strip()):
|
||
ref_end = i
|
||
break
|
||
|
||
# 提取参考文献内容
|
||
references = ''.join(lines[ref_start:ref_end])
|
||
|
||
# 如果需要抹去内容
|
||
if remove_refs:
|
||
lines[ref_start:ref_end] = []
|
||
|
||
# # 如果需要更新headings
|
||
# updated_headings = headings
|
||
# if remove_refs and ref_heading:
|
||
# # 从headings中移除Reference标题
|
||
# updated_headings = [h for h in headings if h[1].upper() != ref_heading[1].upper()]
|
||
|
||
return {
|
||
'start': ref_start,
|
||
'end': ref_end,
|
||
'content': references,
|
||
#'updated_headings': updated_headings
|
||
}, lines
|
||
|
||
def update_headings(lines: list, heading_data: list):
|
||
"""根据提供的标题数据更新Markdown文件内容"""
|
||
# 统计heading_data中level==1的数量
|
||
# count_level_1 = sum(1 for item in heading_data if item['level'] == 1)
|
||
# flag = 2 if count_level_1 > 1 else 3 # 存在多个一级标题是为2否则为3
|
||
|
||
for heading in heading_data:
|
||
line_num = heading['line_num'] - 1
|
||
if heading['level'] >= 2:#flag:
|
||
lines[line_num] = "**" + lines[line_num].replace("#", "").strip() + "**\n"
|
||
return lines
|
||
|
||
|
||
def detect_file_encoding(file_path: str):
|
||
"""检测文件编码"""
|
||
import chardet
|
||
with open(file_path, 'rb') as f:
|
||
raw_data = f.read(1024)
|
||
result = chardet.detect(raw_data)
|
||
return result['encoding']
|
||
|
||
# def read_file_content(file_path: str, config: ReparagraphConfig):
|
||
# """读取文件内容,带大小检查和编码检测"""
|
||
# file_size = os.path.getsize(file_path)
|
||
# if file_size > config.max_file_size:
|
||
# logging.warning(f"文件 {file_path} 超过最大限制 {config.max_file_size} bytes,跳过处理")
|
||
# return None
|
||
|
||
# encoding = detect_file_encoding(file_path)
|
||
# try:
|
||
# with open(file_path, 'r', encoding=encoding) as file:
|
||
# return file.readlines()
|
||
# except UnicodeDecodeError:
|
||
# logging.error(f"无法解码文件 {file_path},尝试使用utf-8")
|
||
# with open(file_path, 'r', encoding='utf-8') as file:
|
||
# return file.readlines()
|
||
|
||
def process_single_file(file_path: str, config: ReparagraphConfig):
|
||
"""处理单个文件并返回处理后的内容"""
|
||
# 读取文件内容
|
||
lines = read_file_content(file_path)
|
||
if lines is None:
|
||
return None
|
||
|
||
# 提取并更新标题
|
||
headings = extract_headings(lines)
|
||
title_info = [{"title": heading, "line_num": line_num, "level": "unknown"}
|
||
for line_num, heading in headings]
|
||
|
||
# 提取参考文献
|
||
ref_info, lines = extract_references(lines, headings, remove_refs=config.remove_refs)
|
||
if ref_info:
|
||
logging.info("提取的参考文献:")
|
||
logging.info(f"起始行: {ref_info['start'] + 1}")
|
||
logging.info(f"结束行: {ref_info['end']}")
|
||
logging.info("内容:")
|
||
logging.info(ref_info['content'])
|
||
# 更新headings
|
||
# headings = ref_info['updated_headings']
|
||
else:
|
||
logging.warning("未找到参考文献部分")
|
||
|
||
# 删除reference后可能会导致标题的行号变化,重新索引
|
||
headings = extract_headings(lines)
|
||
title_info = [{"title": heading, "line_num": line_num, "level": "unknown"}
|
||
for line_num, heading in headings]
|
||
|
||
new_headings = get_true_level(title_info, config)
|
||
updated_lines = update_headings(lines, new_headings)
|
||
|
||
logging.info(f"文件处理完成: {file_path}")
|
||
return updated_lines
|
||
|
||
def create_output_dir(input_path: str, config: ReparagraphConfig):
|
||
"""创建输出目录"""
|
||
import os
|
||
from datetime import datetime
|
||
|
||
# 获取输入路径的父目录
|
||
parent_dir = os.path.dirname(input_path)
|
||
|
||
# 创建带时间戳的输出目录
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
output_dir = os.path.join(parent_dir, f"{config.task_name}_{timestamp}")
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
return output_dir
|
||
|
||
def save_processed_file(file_path: str, content: list, output_dir: str, input_path: str):
|
||
"""保存处理后的文件"""
|
||
import os
|
||
|
||
# 如果是单个文件
|
||
if os.path.isfile(input_path):
|
||
output_path = os.path.join(output_dir, os.path.basename(file_path))
|
||
else:
|
||
# 保持目录结构
|
||
relative_path = os.path.relpath(file_path, input_path)
|
||
output_path = os.path.join(output_dir, relative_path)
|
||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
f.writelines(content)
|
||
logging.info(f"已保存处理后的文件: {output_path}")
|
||
|
||
def reparagraph_file(path: str, config:ReparagraphConfig=None):
|
||
"""处理单个文件或文件夹中的所有.md文件
|
||
Args:
|
||
path: 文件路径或文件夹路径
|
||
config: ReparagraphConfig实例,包含处理配置
|
||
Returns:
|
||
str: 输出目录路径
|
||
"""
|
||
import os
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
|
||
if config is None:
|
||
config = ReparagraphConfig()
|
||
|
||
# 创建输出目录
|
||
output_dir = create_output_dir(path, config)
|
||
logging.info(f"输出目录: {output_dir}")
|
||
|
||
# 如果是文件夹,递归获取所有.md文件
|
||
if os.path.isdir(path):
|
||
files = []
|
||
for root, _, filenames in os.walk(path):
|
||
for filename in filenames:
|
||
if filename.endswith('.md'):
|
||
files.append(os.path.join(root, filename))
|
||
else:
|
||
files = [path]
|
||
|
||
def process_and_save(file_path: str):
|
||
content = process_single_file(file_path, config)
|
||
if content is not None and not config.dry_run:
|
||
save_processed_file(file_path, content, output_dir, path)
|
||
|
||
if config.parallel:
|
||
# 使用线程池并行处理
|
||
with ThreadPoolExecutor() as executor:
|
||
list(tqdm(executor.map(process_and_save, files), total=len(files), desc="Processing files"))
|
||
else:
|
||
# 顺序处理
|
||
for file_path in tqdm(files, desc="Processing files"):
|
||
process_and_save(file_path)
|
||
|
||
logging.info(f"处理完成,共处理 {len(files)} 个文件")
|
||
return output_dir
|