""" Author: Yutang LI Institution: SIAT-MIC Contact: yt.li2@siat.ac.cn """ import os import re import json from tqdm import tqdm import logging from openai import OpenAI from config import ReparagraphConfig # 配置logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('reparagraph.log'), logging.StreamHandler() ] ) def get_true_level(title_info: list, config: ReparagraphConfig): source_title = json.dumps(title_info) instruction = """ 你是一个论文目录重排助手。 有如下的JSON格式的目录信息,已知目录中每级标题的内容和行号。 请你重排该论文的目录层级,并为每级标题的level字段给出正确的层级关系,其中层级关系用数字(1,2,3,4)表示,数字越小,层级越高。 注意:重排序目录要求多个1级标题的样式, 而非单一1级目录的样式。也就说level为1的标题数量必须大于1。 通常情况下位于一级标题的有可能是: 1. 论文的题目 2. 论文的摘要(Abstract) 3. 论文的介绍(Introduction) 4. 论文的方法或实验(Methods or Experiment) 5. 论文的结果或讨论(Result or Discussion) 6. 论文的结论(Conclusion) 7. 论文的参考文献(References) 8. 论文的致谢(Acknowledgments) 9. 论文的附录(Appendix) 10. 论文的支撑信息(Supporting Information) 有时候目录中存在序号,这时则优先使用序号顺序重建目录。 返回结果的时候严格遵守下列示例JSON格式: { 'data': [ { 'title': 'A hierarchically porous MOF confined CsPbBr3 quantum dots: Fluorescence switching probe for detecting Cu (II) and melamine in food samples', 'line_num': 1, 'level': 1}, ... ] """ # 创建 OpenAI 客户端 client = OpenAI(api_key=config.openai_api_key, base_url=config.openai_base_url) messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": instruction.replace("", source_title)} ] attempt = 0 while attempt < config.max_retries: try: completion = client.chat.completions.create( model=config.model_name, stream=False, # 关闭流模式 messages=messages, response_format={ 'type': 'json_object' } ) response = completion.choices[0].message.content response = json.loads(response) count_level_1 = sum(1 for item in response['data'] if item['level'] == 1) if count_level_1 == 1: attempt += 1 messages.append({"role": "assistant", "content": str(response)}) messages.append({"role": "user", "content": "上述目录中仅有1个1级标题, 请重新生成目录, 并保证目录中至少有两个1级标题。"}) continue return response['data'] except (json.JSONDecodeError, Exception) as e: logging.error(f"尝试 {attempt + 1}/{config.max_retries} 失败: {str(e)}") if attempt == config.max_retries - 1: logging.error("达到最大重试次数,放弃操作") return "Error" def read_file_content(file_path: str): """读取文件内容""" with open(file_path, 'r', encoding='utf-8') as file: return file.readlines() def write_file_content(file_path: str, content: list): """写入文件内容""" with open(file_path, 'w', encoding='utf-8') as file: file.writelines(content) def extract_headings(lines: list): """从文件内容中提取所有以#开头的行及其行号""" headings = [] for line_num, line in enumerate(lines, 1): if re.match(r'^#', line.strip()): headings.append((line_num, line.strip())) return headings def extract_references(lines: list, headings: list, remove_refs: bool = False): """从文件内容中提取参考文献部分 Args: lines: 文件内容列表 headings: 标题信息列表 remove_refs: 是否抹去参考文献内容 Returns: dict: 包含起始点、结束点和内容的信息 { 'start': ref_start, 'end': ref_end, 'content': references, 'updated_headings': updated_headings } """ # 在标题中查找REFERENCE ref_heading = None for line_num, heading in headings: if "REFERENCE" in heading.upper().replace(" ", ""): ref_heading = (line_num, heading) break if not ref_heading and "ACKNOWLEDGEMENT" in heading.upper().replace(" ", ""): ref_heading = (line_num, heading) if not ref_heading: # 用正则匹配常见的引用格式并删除 # 包括:[数字]、数字.、(数字) 格式 ref_pattern = r'^(\[\d+\]|\d+\.|\(\d+\))' lines = [line for line in lines if not re.match(ref_pattern, line.strip())] return { 'start': -1, 'end': -1, 'content': None }, lines ref_start = ref_heading[0] - 1 # 转换为0-based索引 # 查找下一个标题或文件结尾 ref_end = len(lines) for i in range(ref_start + 1, len(lines)): if re.match(r'^#', lines[i].strip()): ref_end = i break # 提取参考文献内容 references = ''.join(lines[ref_start:ref_end]) # 如果需要抹去内容 if remove_refs: lines[ref_start:ref_end] = [] # # 如果需要更新headings # updated_headings = headings # if remove_refs and ref_heading: # # 从headings中移除Reference标题 # updated_headings = [h for h in headings if h[1].upper() != ref_heading[1].upper()] return { 'start': ref_start, 'end': ref_end, 'content': references, #'updated_headings': updated_headings }, lines def update_headings(lines: list, heading_data: list): """根据提供的标题数据更新Markdown文件内容""" # 统计heading_data中level==1的数量 # count_level_1 = sum(1 for item in heading_data if item['level'] == 1) # flag = 2 if count_level_1 > 1 else 3 # 存在多个一级标题是为2否则为3 for heading in heading_data: line_num = heading['line_num'] - 1 if heading['level'] >= 2:#flag: lines[line_num] = "**" + lines[line_num].replace("#", "").strip() + "**\n" return lines def detect_file_encoding(file_path: str): """检测文件编码""" import chardet with open(file_path, 'rb') as f: raw_data = f.read(1024) result = chardet.detect(raw_data) return result['encoding'] # def read_file_content(file_path: str, config: ReparagraphConfig): # """读取文件内容,带大小检查和编码检测""" # file_size = os.path.getsize(file_path) # if file_size > config.max_file_size: # logging.warning(f"文件 {file_path} 超过最大限制 {config.max_file_size} bytes,跳过处理") # return None # encoding = detect_file_encoding(file_path) # try: # with open(file_path, 'r', encoding=encoding) as file: # return file.readlines() # except UnicodeDecodeError: # logging.error(f"无法解码文件 {file_path},尝试使用utf-8") # with open(file_path, 'r', encoding='utf-8') as file: # return file.readlines() def process_single_file(file_path: str, config: ReparagraphConfig): """处理单个文件并返回处理后的内容""" # 读取文件内容 lines = read_file_content(file_path) if lines is None: return None # 提取并更新标题 headings = extract_headings(lines) title_info = [{"title": heading, "line_num": line_num, "level": "unknown"} for line_num, heading in headings] # 提取参考文献 ref_info, lines = extract_references(lines, headings, remove_refs=config.remove_refs) if ref_info: logging.info("提取的参考文献:") logging.info(f"起始行: {ref_info['start'] + 1}") logging.info(f"结束行: {ref_info['end']}") logging.info("内容:") logging.info(ref_info['content']) # 更新headings # headings = ref_info['updated_headings'] else: logging.warning("未找到参考文献部分") # 删除reference后可能会导致标题的行号变化,重新索引 headings = extract_headings(lines) title_info = [{"title": heading, "line_num": line_num, "level": "unknown"} for line_num, heading in headings] new_headings = get_true_level(title_info, config) updated_lines = update_headings(lines, new_headings) logging.info(f"文件处理完成: {file_path}") return updated_lines def create_output_dir(input_path: str, config: ReparagraphConfig): """创建输出目录""" import os from datetime import datetime # 获取输入路径的父目录 parent_dir = os.path.dirname(input_path) # 创建带时间戳的输出目录 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_dir = os.path.join(parent_dir, f"{config.task_name}_{timestamp}") os.makedirs(output_dir, exist_ok=True) return output_dir def save_processed_file(file_path: str, content: list, output_dir: str, input_path: str): """保存处理后的文件""" import os # 如果是单个文件 if os.path.isfile(input_path): output_path = os.path.join(output_dir, os.path.basename(file_path)) else: # 保持目录结构 relative_path = os.path.relpath(file_path, input_path) output_path = os.path.join(output_dir, relative_path) os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, 'w', encoding='utf-8') as f: f.writelines(content) logging.info(f"已保存处理后的文件: {output_path}") def reparagraph_file(path: str, config:ReparagraphConfig=None): """处理单个文件或文件夹中的所有.md文件 Args: path: 文件路径或文件夹路径 config: ReparagraphConfig实例,包含处理配置 Returns: str: 输出目录路径 """ import os from concurrent.futures import ThreadPoolExecutor if config is None: config = ReparagraphConfig() # 创建输出目录 output_dir = create_output_dir(path, config) logging.info(f"输出目录: {output_dir}") # 如果是文件夹,递归获取所有.md文件 if os.path.isdir(path): files = [] for root, _, filenames in os.walk(path): for filename in filenames: if filename.endswith('.md'): files.append(os.path.join(root, filename)) else: files = [path] def process_and_save(file_path: str): content = process_single_file(file_path, config) if content is not None and not config.dry_run: save_processed_file(file_path, content, output_dir, path) if config.parallel: # 使用线程池并行处理 with ThreadPoolExecutor() as executor: list(tqdm(executor.map(process_and_save, files), total=len(files), desc="Processing files")) else: # 顺序处理 for file_path in tqdm(files, desc="Processing files"): process_and_save(file_path) logging.info(f"处理完成,共处理 {len(files)} 个文件") return output_dir