第一次合并clean代码
This commit is contained in:
319
clean/reparagraph.py
Executable file
319
clean/reparagraph.py
Executable file
@@ -0,0 +1,319 @@
|
||||
"""
|
||||
Author: Yutang LI
|
||||
Institution: SIAT-MIC
|
||||
Contact: yt.li2@siat.ac.cn
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
from openai import OpenAI
|
||||
from config import ReparagraphConfig
|
||||
|
||||
# 配置logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('reparagraph.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_true_level(title_info: list, config: ReparagraphConfig):
|
||||
source_title = json.dumps(title_info)
|
||||
instruction = """
|
||||
你是一个论文目录重排助手。
|
||||
有如下的JSON格式的目录信息,已知目录中每级标题的内容和行号。
|
||||
<PLACEHOLDER>
|
||||
请你重排该论文的目录层级,并为每级标题的level字段给出正确的层级关系,其中层级关系用数字(1,2,3,4)表示,数字越小,层级越高。
|
||||
注意:重排序目录要求多个1级标题的样式, 而非单一1级目录的样式。也就说level为1的标题数量必须大于1。
|
||||
通常情况下位于一级标题的有可能是:
|
||||
1. 论文的题目
|
||||
2. 论文的摘要(Abstract)
|
||||
3. 论文的介绍(Introduction)
|
||||
4. 论文的方法或实验(Methods or Experiment)
|
||||
5. 论文的结果或讨论(Result or Discussion)
|
||||
6. 论文的结论(Conclusion)
|
||||
7. 论文的参考文献(References)
|
||||
8. 论文的致谢(Acknowledgments)
|
||||
9. 论文的附录(Appendix)
|
||||
10. 论文的支撑信息(Supporting Information)
|
||||
有时候目录中存在序号,这时则优先使用序号顺序重建目录。
|
||||
|
||||
返回结果的时候严格遵守下列示例JSON格式:
|
||||
{ 'data': [
|
||||
{ 'title': 'A hierarchically porous MOF confined CsPbBr3 quantum dots: Fluorescence switching probe for detecting Cu (II) and melamine in food samples', 'line_num': 1, 'level': 1},
|
||||
...
|
||||
]
|
||||
"""
|
||||
# 创建 OpenAI 客户端
|
||||
client = OpenAI(api_key=config.openai_api_key, base_url=config.openai_base_url)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": instruction.replace("<PLACEHOLDER>", source_title)}
|
||||
]
|
||||
attempt = 0
|
||||
while attempt < config.max_retries:
|
||||
try:
|
||||
completion = client.chat.completions.create(
|
||||
model=config.model_name,
|
||||
stream=False, # 关闭流模式
|
||||
messages=messages,
|
||||
response_format={
|
||||
'type': 'json_object'
|
||||
}
|
||||
)
|
||||
|
||||
response = completion.choices[0].message.content
|
||||
response = json.loads(response)
|
||||
count_level_1 = sum(1 for item in response['data'] if item['level'] == 1)
|
||||
if count_level_1 == 1:
|
||||
attempt += 1
|
||||
messages.append({"role": "assistant", "content": str(response)})
|
||||
messages.append({"role": "user", "content": "上述目录中仅有1个1级标题, 请重新生成目录, 并保证目录中至少有两个1级标题。"})
|
||||
continue
|
||||
return response['data']
|
||||
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
logging.error(f"尝试 {attempt + 1}/{config.max_retries} 失败: {str(e)}")
|
||||
if attempt == config.max_retries - 1:
|
||||
logging.error("达到最大重试次数,放弃操作")
|
||||
return "Error"
|
||||
|
||||
|
||||
def read_file_content(file_path: str):
|
||||
"""读取文件内容"""
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
return file.readlines()
|
||||
|
||||
def write_file_content(file_path: str, content: list):
|
||||
"""写入文件内容"""
|
||||
with open(file_path, 'w', encoding='utf-8') as file:
|
||||
file.writelines(content)
|
||||
|
||||
def extract_headings(lines: list):
|
||||
"""从文件内容中提取所有以#开头的行及其行号"""
|
||||
headings = []
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if re.match(r'^#', line.strip()):
|
||||
headings.append((line_num, line.strip()))
|
||||
return headings
|
||||
|
||||
def extract_references(lines: list, headings: list, remove_refs: bool = False):
|
||||
"""从文件内容中提取参考文献部分
|
||||
Args:
|
||||
lines: 文件内容列表
|
||||
headings: 标题信息列表
|
||||
remove_refs: 是否抹去参考文献内容
|
||||
Returns:
|
||||
dict: 包含起始点、结束点和内容的信息
|
||||
{
|
||||
'start': ref_start,
|
||||
'end': ref_end,
|
||||
'content': references,
|
||||
'updated_headings': updated_headings
|
||||
}
|
||||
"""
|
||||
# 在标题中查找REFERENCE
|
||||
ref_heading = None
|
||||
for line_num, heading in headings:
|
||||
if "REFERENCE" in heading.upper().replace(" ", ""):
|
||||
ref_heading = (line_num, heading)
|
||||
break
|
||||
|
||||
if not ref_heading and "ACKNOWLEDGEMENT" in heading.upper().replace(" ", ""):
|
||||
ref_heading = (line_num, heading)
|
||||
|
||||
if not ref_heading:
|
||||
# 用正则匹配常见的引用格式并删除
|
||||
# 包括:[数字]、数字.、(数字) 格式
|
||||
ref_pattern = r'^(\[\d+\]|\d+\.|\(\d+\))'
|
||||
lines = [line for line in lines if not re.match(ref_pattern, line.strip())]
|
||||
return {
|
||||
'start': -1,
|
||||
'end': -1,
|
||||
'content': None
|
||||
}, lines
|
||||
|
||||
ref_start = ref_heading[0] - 1 # 转换为0-based索引
|
||||
|
||||
# 查找下一个标题或文件结尾
|
||||
ref_end = len(lines)
|
||||
for i in range(ref_start + 1, len(lines)):
|
||||
if re.match(r'^#', lines[i].strip()):
|
||||
ref_end = i
|
||||
break
|
||||
|
||||
# 提取参考文献内容
|
||||
references = ''.join(lines[ref_start:ref_end])
|
||||
|
||||
# 如果需要抹去内容
|
||||
if remove_refs:
|
||||
lines[ref_start:ref_end] = []
|
||||
|
||||
# # 如果需要更新headings
|
||||
# updated_headings = headings
|
||||
# if remove_refs and ref_heading:
|
||||
# # 从headings中移除Reference标题
|
||||
# updated_headings = [h for h in headings if h[1].upper() != ref_heading[1].upper()]
|
||||
|
||||
return {
|
||||
'start': ref_start,
|
||||
'end': ref_end,
|
||||
'content': references,
|
||||
#'updated_headings': updated_headings
|
||||
}, lines
|
||||
|
||||
def update_headings(lines: list, heading_data: list):
|
||||
"""根据提供的标题数据更新Markdown文件内容"""
|
||||
# 统计heading_data中level==1的数量
|
||||
# count_level_1 = sum(1 for item in heading_data if item['level'] == 1)
|
||||
# flag = 2 if count_level_1 > 1 else 3 # 存在多个一级标题是为2否则为3
|
||||
|
||||
for heading in heading_data:
|
||||
line_num = heading['line_num'] - 1
|
||||
if heading['level'] >= 2:#flag:
|
||||
lines[line_num] = "**" + lines[line_num].replace("#", "").strip() + "**\n"
|
||||
return lines
|
||||
|
||||
|
||||
def detect_file_encoding(file_path: str):
|
||||
"""检测文件编码"""
|
||||
import chardet
|
||||
with open(file_path, 'rb') as f:
|
||||
raw_data = f.read(1024)
|
||||
result = chardet.detect(raw_data)
|
||||
return result['encoding']
|
||||
|
||||
# def read_file_content(file_path: str, config: ReparagraphConfig):
|
||||
# """读取文件内容,带大小检查和编码检测"""
|
||||
# file_size = os.path.getsize(file_path)
|
||||
# if file_size > config.max_file_size:
|
||||
# logging.warning(f"文件 {file_path} 超过最大限制 {config.max_file_size} bytes,跳过处理")
|
||||
# return None
|
||||
|
||||
# encoding = detect_file_encoding(file_path)
|
||||
# try:
|
||||
# with open(file_path, 'r', encoding=encoding) as file:
|
||||
# return file.readlines()
|
||||
# except UnicodeDecodeError:
|
||||
# logging.error(f"无法解码文件 {file_path},尝试使用utf-8")
|
||||
# with open(file_path, 'r', encoding='utf-8') as file:
|
||||
# return file.readlines()
|
||||
|
||||
def process_single_file(file_path: str, config: ReparagraphConfig):
|
||||
"""处理单个文件并返回处理后的内容"""
|
||||
# 读取文件内容
|
||||
lines = read_file_content(file_path)
|
||||
if lines is None:
|
||||
return None
|
||||
|
||||
# 提取并更新标题
|
||||
headings = extract_headings(lines)
|
||||
title_info = [{"title": heading, "line_num": line_num, "level": "unknown"}
|
||||
for line_num, heading in headings]
|
||||
|
||||
# 提取参考文献
|
||||
ref_info, lines = extract_references(lines, headings, remove_refs=config.remove_refs)
|
||||
if ref_info:
|
||||
logging.info("提取的参考文献:")
|
||||
logging.info(f"起始行: {ref_info['start'] + 1}")
|
||||
logging.info(f"结束行: {ref_info['end']}")
|
||||
logging.info("内容:")
|
||||
logging.info(ref_info['content'])
|
||||
# 更新headings
|
||||
# headings = ref_info['updated_headings']
|
||||
else:
|
||||
logging.warning("未找到参考文献部分")
|
||||
|
||||
# 删除reference后可能会导致标题的行号变化,重新索引
|
||||
headings = extract_headings(lines)
|
||||
title_info = [{"title": heading, "line_num": line_num, "level": "unknown"}
|
||||
for line_num, heading in headings]
|
||||
|
||||
new_headings = get_true_level(title_info, config)
|
||||
updated_lines = update_headings(lines, new_headings)
|
||||
|
||||
logging.info(f"文件处理完成: {file_path}")
|
||||
return updated_lines
|
||||
|
||||
def create_output_dir(input_path: str, config: ReparagraphConfig):
|
||||
"""创建输出目录"""
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
# 获取输入路径的父目录
|
||||
parent_dir = os.path.dirname(input_path)
|
||||
|
||||
# 创建带时间戳的输出目录
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = os.path.join(parent_dir, f"{config.task_name}_{timestamp}")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
return output_dir
|
||||
|
||||
def save_processed_file(file_path: str, content: list, output_dir: str, input_path: str):
|
||||
"""保存处理后的文件"""
|
||||
import os
|
||||
|
||||
# 如果是单个文件
|
||||
if os.path.isfile(input_path):
|
||||
output_path = os.path.join(output_dir, os.path.basename(file_path))
|
||||
else:
|
||||
# 保持目录结构
|
||||
relative_path = os.path.relpath(file_path, input_path)
|
||||
output_path = os.path.join(output_dir, relative_path)
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.writelines(content)
|
||||
logging.info(f"已保存处理后的文件: {output_path}")
|
||||
|
||||
def reparagraph_file(path: str, config:ReparagraphConfig=None):
|
||||
"""处理单个文件或文件夹中的所有.md文件
|
||||
Args:
|
||||
path: 文件路径或文件夹路径
|
||||
config: ReparagraphConfig实例,包含处理配置
|
||||
Returns:
|
||||
str: 输出目录路径
|
||||
"""
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
if config is None:
|
||||
config = ReparagraphConfig()
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = create_output_dir(path, config)
|
||||
logging.info(f"输出目录: {output_dir}")
|
||||
|
||||
# 如果是文件夹,递归获取所有.md文件
|
||||
if os.path.isdir(path):
|
||||
files = []
|
||||
for root, _, filenames in os.walk(path):
|
||||
for filename in filenames:
|
||||
if filename.endswith('.md'):
|
||||
files.append(os.path.join(root, filename))
|
||||
else:
|
||||
files = [path]
|
||||
|
||||
def process_and_save(file_path: str):
|
||||
content = process_single_file(file_path, config)
|
||||
if content is not None and not config.dry_run:
|
||||
save_processed_file(file_path, content, output_dir, path)
|
||||
|
||||
if config.parallel:
|
||||
# 使用线程池并行处理
|
||||
with ThreadPoolExecutor() as executor:
|
||||
list(tqdm(executor.map(process_and_save, files), total=len(files), desc="Processing files"))
|
||||
else:
|
||||
# 顺序处理
|
||||
for file_path in tqdm(files, desc="Processing files"):
|
||||
process_and_save(file_path)
|
||||
|
||||
logging.info(f"处理完成,共处理 {len(files)} 个文件")
|
||||
return output_dir
|
||||
Reference in New Issue
Block a user