第一次合并clean代码

This commit is contained in:
2025-01-18 17:09:51 +08:00
parent e33a8b069e
commit a0f5ca9a35
21 changed files with 2252 additions and 375 deletions

319
clean/reparagraph.py Executable file
View File

@@ -0,0 +1,319 @@
"""
Author: Yutang LI
Institution: SIAT-MIC
Contact: yt.li2@siat.ac.cn
"""
import os
import re
import json
from tqdm import tqdm
import logging
from openai import OpenAI
from config import ReparagraphConfig
# 配置logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('reparagraph.log'),
logging.StreamHandler()
]
)
def get_true_level(title_info: list, config: ReparagraphConfig):
source_title = json.dumps(title_info)
instruction = """
你是一个论文目录重排助手。
有如下的JSON格式的目录信息,已知目录中每级标题的内容和行号。
<PLACEHOLDER>
请你重排该论文的目录层级,并为每级标题的level字段给出正确的层级关系,其中层级关系用数字(1,2,3,4)表示,数字越小,层级越高。
注意:重排序目录要求多个1级标题的样式, 而非单一1级目录的样式。也就说level为1的标题数量必须大于1。
通常情况下位于一级标题的有可能是:
1. 论文的题目
2. 论文的摘要(Abstract)
3. 论文的介绍(Introduction)
4. 论文的方法或实验(Methods or Experiment)
5. 论文的结果或讨论(Result or Discussion)
6. 论文的结论(Conclusion)
7. 论文的参考文献(References)
8. 论文的致谢(Acknowledgments)
9. 论文的附录(Appendix)
10. 论文的支撑信息(Supporting Information)
有时候目录中存在序号,这时则优先使用序号顺序重建目录。
返回结果的时候严格遵守下列示例JSON格式:
{ 'data': [
{ 'title': 'A hierarchically porous MOF confined CsPbBr3 quantum dots: Fluorescence switching probe for detecting Cu (II) and melamine in food samples', 'line_num': 1, 'level': 1},
...
]
"""
# 创建 OpenAI 客户端
client = OpenAI(api_key=config.openai_api_key, base_url=config.openai_base_url)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": instruction.replace("<PLACEHOLDER>", source_title)}
]
attempt = 0
while attempt < config.max_retries:
try:
completion = client.chat.completions.create(
model=config.model_name,
stream=False, # 关闭流模式
messages=messages,
response_format={
'type': 'json_object'
}
)
response = completion.choices[0].message.content
response = json.loads(response)
count_level_1 = sum(1 for item in response['data'] if item['level'] == 1)
if count_level_1 == 1:
attempt += 1
messages.append({"role": "assistant", "content": str(response)})
messages.append({"role": "user", "content": "上述目录中仅有1个1级标题, 请重新生成目录, 并保证目录中至少有两个1级标题。"})
continue
return response['data']
except (json.JSONDecodeError, Exception) as e:
logging.error(f"尝试 {attempt + 1}/{config.max_retries} 失败: {str(e)}")
if attempt == config.max_retries - 1:
logging.error("达到最大重试次数,放弃操作")
return "Error"
def read_file_content(file_path: str):
"""读取文件内容"""
with open(file_path, 'r', encoding='utf-8') as file:
return file.readlines()
def write_file_content(file_path: str, content: list):
"""写入文件内容"""
with open(file_path, 'w', encoding='utf-8') as file:
file.writelines(content)
def extract_headings(lines: list):
"""从文件内容中提取所有以#开头的行及其行号"""
headings = []
for line_num, line in enumerate(lines, 1):
if re.match(r'^#', line.strip()):
headings.append((line_num, line.strip()))
return headings
def extract_references(lines: list, headings: list, remove_refs: bool = False):
"""从文件内容中提取参考文献部分
Args:
lines: 文件内容列表
headings: 标题信息列表
remove_refs: 是否抹去参考文献内容
Returns:
dict: 包含起始点、结束点和内容的信息
{
'start': ref_start,
'end': ref_end,
'content': references,
'updated_headings': updated_headings
}
"""
# 在标题中查找REFERENCE
ref_heading = None
for line_num, heading in headings:
if "REFERENCE" in heading.upper().replace(" ", ""):
ref_heading = (line_num, heading)
break
if not ref_heading and "ACKNOWLEDGEMENT" in heading.upper().replace(" ", ""):
ref_heading = (line_num, heading)
if not ref_heading:
# 用正则匹配常见的引用格式并删除
# 包括:[数字]、数字.、(数字) 格式
ref_pattern = r'^(\[\d+\]|\d+\.|\(\d+\))'
lines = [line for line in lines if not re.match(ref_pattern, line.strip())]
return {
'start': -1,
'end': -1,
'content': None
}, lines
ref_start = ref_heading[0] - 1 # 转换为0-based索引
# 查找下一个标题或文件结尾
ref_end = len(lines)
for i in range(ref_start + 1, len(lines)):
if re.match(r'^#', lines[i].strip()):
ref_end = i
break
# 提取参考文献内容
references = ''.join(lines[ref_start:ref_end])
# 如果需要抹去内容
if remove_refs:
lines[ref_start:ref_end] = []
# # 如果需要更新headings
# updated_headings = headings
# if remove_refs and ref_heading:
# # 从headings中移除Reference标题
# updated_headings = [h for h in headings if h[1].upper() != ref_heading[1].upper()]
return {
'start': ref_start,
'end': ref_end,
'content': references,
#'updated_headings': updated_headings
}, lines
def update_headings(lines: list, heading_data: list):
"""根据提供的标题数据更新Markdown文件内容"""
# 统计heading_data中level==1的数量
# count_level_1 = sum(1 for item in heading_data if item['level'] == 1)
# flag = 2 if count_level_1 > 1 else 3 # 存在多个一级标题是为2否则为3
for heading in heading_data:
line_num = heading['line_num'] - 1
if heading['level'] >= 2:#flag:
lines[line_num] = "**" + lines[line_num].replace("#", "").strip() + "**\n"
return lines
def detect_file_encoding(file_path: str):
"""检测文件编码"""
import chardet
with open(file_path, 'rb') as f:
raw_data = f.read(1024)
result = chardet.detect(raw_data)
return result['encoding']
# def read_file_content(file_path: str, config: ReparagraphConfig):
# """读取文件内容,带大小检查和编码检测"""
# file_size = os.path.getsize(file_path)
# if file_size > config.max_file_size:
# logging.warning(f"文件 {file_path} 超过最大限制 {config.max_file_size} bytes跳过处理")
# return None
# encoding = detect_file_encoding(file_path)
# try:
# with open(file_path, 'r', encoding=encoding) as file:
# return file.readlines()
# except UnicodeDecodeError:
# logging.error(f"无法解码文件 {file_path}尝试使用utf-8")
# with open(file_path, 'r', encoding='utf-8') as file:
# return file.readlines()
def process_single_file(file_path: str, config: ReparagraphConfig):
"""处理单个文件并返回处理后的内容"""
# 读取文件内容
lines = read_file_content(file_path)
if lines is None:
return None
# 提取并更新标题
headings = extract_headings(lines)
title_info = [{"title": heading, "line_num": line_num, "level": "unknown"}
for line_num, heading in headings]
# 提取参考文献
ref_info, lines = extract_references(lines, headings, remove_refs=config.remove_refs)
if ref_info:
logging.info("提取的参考文献:")
logging.info(f"起始行: {ref_info['start'] + 1}")
logging.info(f"结束行: {ref_info['end']}")
logging.info("内容:")
logging.info(ref_info['content'])
# 更新headings
# headings = ref_info['updated_headings']
else:
logging.warning("未找到参考文献部分")
# 删除reference后可能会导致标题的行号变化重新索引
headings = extract_headings(lines)
title_info = [{"title": heading, "line_num": line_num, "level": "unknown"}
for line_num, heading in headings]
new_headings = get_true_level(title_info, config)
updated_lines = update_headings(lines, new_headings)
logging.info(f"文件处理完成: {file_path}")
return updated_lines
def create_output_dir(input_path: str, config: ReparagraphConfig):
"""创建输出目录"""
import os
from datetime import datetime
# 获取输入路径的父目录
parent_dir = os.path.dirname(input_path)
# 创建带时间戳的输出目录
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = os.path.join(parent_dir, f"{config.task_name}_{timestamp}")
os.makedirs(output_dir, exist_ok=True)
return output_dir
def save_processed_file(file_path: str, content: list, output_dir: str, input_path: str):
"""保存处理后的文件"""
import os
# 如果是单个文件
if os.path.isfile(input_path):
output_path = os.path.join(output_dir, os.path.basename(file_path))
else:
# 保持目录结构
relative_path = os.path.relpath(file_path, input_path)
output_path = os.path.join(output_dir, relative_path)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
f.writelines(content)
logging.info(f"已保存处理后的文件: {output_path}")
def reparagraph_file(path: str, config:ReparagraphConfig=None):
"""处理单个文件或文件夹中的所有.md文件
Args:
path: 文件路径或文件夹路径
config: ReparagraphConfig实例包含处理配置
Returns:
str: 输出目录路径
"""
import os
from concurrent.futures import ThreadPoolExecutor
if config is None:
config = ReparagraphConfig()
# 创建输出目录
output_dir = create_output_dir(path, config)
logging.info(f"输出目录: {output_dir}")
# 如果是文件夹,递归获取所有.md文件
if os.path.isdir(path):
files = []
for root, _, filenames in os.walk(path):
for filename in filenames:
if filename.endswith('.md'):
files.append(os.path.join(root, filename))
else:
files = [path]
def process_and_save(file_path: str):
content = process_single_file(file_path, config)
if content is not None and not config.dry_run:
save_processed_file(file_path, content, output_dir, path)
if config.parallel:
# 使用线程池并行处理
with ThreadPoolExecutor() as executor:
list(tqdm(executor.map(process_and_save, files), total=len(files), desc="Processing files"))
else:
# 顺序处理
for file_path in tqdm(files, desc="Processing files"):
process_and_save(file_path)
logging.info(f"处理完成,共处理 {len(files)} 个文件")
return output_dir