""" This script generates questions and answers from a given set of CIFs. It uses the OpenAI API and MySQL for storing and retrieving data. @author: Yutang Li """ import multiprocessing import sqlite3 import tiktoken import re from fractions import Fraction import numpy as np import glob import tqdm import copy import json import time import random from openai import OpenAI, APIError, RateLimitError from mysql.connector import pooling, Error from collections import Counter def get_response_from_deepseek_r1(messages: list[dict], prefix: bool = False, max_retries: int = 3, initial_backoff: float = 1.0): """ Get response from DeepSeek API with retry mechanism. Args: messages: List of message dictionaries prefix: Whether to use the prefix URL max_retries: Maximum number of retry attempts initial_backoff: Initial backoff time in seconds Returns: Tuple of (reasoning_content, content) or error messages """ retries = 0 while retries <= max_retries: try: base_url = "https://api.deepseek.com/beta" if prefix else "https://vip.apiyi.com/v1" api_key = "sk-59279cc16ec740089146ef9aef9c1671" if prefix else "sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d" client = OpenAI(api_key=api_key, base_url=base_url) # messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] response = client.chat.completions.create( model="deepseek-r1", messages=messages, temperature=0.6 ) # reasoning_content = "null" if prefix else "\n" + response.choices[0].message.model_extra['reasoning_content'] + "\n\n" reasoning_content = response.choices[0].message.content.split("\n")[0].split("\n")[-1] content = response.choices[0].message.content.split("\n")[-1] return reasoning_content, content except RateLimitError as rate_error: retries += 1 if retries > max_retries: print(f"Max retries exceeded for RateLimitError: {rate_error}") return 'apierror', 'apierror' # Exponential backoff with jitter backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random()) print(f"Rate limit hit, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries})") time.sleep(backoff_time) except APIError as api_error: retries += 1 if retries > max_retries: print(f"Max retries exceeded for APIError: {api_error}") return 'apierror', 'apierror' # Check if the error is retryable error_str = str(api_error) if "timeout" in error_str.lower() or "connection" in error_str.lower() or "server" in error_str.lower(): # Exponential backoff with jitter backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random()) print(f"API error, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries}): {api_error}") time.sleep(backoff_time) else: # Non-retryable API error print(f"Non-retryable API error: {api_error}") return 'apierror', 'apierror' except Exception as e: print(f"generate_design_question Unexpected error: {e}") return 'unexpectederror', 'unexpectederror' def get_response_from_llm(messages: list[dict], model_name: str, tools: list = None, max_retries: int = 3, initial_backoff: float = 1.0): """ Get response from LLM API with retry mechanism. Args: messages: List of message dictionaries model_name: Name of the model to use tools: Optional list of tools to use max_retries: Maximum number of retry attempts initial_backoff: Initial backoff time in seconds Returns: Content of the response or error message """ retries = 0 while retries <= max_retries: try: client = OpenAI(api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d", base_url="https://vip.apiyi.com/v1") # messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] if tools is None: response = client.chat.completions.create( model=model_name, messages=messages, ) else: response = client.chat.completions.create( model=model_name, messages=messages, tools=tools, tool_choice='auto', parallel_tool_calls=True ) content = response.choices[0].message.content return content except RateLimitError as rate_error: retries += 1 if retries > max_retries: print(f"Max retries exceeded for RateLimitError: {rate_error}") return 'apierror' # Exponential backoff with jitter backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random()) print(f"Rate limit hit, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries})") time.sleep(backoff_time) except APIError as api_error: retries += 1 if retries > max_retries: print(f"Max retries exceeded for APIError: {api_error}") return 'apierror' # Check if the error is retryable error_str = str(api_error) if "timeout" in error_str.lower() or "connection" in error_str.lower() or "server" in error_str.lower(): # Exponential backoff with jitter backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random()) print(f"API error, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries}): {api_error}") time.sleep(backoff_time) else: # Non-retryable API error print(f"Non-retryable API error: {api_error}") return 'apierror' except Exception as e: print(f"generate_design_question Unexpected error: {e}") return 'unexpectederror' def get_response_from_qwq(messages: list[dict], model_name: str, tools: list = None, max_retries: int = 3, initial_backoff: float = 1.0): """ Get response from LLM API with retry mechanism. Args: messages: List of message dictionaries model_name: Name of the model to use tools: Optional list of tools to use max_retries: Maximum number of retry attempts initial_backoff: Initial backoff time in seconds Returns: Content of the response or error message """ retries = 0 while retries <= max_retries: try: # client = OpenAI(api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d", base_url="https://vip.apiyi.com/v1") # client = OpenAI(api_key="sk-df98afdc6b5b48db8195dcb4a68e804b", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1") import random if random.random() > 0.5: client = OpenAI(api_key="sk-124748a0bdb24f4aa5ec2776e97cea2e", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1") else: client = OpenAI(api_key="sk-f3dddc436b054ed1bb524d544bcb8f0f", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1") # messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] if tools is None: response = client.chat.completions.create( model=model_name, messages=messages, stream=True ) else: response = client.chat.completions.create( model=model_name, messages=messages, tools=tools, tool_choice='auto', parallel_tool_calls=True, stream=True ) reasoning_content = "" # 定义完整思考过程 answer_content = "" # 定义完整回复 tool_info = [] # 存储工具调用信息 is_answering = False # 判断是否结束思考过程并开始回复 # print("="*20+"思考过程"+"="*20) for chunk in response: # if not chunk.choices: # # 处理用量统计信息 # print("\n"+"="*20+"Usage"+"="*20) # print(chunk.usage) # else: delta = chunk.choices[0].delta # 处理AI的思考过程(链式推理) if hasattr(delta, 'reasoning_content') and delta.reasoning_content is not None: reasoning_content += delta.reasoning_content # print(delta.reasoning_content,end="",flush=True) # 实时输出思考过程 # 处理最终回复内容 else: if not is_answering: # 首次进入回复阶段时打印标题 is_answering = True # print("\n"+"="*20+"回复内容"+"="*20) if delta.content is not None: answer_content += delta.content # print(delta.content,end="",flush=True) # 流式输出回复内容 # 处理工具调用信息(支持并行工具调用) if delta.tool_calls is not None: for tool_call in delta.tool_calls: index = tool_call.index # 工具调用索引,用于并行调用 # 动态扩展工具信息存储列表 while len(tool_info) <= index: tool_info.append({}) # 收集工具调用ID(用于后续函数调用) # if tool_call.id: # tool_info[index]['id'] = tool_info[index].get('id', '') + tool_call.id # 收集函数名称(用于后续路由到具体函数) if tool_call.function and tool_call.function.name: tool_info[index]['name'] = tool_info[index].get('name', '') + tool_call.function.name # 收集函数参数(JSON字符串格式,需要后续解析) if tool_call.function and tool_call.function.arguments: tool_info[index]['arguments'] = tool_info[index].get('arguments', '') + tool_call.function.arguments tools_response = "" for tool in tool_info: tools_response += ("\n" + json.dumps(tool, ensure_ascii=False) + "\n\n") response = "\n" + reasoning_content + "\n\n" + "\n" + answer_content + tools_response + "\n\n" return response, tool_info # return reasoning_content, answer_content, tool_info except RateLimitError as rate_error: retries += 1 if retries > max_retries: print(f"Max retries exceeded for RateLimitError: {rate_error}") return 'apierror', [] # Exponential backoff with jitter backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random()) print(f"Rate limit hit, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries})") time.sleep(backoff_time) except APIError as api_error: retries += 1 if retries > max_retries: print(f"Max retries exceeded for APIError: {api_error}") return 'apierror', [] # Check if the error is retryable error_str = str(api_error) if "timeout" in error_str.lower() or "connection" in error_str.lower() or "server" in error_str.lower(): # Exponential backoff with jitter backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random()) print(f"API error, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries}): {api_error}") time.sleep(backoff_time) else: # Non-retryable API error print(f"Non-retryable API error: {api_error}") return 'apierror', [] except Exception as e: print(f"generate_design_question Unexpected error: {e}") return 'unexpectederror', [] def read_json_file(file_path): """Read the json file and return its content.""" try: with open(file_path, 'r', encoding='utf-8') as f: return json.load(f) except Exception as e: print(f"Error reading file {file_path}: {e}") return None ################################## utils def clean_all_repetitions_with_details(text, min_length=10, threshold=10): """ 综合清理文本中的各种重复内容,并返回详细信息 参数: - text: 要清理的文本 - min_length: 最小重复片段长度 - threshold: 重复内容的阈值 返回: - cleaned_text: 清理后的文本 - is_repetitive: 是否检测到重复 - repetition_details: 重复内容的详细信息 """ original_text = text is_repetitive = False repetition_details = [] # 1. 首先处理有换行符的重复 if '\n' in text: lines = text.split('\n') unique_lines = [] line_counts = {} for i, line in enumerate(lines): normalized = line.strip().lower() if not normalized: unique_lines.append(line) continue line_counts[normalized] = line_counts.get(normalized, 0) + 1 if line_counts[normalized] <= threshold: unique_lines.append(line) # 如果这是第一次超过阈值,记录重复详情 if line_counts[normalized] == threshold + 1: # 找到原始形式(保留大小写) original_form = None for l in lines[:i]: if l.strip().lower() == normalized: original_form = l break if original_form is None: original_form = line repetition_details.append({ 'type': 'line_repetition', 'repeated_string': original_form, 'repeat_count': line_counts[normalized] }) if any(count > threshold for count in line_counts.values()): text = '\n'.join(unique_lines) is_repetitive = True # 2. 处理同一行内的连续重复模式 for length in range(min_length, 101): pattern = r'(.{' + str(length) + r'})(\1)+' while True: match = re.search(pattern, text) if not match: break repeated_part = match.group(1) full_match = match.group(0) # 计算重复次数 repeat_count = len(full_match) // len(repeated_part) # 记录重复详情 repetition_details.append({ 'type': 'inline_repetition', 'repeated_string': repeated_part, 'repeat_count': repeat_count, 'total_length': len(full_match), 'position': match.start() }) text = text.replace(full_match, repeated_part) is_repetitive = True # 3. 处理句子级别的重复 sentences = re.split(r'(?<=[.!?。?!])\s+', text) if len(sentences) > 1: sentence_counter = Counter(sentences) for sentence, count in sentence_counter.items(): if count > threshold: repetition_details.append({ 'type': 'sentence_repetition', 'repeated_string': sentence, 'repeat_count': count }) if any(count > threshold for count in sentence_counter.values()): unique_sentences = [] seen_sentences = {} for sentence in sentences: seen_sentences[sentence] = seen_sentences.get(sentence, 0) + 1 if seen_sentences[sentence] <= threshold: unique_sentences.append(sentence) # 重新组合文本 text = ' '.join(unique_sentences) is_repetitive = True # 4. 处理更短的重复(如果前面的方法没有检测到重复) if not is_repetitive and min_length > 5: for length in range(5, min_length): pattern = r'(.{' + str(length) + r'})(\1){2,}' # 至少重复3次才处理 while True: match = re.search(pattern, text) if not match: break repeated_part = match.group(1) full_match = match.group(0) # 计算重复次数 repeat_count = len(full_match) // len(repeated_part) # 记录重复详情 repetition_details.append({ 'type': 'short_repetition', 'repeated_string': repeated_part, 'repeat_count': repeat_count, 'total_length': len(full_match), 'position': match.start() }) text = text.replace(full_match, repeated_part) is_repetitive = True # 按重复类型和长度排序 repetition_details.sort(key=lambda x: (-len(x['repeated_string']), x['type'])) return text, is_repetitive or text != original_text, repetition_details def create_table(table_name, connection_pool): """Create the required MySQL table if it does not exist.""" db = connection_pool.get_connection() cursor = db.cursor() create_table_query = f""" CREATE TABLE IF NOT EXISTS {table_name} ( id INT AUTO_INCREMENT PRIMARY KEY, mp_id TEXT, question_model TEXT, question TEXT, answer_model TEXT, answer TEXT, answer_len INT ) """ cursor.execute(create_table_query) db.commit() cursor.close() db.close() def record_exists(mp_id, table_name, connection_pool): """Check if a mp_id already exists in the table.""" db = connection_pool.get_connection() cursor = db.cursor() query = f"SELECT * FROM {table_name} WHERE mp_id = %s" cursor.execute(query, (mp_id,)) result = cursor.fetchone() cursor.fetchall() # Ensure all results are processed cursor.close() db.close() return result is not None def insert_record(entry, table_name, connection_pool): """Insert a record into the MySQL table.""" db = None cursor = None try: db = connection_pool.get_connection() cursor = db.cursor() insert_query = f""" INSERT INTO {table_name} (mp_id, question_model, question, answer_model, answer, answer_len) VALUES (%s, %s, %s, %s, %s, %s) """ values = ( entry["mp_id"], entry["question_model"], entry["question"], entry["answer_model"], entry["answer"], entry["answer_len"], ) cursor.execute(insert_query, values) db.commit() except Error as e: print(f"Error: {e}") db.rollback() finally: # Ensure cursor is closed if cursor: cursor.close() # Ensure connection is returned to the pool if db: db.close() # Initialize SQLite database connection def initialize_db(): conn = sqlite3.connect('multi_turns_data.db', check_same_thread=False) cursor = conn.cursor() cursor.execute(''' CREATE TABLE IF NOT EXISTS conversations ( mp_id TEXT PRIMARY KEY, sample TEXT, token_num INTEGER ) ''') conn.commit() return conn # Save sample to SQLite database def save_to_db(conn, mp_id, sample, total_token): cursor = conn.cursor() cursor.execute(''' INSERT OR REPLACE INTO conversations (mp_id, sample, token_num) VALUES (?, ?, ?) ''', (mp_id, str(sample), total_token)) conn.commit() def read_cif_txt_file(file_path): """Read the markdown file and return its content.""" try: with open(file_path, 'r', encoding='utf-8') as f: return f.read() except Exception as e: print(f"Error reading file {file_path}: {e}") return None def round_values(data, precision=3): """ 递归地将字典中的所有值保留三位小数 """ if isinstance(data, dict): # 如果是字典 return {key: round_values(value) for key, value in data.items()} elif isinstance(data, list): # 如果是列表,递归处理每个元素 return [round_values(item) for item in data] elif isinstance(data, (int, float)): # 如果是数字,保留三位小数 return round(data, precision) else: # 对其他类型,直接返回 return data def decimal_to_fraction(decimal_value, max_denominator=1000): """ 将小数转换为分数表示 参数: decimal_value: 要转换的小数 max_denominator: 分母的最大值,用于控制精度 返回: 分数表示的字符串 """ frac = Fraction(decimal_value).limit_denominator(max_denominator) return f"{frac.numerator}/{frac.denominator}" def poscar_to_fractional_representation(poscar_content, max_denominator=1000): """ 将POSCAR文件中的数值转换为分数表示 参数: poscar_content: POSCAR文件内容 max_denominator: 分母的最大值,用于控制精度 返回: 转换后的POSCAR内容,数值以分数表示 """ lines = poscar_content.strip().split('\n') result_lines = [] # 保留系统名称 result_lines.append(lines[0]) # 保留缩放因子 scaling_factor = float(lines[1]) result_lines.append(lines[1]) # 处理晶格向量 for i in range(2, 5): vector = [float(x) for x in lines[i].split()] # 将每个分量转换为分数 fractional_vector = [decimal_to_fraction(x, max_denominator) for x in vector] result_lines.append(" " + " ".join(fractional_vector)) # 保留元素类型和数量 if len(lines) > 5: result_lines.append(lines[5]) if len(lines) > 6: result_lines.append(lines[6]) # 保留坐标类型 if len(lines) > 7: result_lines.append(lines[7]) # 处理原子坐标 for i in range(8, len(lines)): parts = lines[i].split() if len(parts) >= 3: # 将坐标转换为分数 coords = [float(parts[j]) for j in range(3)] fractional_coords = [decimal_to_fraction(x, max_denominator) for x in coords] # 构建新行 new_line = " " + " ".join(fractional_coords) if len(parts) > 3: new_line += " " + " ".join(parts[3:]) result_lines.append(new_line) else: # 保留非坐标行 result_lines.append(lines[i]) return "\n".join(result_lines) def remove_symmetry_equiv_xyz(cif_content): """ 删除CIF文件中的对称性操作部分 参数: cif_content: CIF文件内容字符串 返回: 清理后的CIF内容字符串 """ lines = cif_content.split('\n') output_lines = [] i = 0 while i < len(lines): line = lines[i].strip() # 检测循环开始 if line == 'loop_': # 查看下一行,检查是否是对称性循环 next_lines = [] j = i + 1 while j < len(lines) and lines[j].strip().startswith('_'): next_lines.append(lines[j].strip()) j += 1 # 检查是否包含对称性操作标签 if any('_symmetry_equiv_pos_as_xyz' in tag for tag in next_lines): # 跳过整个循环块 while i < len(lines): if i + 1 >= len(lines): break next_line = lines[i + 1].strip() # 检查是否到达下一个循环或数据块 if next_line == 'loop_' or next_line.startswith('data_'): break # 检查是否到达原子位置部分 if next_line.startswith('_atom_site_'): break i += 1 else: # 不是对称性循环,保留loop_行 output_lines.append(lines[i]) else: # 非循环开始行,直接保留 output_lines.append(lines[i]) i += 1 return '\n'.join(output_lines) def remove_null_values(d): """ Recursively remove key-value pairs with null (None) values from a dictionary. Args: d (dict): The dictionary to clean. Returns: dict: A new dictionary without null values. """ if not isinstance(d, dict): raise ValueError("Input must be a dictionary") _d = copy.deepcopy(d) def recursive_remove(d): cleaned_dict = {} for key, value in d.items(): if isinstance(value, dict): # Recursively clean nested dictionaries nested_cleaned = recursive_remove(value) if nested_cleaned: # Only add non-empty dictionaries cleaned_dict[key] = nested_cleaned elif value is not None and key != 'version': cleaned_dict[key] = value return cleaned_dict clean_dict = recursive_remove(d) if _d['cbm'] is None and _d['vbm'] is None and _d['band_gap'] is not None: # clean_dict['band_gap'] = None clean_dict.pop('band_gap') return clean_dict def get_extra_cif_info(path: str, fields_name: list): """Extract specific fields from the CIF description.""" basic_fields = ['formula_pretty', 'chemsys', 'composition', 'elements', 'symmetry', 'nsites', 'volume', 'density'] energy_electronic_fields = ['formation_energy_per_atom', 'energy_above_hull', 'is_stable', 'efermi', 'cbm', 'vbm', 'band_gap', 'is_gap_direct'] metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'num_magnetic_sites'] # metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'total_magnetization_normalized_vol', 'total_magnetization_normalized_formula_units', 'num_magnetic_sites', 'num_unique_magnetic_sites', 'types_of_magnetic_species', "decomposes_to"] selected_fields = [] if fields_name[0] == 'all_fields': selected_fields = basic_fields + energy_electronic_fields + metal_magentic_fields # selected_fields = energy_electronic_fields + metal_magentic_fields else: for field in fields_name: selected_fields.extend(locals().get(field, [])) with open(path, 'r') as f: docs = json.load(f) new_docs = {} for field_name in selected_fields: new_docs[field_name] = docs.get(field_name, '') # new_docs['structure'] = {"lattice": docs['structure']['lattice']} return new_docs def extract_json(text): """Extract JSON content from a block of text using regex.""" json_pattern = re.compile(r'\\{(?:[^{}]|(?R))*\\}') matches = json_pattern.search(text) if matches: json_str = matches.group(0) try: return json.loads(json_str) except json.JSONDecodeError: return None return None def extract_and_parse_json(response): """Extract and parse JSON from a response.""" json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', response) json_str = json_match.group(1) if json_match else response.strip() json_str = re.sub(r'(\$[^\$]*\$)', lambda m: m.group(1).replace('\\', '\\\\'), json_str) json_str = json_str.replace('\\"', '"').replace("\\'", "'") try: return json.loads(json_str) except json.JSONDecodeError as e: print(f"JSON parse error: {e}") return 'errformat' # 计算输入消息的tokens def count_message_tokens(messages, model_name): encoding = tiktoken.encoding_for_model(model_name) num_tokens = 0 num_tokens += len(encoding.encode(messages)) return num_tokens def make_multi_turns_sharegpt_sample(humans: list[str], gpts: list[str], system: str="{SYSTEM}"): sample = {} conversations = [] if system is not None and system != "": sample["system"] = system assert len(humans) !=0, "human cannot be None" assert len(gpts) == len(humans), "human and gpt must have the same length" for human, gpt in zip(humans, gpts): if human is not None and human != "": assert gpt is not None, "gpt cannot be None" assert gpt != "", "gpt cannot be empty" # 下列顺序不可改 conversations.append({"from": "human", "value": human}) conversations.append({"from": "gpt", "value": gpt}) sample["conversations"] = conversations return sample ##################################### utils