Files
datapipe/clean/reparagraph.py
2025-01-18 17:09:51 +08:00

320 lines
12 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import os
import re
import json
from tqdm import tqdm
import logging
from openai import OpenAI
from config import ReparagraphConfig
# 配置logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('reparagraph.log'),
logging.StreamHandler()
]
)
def get_true_level(title_info: list, config: ReparagraphConfig):
source_title = json.dumps(title_info)
instruction = """
你是一个论文目录重排助手。
有如下的JSON格式的目录信息,已知目录中每级标题的内容和行号。
<PLACEHOLDER>
请你重排该论文的目录层级,并为每级标题的level字段给出正确的层级关系,其中层级关系用数字(1,2,3,4)表示,数字越小,层级越高。
注意:重排序目录要求多个1级标题的样式, 而非单一1级目录的样式。也就说level为1的标题数量必须大于1。
通常情况下位于一级标题的有可能是:
1. 论文的题目
2. 论文的摘要(Abstract)
3. 论文的介绍(Introduction)
4. 论文的方法或实验(Methods or Experiment)
5. 论文的结果或讨论(Result or Discussion)
6. 论文的结论(Conclusion)
7. 论文的参考文献(References)
8. 论文的致谢(Acknowledgments)
9. 论文的附录(Appendix)
10. 论文的支撑信息(Supporting Information)
有时候目录中存在序号,这时则优先使用序号顺序重建目录。
返回结果的时候严格遵守下列示例JSON格式:
{ 'data': [
{ 'title': 'A hierarchically porous MOF confined CsPbBr3 quantum dots: Fluorescence switching probe for detecting Cu (II) and melamine in food samples', 'line_num': 1, 'level': 1},
...
]
"""
# 创建 OpenAI 客户端
client = OpenAI(api_key=config.openai_api_key, base_url=config.openai_base_url)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": instruction.replace("<PLACEHOLDER>", source_title)}
]
attempt = 0
while attempt < config.max_retries:
try:
completion = client.chat.completions.create(
model=config.model_name,
stream=False, # 关闭流模式
messages=messages,
response_format={
'type': 'json_object'
}
)
response = completion.choices[0].message.content
response = json.loads(response)
count_level_1 = sum(1 for item in response['data'] if item['level'] == 1)
if count_level_1 == 1:
attempt += 1
messages.append({"role": "assistant", "content": str(response)})
messages.append({"role": "user", "content": "上述目录中仅有1个1级标题, 请重新生成目录, 并保证目录中至少有两个1级标题。"})
continue
return response['data']
except (json.JSONDecodeError, Exception) as e:
logging.error(f"尝试 {attempt + 1}/{config.max_retries} 失败: {str(e)}")
if attempt == config.max_retries - 1:
logging.error("达到最大重试次数,放弃操作")
return "Error"
def read_file_content(file_path: str):
"""读取文件内容"""
with open(file_path, 'r', encoding='utf-8') as file:
return file.readlines()
def write_file_content(file_path: str, content: list):
"""写入文件内容"""
with open(file_path, 'w', encoding='utf-8') as file:
file.writelines(content)
def extract_headings(lines: list):
"""从文件内容中提取所有以#开头的行及其行号"""
headings = []
for line_num, line in enumerate(lines, 1):
if re.match(r'^#', line.strip()):
headings.append((line_num, line.strip()))
return headings
def extract_references(lines: list, headings: list, remove_refs: bool = False):
"""从文件内容中提取参考文献部分
Args:
lines: 文件内容列表
headings: 标题信息列表
remove_refs: 是否抹去参考文献内容
Returns:
dict: 包含起始点、结束点和内容的信息
{
'start': ref_start,
'end': ref_end,
'content': references,
'updated_headings': updated_headings
}
"""
# 在标题中查找REFERENCE
ref_heading = None
for line_num, heading in headings:
if "REFERENCE" in heading.upper().replace(" ", ""):
ref_heading = (line_num, heading)
break
if not ref_heading and "ACKNOWLEDGEMENT" in heading.upper().replace(" ", ""):
ref_heading = (line_num, heading)
if not ref_heading:
# 用正则匹配常见的引用格式并删除
# 包括:[数字]、数字.、(数字) 格式
ref_pattern = r'^(\[\d+\]|\d+\.|\(\d+\))'
lines = [line for line in lines if not re.match(ref_pattern, line.strip())]
return {
'start': -1,
'end': -1,
'content': None
}, lines
ref_start = ref_heading[0] - 1 # 转换为0-based索引
# 查找下一个标题或文件结尾
ref_end = len(lines)
for i in range(ref_start + 1, len(lines)):
if re.match(r'^#', lines[i].strip()):
ref_end = i
break
# 提取参考文献内容
references = ''.join(lines[ref_start:ref_end])
# 如果需要抹去内容
if remove_refs:
lines[ref_start:ref_end] = []
# # 如果需要更新headings
# updated_headings = headings
# if remove_refs and ref_heading:
# # 从headings中移除Reference标题
# updated_headings = [h for h in headings if h[1].upper() != ref_heading[1].upper()]
return {
'start': ref_start,
'end': ref_end,
'content': references,
#'updated_headings': updated_headings
}, lines
def update_headings(lines: list, heading_data: list):
"""根据提供的标题数据更新Markdown文件内容"""
# 统计heading_data中level==1的数量
# count_level_1 = sum(1 for item in heading_data if item['level'] == 1)
# flag = 2 if count_level_1 > 1 else 3 # 存在多个一级标题是为2否则为3
for heading in heading_data:
line_num = heading['line_num'] - 1
if heading['level'] >= 2:#flag:
lines[line_num] = "**" + lines[line_num].replace("#", "").strip() + "**\n"
return lines
def detect_file_encoding(file_path: str):
"""检测文件编码"""
import chardet
with open(file_path, 'rb') as f:
raw_data = f.read(1024)
result = chardet.detect(raw_data)
return result['encoding']
# def read_file_content(file_path: str, config: ReparagraphConfig):
# """读取文件内容,带大小检查和编码检测"""
# file_size = os.path.getsize(file_path)
# if file_size > config.max_file_size:
# logging.warning(f"文件 {file_path} 超过最大限制 {config.max_file_size} bytes跳过处理")
# return None
# encoding = detect_file_encoding(file_path)
# try:
# with open(file_path, 'r', encoding=encoding) as file:
# return file.readlines()
# except UnicodeDecodeError:
# logging.error(f"无法解码文件 {file_path}尝试使用utf-8")
# with open(file_path, 'r', encoding='utf-8') as file:
# return file.readlines()
def process_single_file(file_path: str, config: ReparagraphConfig):
"""处理单个文件并返回处理后的内容"""
# 读取文件内容
lines = read_file_content(file_path)
if lines is None:
return None
# 提取并更新标题
headings = extract_headings(lines)
title_info = [{"title": heading, "line_num": line_num, "level": "unknown"}
for line_num, heading in headings]
# 提取参考文献
ref_info, lines = extract_references(lines, headings, remove_refs=config.remove_refs)
if ref_info:
logging.info("提取的参考文献:")
logging.info(f"起始行: {ref_info['start'] + 1}")
logging.info(f"结束行: {ref_info['end']}")
logging.info("内容:")
logging.info(ref_info['content'])
# 更新headings
# headings = ref_info['updated_headings']
else:
logging.warning("未找到参考文献部分")
# 删除reference后可能会导致标题的行号变化重新索引
headings = extract_headings(lines)
title_info = [{"title": heading, "line_num": line_num, "level": "unknown"}
for line_num, heading in headings]
new_headings = get_true_level(title_info, config)
updated_lines = update_headings(lines, new_headings)
logging.info(f"文件处理完成: {file_path}")
return updated_lines
def create_output_dir(input_path: str, config: ReparagraphConfig):
"""创建输出目录"""
import os
from datetime import datetime
# 获取输入路径的父目录
parent_dir = os.path.dirname(input_path)
# 创建带时间戳的输出目录
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = os.path.join(parent_dir, f"{config.task_name}_{timestamp}")
os.makedirs(output_dir, exist_ok=True)
return output_dir
def save_processed_file(file_path: str, content: list, output_dir: str, input_path: str):
"""保存处理后的文件"""
import os
# 如果是单个文件
if os.path.isfile(input_path):
output_path = os.path.join(output_dir, os.path.basename(file_path))
else:
# 保持目录结构
relative_path = os.path.relpath(file_path, input_path)
output_path = os.path.join(output_dir, relative_path)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
f.writelines(content)
logging.info(f"已保存处理后的文件: {output_path}")
def reparagraph_file(path: str, config:ReparagraphConfig=None):
"""处理单个文件或文件夹中的所有.md文件
Args:
path: 文件路径或文件夹路径
config: ReparagraphConfig实例包含处理配置
Returns:
str: 输出目录路径
"""
import os
from concurrent.futures import ThreadPoolExecutor
if config is None:
config = ReparagraphConfig()
# 创建输出目录
output_dir = create_output_dir(path, config)
logging.info(f"输出目录: {output_dir}")
# 如果是文件夹,递归获取所有.md文件
if os.path.isdir(path):
files = []
for root, _, filenames in os.walk(path):
for filename in filenames:
if filename.endswith('.md'):
files.append(os.path.join(root, filename))
else:
files = [path]
def process_and_save(file_path: str):
content = process_single_file(file_path, config)
if content is not None and not config.dry_run:
save_processed_file(file_path, content, output_dir, path)
if config.parallel:
# 使用线程池并行处理
with ThreadPoolExecutor() as executor:
list(tqdm(executor.map(process_and_save, files), total=len(files), desc="Processing files"))
else:
# 顺序处理
for file_path in tqdm(files, desc="Processing files"):
process_and_save(file_path)
logging.info(f"处理完成,共处理 {len(files)} 个文件")
return output_dir