Files
mars-mcp/generate_data/generate_sft_data/utils.py

819 lines
30 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
This script generates questions and answers from a given set of CIFs.
It uses the OpenAI API and MySQL for storing and retrieving data.
@author: Yutang Li
"""
import multiprocessing
import sqlite3
import jsonlines
import tiktoken
import re
from fractions import Fraction
import numpy as np
import glob
import tqdm
import copy
import json
import time
import random
from openai import OpenAI, APIError, RateLimitError
from mysql.connector import pooling, Error
from collections import Counter
def get_response_from_deepseek_r1(messages: list[dict], prefix: bool = False, max_retries: int = 3, initial_backoff: float = 1.0):
"""
Get response from DeepSeek API with retry mechanism.
Args:
messages: List of message dictionaries
prefix: Whether to use the prefix URL
max_retries: Maximum number of retry attempts
initial_backoff: Initial backoff time in seconds
Returns:
Tuple of (reasoning_content, content) or error messages
"""
retries = 0
while retries <= max_retries:
try:
base_url = "https://api.deepseek.com/beta" if prefix else "https://vip.apiyi.com/v1"
api_key = "sk-59279cc16ec740089146ef9aef9c1671" if prefix else "sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d"
client = OpenAI(api_key=api_key, base_url=base_url)
# messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
response = client.chat.completions.create(
model="deepseek-r1",
messages=messages,
temperature=0.6
)
#print("response",response)
reasoning_content = response.choices[0].message.content.split("</think>\n")[0].split("<think>\n")[-1]
if reasoning_content=='':
reasoning_content=response.choices[0].message.content.split("</think>\n")[1]
# while reasoning_content == "" :
# if retries<max_retries:
# response = client.chat.completions.create(
# model="deepseek-r1",
# messages=messages,
# temperature=0.6
# )
# retries+=1
# else:
# print(f"Max retries exceeded for RateLimitError: {rate_error}")
# return 'apierror', 'apierror'
# reasoning_content = response.choices[0].message.content.split("</think>\n")[0].split("<think>\n")[-1]
content = response.choices[0].message.content.split("</think>\n")[-1]
return reasoning_content, content
except RateLimitError as rate_error:
retries += 1
if retries > max_retries:
print(f"Max retries exceeded for RateLimitError: {rate_error}")
return 'apierror', 'apierror'
# Exponential backoff with jitter
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
print(f"Rate limit hit, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries})")
time.sleep(backoff_time)
except APIError as api_error:
retries += 1
if retries > max_retries:
print(f"Max retries exceeded for APIError: {api_error}")
return 'apierror', 'apierror'
# Check if the error is retryable
error_str = str(api_error)
if "timeout" in error_str.lower() or "connection" in error_str.lower() or "server" in error_str.lower():
# Exponential backoff with jitter
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
print(f"API error, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries}): {api_error}")
time.sleep(backoff_time)
else:
# Non-retryable API error
print(f"Non-retryable API error: {api_error}")
return 'apierror', 'apierror'
except Exception as e:
print(f"generate_design_question Unexpected error: {e}")
return 'unexpectederror', 'unexpectederror'
def get_response_from_llm(messages: list[dict], model_name: str, tools: list = None, max_retries: int = 3, initial_backoff: float = 1.0):
"""
Get response from LLM API with retry mechanism.
Args:
messages: List of message dictionaries
model_name: Name of the model to use
tools: Optional list of tools to use
max_retries: Maximum number of retry attempts
initial_backoff: Initial backoff time in seconds
Returns:
Content of the response or error message
"""
retries = 0
while retries <= max_retries:
try:
client = OpenAI(api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d", base_url="https://vip.apiyi.com/v1")
# messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
if tools is None:
response = client.chat.completions.create(
model=model_name,
messages=messages,
)
else:
response = client.chat.completions.create(
model=model_name,
messages=messages,
tools=tools,
tool_choice='auto',
parallel_tool_calls=True
)
content = response.choices[0].message.content
return content
except RateLimitError as rate_error:
retries += 1
if retries > max_retries:
print(f"Max retries exceeded for RateLimitError: {rate_error}")
return 'apierror'
# Exponential backoff with jitter
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
print(f"Rate limit hit, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries})")
time.sleep(backoff_time)
except APIError as api_error:
retries += 1
if retries > max_retries:
print(f"Max retries exceeded for APIError: {api_error}")
return 'apierror'
# Check if the error is retryable
error_str = str(api_error)
if "timeout" in error_str.lower() or "connection" in error_str.lower() or "server" in error_str.lower():
# Exponential backoff with jitter
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
print(f"API error, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries}): {api_error}")
time.sleep(backoff_time)
else:
# Non-retryable API error
print(f"Non-retryable API error: {api_error}")
return 'apierror'
except Exception as e:
print(f"generate_design_question Unexpected error: {e}")
return 'unexpectederror'
def get_response_from_qwq(messages: list[dict], model_name: str, tools: list = None, max_retries: int = 3, initial_backoff: float = 1.0):
"""
Get response from LLM API with retry mechanism.
Args:
messages: List of message dictionaries
model_name: Name of the model to use
tools: Optional list of tools to use
max_retries: Maximum number of retry attempts
initial_backoff: Initial backoff time in seconds
Returns:
Content of the response or error message
"""
retries = 0
while retries <= max_retries:
try:
client = OpenAI(api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d", base_url="https://vip.apiyi.com/v1")
# client = OpenAI(api_key="sk-df98afdc6b5b48db8195dcb4a68e804b", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
# import random
# if random.random() > 0.5:
# client = OpenAI(api_key="sk-124748a0bdb24f4aa5ec2776e97cea2e", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
# else:
# client = OpenAI(api_key="sk-f3dddc436b054ed1bb524d544bcb8f0f", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
# messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
if tools is None:
response = client.chat.completions.create(
model=model_name,
messages=messages,
stream=True
)
else:
response = client.chat.completions.create(
model=model_name,
messages=messages,
tools=tools,
tool_choice='auto',
parallel_tool_calls=True,
stream=True
)
reasoning_content = "" # 定义完整思考过程
answer_content = "" # 定义完整回复
tool_info = [] # 存储工具调用信息
is_answering = False # 判断是否结束思考过程并开始回复
# print("="*20+"思考过程"+"="*20)
for chunk in response:
# if not chunk.choices:
# # 处理用量统计信息
# print("\n"+"="*20+"Usage"+"="*20)
# print(chunk.usage)
# else:
delta = chunk.choices[0].delta
# 处理AI的思考过程链式推理
if hasattr(delta, 'reasoning_content') and delta.reasoning_content is not None:
reasoning_content += delta.reasoning_content
# print(delta.reasoning_content,end="",flush=True) # 实时输出思考过程
# 处理最终回复内容
else:
if not is_answering: # 首次进入回复阶段时打印标题
is_answering = True
# print("\n"+"="*20+"回复内容"+"="*20)
if delta.content is not None:
answer_content += delta.content
# print(delta.content,end="",flush=True) # 流式输出回复内容
# 处理工具调用信息(支持并行工具调用)
if delta.tool_calls is not None:
for tool_call in delta.tool_calls:
index = tool_call.index # 工具调用索引,用于并行调用
# 动态扩展工具信息存储列表
while len(tool_info) <= index:
tool_info.append({})
# 收集工具调用ID用于后续函数调用
# if tool_call.id:
# tool_info[index]['id'] = tool_info[index].get('id', '') + tool_call.id
# 收集函数名称(用于后续路由到具体函数)
if tool_call.function and tool_call.function.name:
tool_info[index]['name'] = tool_info[index].get('name', '') + tool_call.function.name
# 收集函数参数JSON字符串格式需要后续解析
if tool_call.function and tool_call.function.arguments:
tool_info[index]['arguments'] = tool_info[index].get('arguments', '') + tool_call.function.arguments
tools_response = ""
for tool in tool_info:
tools_response += ("<tool_call>\n" + json.dumps(tool, ensure_ascii=False) + "\n</tool_call>\n")
response = "<think>\n" + reasoning_content + "\n</think>\n" + "<answer>\n" + answer_content + tools_response + "\n</answer>\n"
return response, tool_info
# return reasoning_content, answer_content, tool_info
except RateLimitError as rate_error:
retries += 1
if retries > max_retries:
print(f"Max retries exceeded for RateLimitError: {rate_error}")
return 'apierror', []
# Exponential backoff with jitter
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
print(f"Rate limit hit, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries})")
time.sleep(backoff_time)
except APIError as api_error:
retries += 1
if retries > max_retries:
print(f"Max retries exceeded for APIError: {api_error}")
return 'apierror', []
# Check if the error is retryable
error_str = str(api_error)
if "timeout" in error_str.lower() or "connection" in error_str.lower() or "server" in error_str.lower():
# Exponential backoff with jitter
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
print(f"API error, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries}): {api_error}")
time.sleep(backoff_time)
else:
# Non-retryable API error
print(f"Non-retryable API error: {api_error}")
return 'apierror', []
except Exception as e:
print(f"generate_design_question Unexpected error: {e}")
return 'unexpectederror', []
def read_json_file(file_path):
"""Read the json file and return its content."""
try:
with open(file_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
print(f"Error reading file {file_path}: {e}")
return None
def read_jsonline_file(file_path):
with jsonlines.open(file_path, mode='r') as reader:
datas = [line for line in reader]
return datas
################################## utils
def clean_all_repetitions_with_details(text, min_length=10, threshold=10):
"""
综合清理文本中的各种重复内容,并返回详细信息
参数:
- text: 要清理的文本
- min_length: 最小重复片段长度
- threshold: 重复内容的阈值
返回:
- cleaned_text: 清理后的文本
- is_repetitive: 是否检测到重复
- repetition_details: 重复内容的详细信息
"""
original_text = text
is_repetitive = False
repetition_details = []
# 1. 首先处理有换行符的重复
if '\n' in text:
lines = text.split('\n')
unique_lines = []
line_counts = {}
for i, line in enumerate(lines):
normalized = line.strip().lower()
if not normalized:
unique_lines.append(line)
continue
line_counts[normalized] = line_counts.get(normalized, 0) + 1
if line_counts[normalized] <= threshold:
unique_lines.append(line)
# 如果这是第一次超过阈值,记录重复详情
if line_counts[normalized] == threshold + 1:
# 找到原始形式(保留大小写)
original_form = None
for l in lines[:i]:
if l.strip().lower() == normalized:
original_form = l
break
if original_form is None:
original_form = line
repetition_details.append({
'type': 'line_repetition',
'repeated_string': original_form,
'repeat_count': line_counts[normalized]
})
if any(count > threshold for count in line_counts.values()):
text = '\n'.join(unique_lines)
is_repetitive = True
# 2. 处理同一行内的连续重复模式
for length in range(min_length, 101):
pattern = r'(.{' + str(length) + r'})(\1)+'
while True:
match = re.search(pattern, text)
if not match:
break
repeated_part = match.group(1)
full_match = match.group(0)
# 计算重复次数
repeat_count = len(full_match) // len(repeated_part)
# 记录重复详情
repetition_details.append({
'type': 'inline_repetition',
'repeated_string': repeated_part,
'repeat_count': repeat_count,
'total_length': len(full_match),
'position': match.start()
})
text = text.replace(full_match, repeated_part)
is_repetitive = True
# 3. 处理句子级别的重复
sentences = re.split(r'(?<=[.!?。?!])\s+', text)
if len(sentences) > 1:
sentence_counter = Counter(sentences)
for sentence, count in sentence_counter.items():
if count > threshold:
repetition_details.append({
'type': 'sentence_repetition',
'repeated_string': sentence,
'repeat_count': count
})
if any(count > threshold for count in sentence_counter.values()):
unique_sentences = []
seen_sentences = {}
for sentence in sentences:
seen_sentences[sentence] = seen_sentences.get(sentence, 0) + 1
if seen_sentences[sentence] <= threshold:
unique_sentences.append(sentence)
# 重新组合文本
text = ' '.join(unique_sentences)
is_repetitive = True
# 4. 处理更短的重复(如果前面的方法没有检测到重复)
if not is_repetitive and min_length > 5:
for length in range(5, min_length):
pattern = r'(.{' + str(length) + r'})(\1){2,}' # 至少重复3次才处理
while True:
match = re.search(pattern, text)
if not match:
break
repeated_part = match.group(1)
full_match = match.group(0)
# 计算重复次数
repeat_count = len(full_match) // len(repeated_part)
# 记录重复详情
repetition_details.append({
'type': 'short_repetition',
'repeated_string': repeated_part,
'repeat_count': repeat_count,
'total_length': len(full_match),
'position': match.start()
})
text = text.replace(full_match, repeated_part)
is_repetitive = True
# 按重复类型和长度排序
repetition_details.sort(key=lambda x: (-len(x['repeated_string']), x['type']))
return text, is_repetitive or text != original_text, repetition_details
def create_table(table_name, connection_pool):
"""Create the required MySQL table if it does not exist."""
db = connection_pool.get_connection()
cursor = db.cursor()
create_table_query = f"""
CREATE TABLE IF NOT EXISTS {table_name} (
id INT AUTO_INCREMENT PRIMARY KEY,
mp_id TEXT,
question_model TEXT,
question TEXT,
answer_model TEXT,
answer TEXT,
answer_len INT
)
"""
cursor.execute(create_table_query)
db.commit()
cursor.close()
db.close()
def record_exists(mp_id, table_name, connection_pool):
"""Check if a mp_id already exists in the table."""
db = connection_pool.get_connection()
cursor = db.cursor()
query = f"SELECT * FROM {table_name} WHERE mp_id = %s"
cursor.execute(query, (mp_id,))
result = cursor.fetchone()
cursor.fetchall() # Ensure all results are processed
cursor.close()
db.close()
return result is not None
def insert_record(entry, table_name, connection_pool):
"""Insert a record into the MySQL table."""
db = None
cursor = None
try:
db = connection_pool.get_connection()
cursor = db.cursor()
insert_query = f"""
INSERT INTO {table_name}
(mp_id, question_model, question, answer_model, answer, answer_len)
VALUES (%s, %s, %s, %s, %s, %s)
"""
values = (
entry["mp_id"], entry["question_model"],
entry["question"], entry["answer_model"], entry["answer"], entry["answer_len"],
)
cursor.execute(insert_query, values)
db.commit()
except Error as e:
print(f"Error: {e}")
db.rollback()
finally:
# Ensure cursor is closed
if cursor:
cursor.close()
# Ensure connection is returned to the pool
if db:
db.close()
# Initialize SQLite database connection
def initialize_db():
conn = sqlite3.connect('multi_turns_data.db', check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS conversations (
mp_id TEXT PRIMARY KEY,
sample TEXT,
token_num INTEGER
)
''')
conn.commit()
return conn
# Save sample to SQLite database
def save_to_db(conn, mp_id, sample, total_token):
cursor = conn.cursor()
cursor.execute('''
INSERT OR REPLACE INTO conversations (mp_id, sample, token_num)
VALUES (?, ?, ?)
''', (mp_id, str(sample), total_token))
conn.commit()
def read_cif_txt_file(file_path):
"""Read the markdown file and return its content."""
try:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
print(f"Error reading file {file_path}: {e}")
return None
def round_values(data, precision=3):
"""
递归地将字典中的所有值保留三位小数
"""
if isinstance(data, dict): # 如果是字典
return {key: round_values(value) for key, value in data.items()}
elif isinstance(data, list): # 如果是列表,递归处理每个元素
return [round_values(item) for item in data]
elif isinstance(data, (int, float)): # 如果是数字,保留三位小数
return round(data, precision)
else: # 对其他类型,直接返回
return data
def decimal_to_fraction(decimal_value, max_denominator=1000):
"""
将小数转换为分数表示
参数:
decimal_value: 要转换的小数
max_denominator: 分母的最大值,用于控制精度
返回:
分数表示的字符串
"""
frac = Fraction(decimal_value).limit_denominator(max_denominator)
return f"{frac.numerator}/{frac.denominator}"
def poscar_to_fractional_representation(poscar_content, max_denominator=1000):
"""
将POSCAR文件中的数值转换为分数表示
参数:
poscar_content: POSCAR文件内容
max_denominator: 分母的最大值,用于控制精度
返回:
转换后的POSCAR内容数值以分数表示
"""
lines = poscar_content.strip().split('\n')
result_lines = []
# 保留系统名称
result_lines.append(lines[0])
# 保留缩放因子
scaling_factor = float(lines[1])
result_lines.append(lines[1])
# 处理晶格向量
for i in range(2, 5):
vector = [float(x) for x in lines[i].split()]
# 将每个分量转换为分数
fractional_vector = [decimal_to_fraction(x, max_denominator) for x in vector]
result_lines.append(" " + " ".join(fractional_vector))
# 保留元素类型和数量
if len(lines) > 5:
result_lines.append(lines[5])
if len(lines) > 6:
result_lines.append(lines[6])
# 保留坐标类型
if len(lines) > 7:
result_lines.append(lines[7])
# 处理原子坐标
for i in range(8, len(lines)):
parts = lines[i].split()
if len(parts) >= 3:
# 将坐标转换为分数
coords = [float(parts[j]) for j in range(3)]
fractional_coords = [decimal_to_fraction(x, max_denominator) for x in coords]
# 构建新行
new_line = " " + " ".join(fractional_coords)
if len(parts) > 3:
new_line += " " + " ".join(parts[3:])
result_lines.append(new_line)
else:
# 保留非坐标行
result_lines.append(lines[i])
return "\n".join(result_lines)
def remove_symmetry_equiv_xyz(cif_content):
"""
删除CIF文件中的对称性操作部分
参数:
cif_content: CIF文件内容字符串
返回:
清理后的CIF内容字符串
"""
lines = cif_content.split('\n')
output_lines = []
i = 0
while i < len(lines):
line = lines[i].strip()
# 检测循环开始
if line == 'loop_':
# 查看下一行,检查是否是对称性循环
next_lines = []
j = i + 1
while j < len(lines) and lines[j].strip().startswith('_'):
next_lines.append(lines[j].strip())
j += 1
# 检查是否包含对称性操作标签
if any('_symmetry_equiv_pos_as_xyz' in tag for tag in next_lines):
# 跳过整个循环块
while i < len(lines):
if i + 1 >= len(lines):
break
next_line = lines[i + 1].strip()
# 检查是否到达下一个循环或数据块
if next_line == 'loop_' or next_line.startswith('data_'):
break
# 检查是否到达原子位置部分
if next_line.startswith('_atom_site_'):
break
i += 1
else:
# 不是对称性循环保留loop_行
output_lines.append(lines[i])
else:
# 非循环开始行,直接保留
output_lines.append(lines[i])
i += 1
return '\n'.join(output_lines)
def remove_null_values(d):
"""
Recursively remove key-value pairs with null (None) values from a dictionary.
Args:
d (dict): The dictionary to clean.
Returns:
dict: A new dictionary without null values.
"""
if not isinstance(d, dict):
raise ValueError("Input must be a dictionary")
_d = copy.deepcopy(d)
def recursive_remove(d):
cleaned_dict = {}
for key, value in d.items():
if isinstance(value, dict):
# Recursively clean nested dictionaries
nested_cleaned = recursive_remove(value)
if nested_cleaned: # Only add non-empty dictionaries
cleaned_dict[key] = nested_cleaned
elif value is not None and key != 'version':
cleaned_dict[key] = value
return cleaned_dict
clean_dict = recursive_remove(d)
if _d['cbm'] is None and _d['vbm'] is None and _d['band_gap'] is not None:
# clean_dict['band_gap'] = None
clean_dict.pop('band_gap')
return clean_dict
def get_extra_cif_info(path: str, fields_name: list):
"""Extract specific fields from the CIF description."""
basic_fields = ['formula_pretty', 'chemsys', 'composition', 'elements', 'symmetry', 'nsites', 'volume', 'density']
energy_electronic_fields = ['formation_energy_per_atom', 'energy_above_hull', 'is_stable', 'efermi', 'cbm', 'vbm', 'band_gap', 'is_gap_direct']
metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'num_magnetic_sites']
# metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'total_magnetization_normalized_vol', 'total_magnetization_normalized_formula_units', 'num_magnetic_sites', 'num_unique_magnetic_sites', 'types_of_magnetic_species', "decomposes_to"]
selected_fields = []
if fields_name[0] == 'all_fields':
selected_fields = basic_fields + energy_electronic_fields + metal_magentic_fields
# selected_fields = energy_electronic_fields + metal_magentic_fields
else:
for field in fields_name:
selected_fields.extend(locals().get(field, []))
with open(path, 'r') as f:
docs = json.load(f)
new_docs = {}
for field_name in selected_fields:
new_docs[field_name] = docs.get(field_name, '')
# new_docs['structure'] = {"lattice": docs['structure']['lattice']}
return new_docs
def extract_json(text):
"""Extract JSON content from a block of text using regex."""
json_pattern = re.compile(r'\\{(?:[^{}]|(?R))*\\}')
matches = json_pattern.search(text)
if matches:
json_str = matches.group(0)
try:
return json.loads(json_str)
except json.JSONDecodeError:
return None
return None
def extract_and_parse_json(response):
"""Extract and parse JSON from a response."""
json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', response)
json_str = json_match.group(1) if json_match else response.strip()
json_str = re.sub(r'(\$[^\$]*\$)', lambda m: m.group(1).replace('\\', '\\\\'), json_str)
json_str = json_str.replace('\\"', '"').replace("\\'", "'")
try:
return json.loads(json_str)
except json.JSONDecodeError as e:
print(f"JSON parse error: {e}")
return 'errformat'
# 计算输入消息的tokens
def count_message_tokens(messages, model_name):
encoding = tiktoken.encoding_for_model(model_name)
num_tokens = 0
num_tokens += len(encoding.encode(messages))
return num_tokens
def make_multi_turns_sharegpt_sample(humans: list[str], gpts: list[str], system: str="{SYSTEM}"):
sample = {}
conversations = []
if system is not None and system != "":
sample["system"] = system
assert len(humans) !=0, "human cannot be None"
assert len(gpts) == len(humans), "human and gpt must have the same length"
for human, gpt in zip(humans, gpts):
if human is not None and human != "":
assert gpt is not None, "gpt cannot be None"
assert gpt != "", "gpt cannot be empty"
# 下列顺序不可改
conversations.append({"from": "human", "value": human})
conversations.append({"from": "gpt", "value": gpt})
sample["conversations"] = conversations
return sample
##################################### utils