mcp,生成数据代码
This commit is contained in:
800
generate_data/utils.py
Executable file
800
generate_data/utils.py
Executable file
@@ -0,0 +1,800 @@
|
||||
"""
|
||||
This script generates questions and answers from a given set of CIFs.
|
||||
It uses the OpenAI API and MySQL for storing and retrieving data.
|
||||
@author: Yutang Li
|
||||
"""
|
||||
import multiprocessing
|
||||
import sqlite3
|
||||
import tiktoken
|
||||
import re
|
||||
from fractions import Fraction
|
||||
import numpy as np
|
||||
import glob
|
||||
import tqdm
|
||||
import copy
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
from openai import OpenAI, APIError, RateLimitError
|
||||
from mysql.connector import pooling, Error
|
||||
from collections import Counter
|
||||
|
||||
|
||||
|
||||
def get_response_from_deepseek_r1(messages: list[dict], prefix: bool = False, max_retries: int = 3, initial_backoff: float = 1.0):
|
||||
"""
|
||||
Get response from DeepSeek API with retry mechanism.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
prefix: Whether to use the prefix URL
|
||||
max_retries: Maximum number of retry attempts
|
||||
initial_backoff: Initial backoff time in seconds
|
||||
|
||||
Returns:
|
||||
Tuple of (reasoning_content, content) or error messages
|
||||
"""
|
||||
retries = 0
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
base_url = "https://api.deepseek.com/beta" if prefix else "https://vip.apiyi.com/v1"
|
||||
api_key = "sk-59279cc16ec740089146ef9aef9c1671" if prefix else "sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d"
|
||||
|
||||
client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
# messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-r1",
|
||||
messages=messages,
|
||||
temperature=0.6
|
||||
)
|
||||
|
||||
# reasoning_content = "null" if prefix else "<think>\n" + response.choices[0].message.model_extra['reasoning_content'] + "\n</think>\n"
|
||||
reasoning_content = response.choices[0].message.content.split("</think>\n")[0].split("<think>\n")[-1]
|
||||
content = response.choices[0].message.content.split("</think>\n")[-1]
|
||||
return reasoning_content, content
|
||||
|
||||
except RateLimitError as rate_error:
|
||||
retries += 1
|
||||
if retries > max_retries:
|
||||
print(f"Max retries exceeded for RateLimitError: {rate_error}")
|
||||
return 'apierror', 'apierror'
|
||||
|
||||
# Exponential backoff with jitter
|
||||
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
|
||||
print(f"Rate limit hit, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries})")
|
||||
time.sleep(backoff_time)
|
||||
|
||||
except APIError as api_error:
|
||||
retries += 1
|
||||
if retries > max_retries:
|
||||
print(f"Max retries exceeded for APIError: {api_error}")
|
||||
return 'apierror', 'apierror'
|
||||
|
||||
# Check if the error is retryable
|
||||
error_str = str(api_error)
|
||||
if "timeout" in error_str.lower() or "connection" in error_str.lower() or "server" in error_str.lower():
|
||||
# Exponential backoff with jitter
|
||||
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
|
||||
print(f"API error, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries}): {api_error}")
|
||||
time.sleep(backoff_time)
|
||||
else:
|
||||
# Non-retryable API error
|
||||
print(f"Non-retryable API error: {api_error}")
|
||||
return 'apierror', 'apierror'
|
||||
|
||||
except Exception as e:
|
||||
print(f"generate_design_question Unexpected error: {e}")
|
||||
return 'unexpectederror', 'unexpectederror'
|
||||
|
||||
|
||||
def get_response_from_llm(messages: list[dict], model_name: str, tools: list = None, max_retries: int = 3, initial_backoff: float = 1.0):
|
||||
"""
|
||||
Get response from LLM API with retry mechanism.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
model_name: Name of the model to use
|
||||
tools: Optional list of tools to use
|
||||
max_retries: Maximum number of retry attempts
|
||||
initial_backoff: Initial backoff time in seconds
|
||||
|
||||
Returns:
|
||||
Content of the response or error message
|
||||
"""
|
||||
retries = 0
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
client = OpenAI(api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d", base_url="https://vip.apiyi.com/v1")
|
||||
# messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
|
||||
if tools is None:
|
||||
response = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
response = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice='auto',
|
||||
parallel_tool_calls=True
|
||||
)
|
||||
content = response.choices[0].message.content
|
||||
return content
|
||||
|
||||
except RateLimitError as rate_error:
|
||||
retries += 1
|
||||
if retries > max_retries:
|
||||
print(f"Max retries exceeded for RateLimitError: {rate_error}")
|
||||
return 'apierror'
|
||||
|
||||
# Exponential backoff with jitter
|
||||
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
|
||||
print(f"Rate limit hit, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries})")
|
||||
time.sleep(backoff_time)
|
||||
|
||||
except APIError as api_error:
|
||||
retries += 1
|
||||
if retries > max_retries:
|
||||
print(f"Max retries exceeded for APIError: {api_error}")
|
||||
return 'apierror'
|
||||
|
||||
# Check if the error is retryable
|
||||
error_str = str(api_error)
|
||||
if "timeout" in error_str.lower() or "connection" in error_str.lower() or "server" in error_str.lower():
|
||||
# Exponential backoff with jitter
|
||||
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
|
||||
print(f"API error, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries}): {api_error}")
|
||||
time.sleep(backoff_time)
|
||||
else:
|
||||
# Non-retryable API error
|
||||
print(f"Non-retryable API error: {api_error}")
|
||||
return 'apierror'
|
||||
|
||||
except Exception as e:
|
||||
print(f"generate_design_question Unexpected error: {e}")
|
||||
return 'unexpectederror'
|
||||
|
||||
def get_response_from_qwq(messages: list[dict], model_name: str, tools: list = None, max_retries: int = 3, initial_backoff: float = 1.0):
|
||||
"""
|
||||
Get response from LLM API with retry mechanism.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
model_name: Name of the model to use
|
||||
tools: Optional list of tools to use
|
||||
max_retries: Maximum number of retry attempts
|
||||
initial_backoff: Initial backoff time in seconds
|
||||
|
||||
Returns:
|
||||
Content of the response or error message
|
||||
"""
|
||||
retries = 0
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
# client = OpenAI(api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d", base_url="https://vip.apiyi.com/v1")
|
||||
# client = OpenAI(api_key="sk-df98afdc6b5b48db8195dcb4a68e804b", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||||
import random
|
||||
if random.random() > 0.5:
|
||||
client = OpenAI(api_key="sk-124748a0bdb24f4aa5ec2776e97cea2e", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||||
else:
|
||||
client = OpenAI(api_key="sk-f3dddc436b054ed1bb524d544bcb8f0f", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||||
# messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
|
||||
if tools is None:
|
||||
response = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
stream=True
|
||||
)
|
||||
else:
|
||||
response = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice='auto',
|
||||
parallel_tool_calls=True,
|
||||
stream=True
|
||||
)
|
||||
|
||||
reasoning_content = "" # 定义完整思考过程
|
||||
answer_content = "" # 定义完整回复
|
||||
tool_info = [] # 存储工具调用信息
|
||||
is_answering = False # 判断是否结束思考过程并开始回复
|
||||
# print("="*20+"思考过程"+"="*20)
|
||||
for chunk in response:
|
||||
# if not chunk.choices:
|
||||
# # 处理用量统计信息
|
||||
# print("\n"+"="*20+"Usage"+"="*20)
|
||||
# print(chunk.usage)
|
||||
# else:
|
||||
delta = chunk.choices[0].delta
|
||||
# 处理AI的思考过程(链式推理)
|
||||
if hasattr(delta, 'reasoning_content') and delta.reasoning_content is not None:
|
||||
reasoning_content += delta.reasoning_content
|
||||
# print(delta.reasoning_content,end="",flush=True) # 实时输出思考过程
|
||||
|
||||
# 处理最终回复内容
|
||||
else:
|
||||
if not is_answering: # 首次进入回复阶段时打印标题
|
||||
is_answering = True
|
||||
# print("\n"+"="*20+"回复内容"+"="*20)
|
||||
if delta.content is not None:
|
||||
answer_content += delta.content
|
||||
# print(delta.content,end="",flush=True) # 流式输出回复内容
|
||||
|
||||
# 处理工具调用信息(支持并行工具调用)
|
||||
if delta.tool_calls is not None:
|
||||
for tool_call in delta.tool_calls:
|
||||
index = tool_call.index # 工具调用索引,用于并行调用
|
||||
|
||||
# 动态扩展工具信息存储列表
|
||||
while len(tool_info) <= index:
|
||||
tool_info.append({})
|
||||
|
||||
# 收集工具调用ID(用于后续函数调用)
|
||||
# if tool_call.id:
|
||||
# tool_info[index]['id'] = tool_info[index].get('id', '') + tool_call.id
|
||||
|
||||
# 收集函数名称(用于后续路由到具体函数)
|
||||
if tool_call.function and tool_call.function.name:
|
||||
tool_info[index]['name'] = tool_info[index].get('name', '') + tool_call.function.name
|
||||
|
||||
# 收集函数参数(JSON字符串格式,需要后续解析)
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_info[index]['arguments'] = tool_info[index].get('arguments', '') + tool_call.function.arguments
|
||||
|
||||
tools_response = ""
|
||||
for tool in tool_info:
|
||||
tools_response += ("<tool_call>\n" + json.dumps(tool, ensure_ascii=False) + "\n</tool_call>\n")
|
||||
response = "<think>\n" + reasoning_content + "\n</think>\n" + "<answer>\n" + answer_content + tools_response + "\n</answer>\n"
|
||||
return response, tool_info
|
||||
# return reasoning_content, answer_content, tool_info
|
||||
|
||||
except RateLimitError as rate_error:
|
||||
retries += 1
|
||||
if retries > max_retries:
|
||||
print(f"Max retries exceeded for RateLimitError: {rate_error}")
|
||||
return 'apierror', []
|
||||
|
||||
# Exponential backoff with jitter
|
||||
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
|
||||
print(f"Rate limit hit, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries})")
|
||||
time.sleep(backoff_time)
|
||||
|
||||
except APIError as api_error:
|
||||
retries += 1
|
||||
if retries > max_retries:
|
||||
print(f"Max retries exceeded for APIError: {api_error}")
|
||||
return 'apierror', []
|
||||
|
||||
# Check if the error is retryable
|
||||
error_str = str(api_error)
|
||||
if "timeout" in error_str.lower() or "connection" in error_str.lower() or "server" in error_str.lower():
|
||||
# Exponential backoff with jitter
|
||||
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
|
||||
print(f"API error, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries}): {api_error}")
|
||||
time.sleep(backoff_time)
|
||||
else:
|
||||
# Non-retryable API error
|
||||
print(f"Non-retryable API error: {api_error}")
|
||||
return 'apierror', []
|
||||
|
||||
except Exception as e:
|
||||
print(f"generate_design_question Unexpected error: {e}")
|
||||
return 'unexpectederror', []
|
||||
|
||||
|
||||
|
||||
def read_json_file(file_path):
|
||||
"""Read the json file and return its content."""
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
print(f"Error reading file {file_path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
################################## utils
|
||||
|
||||
def clean_all_repetitions_with_details(text, min_length=10, threshold=10):
|
||||
"""
|
||||
综合清理文本中的各种重复内容,并返回详细信息
|
||||
|
||||
参数:
|
||||
- text: 要清理的文本
|
||||
- min_length: 最小重复片段长度
|
||||
- threshold: 重复内容的阈值
|
||||
|
||||
返回:
|
||||
- cleaned_text: 清理后的文本
|
||||
- is_repetitive: 是否检测到重复
|
||||
- repetition_details: 重复内容的详细信息
|
||||
"""
|
||||
original_text = text
|
||||
is_repetitive = False
|
||||
repetition_details = []
|
||||
|
||||
# 1. 首先处理有换行符的重复
|
||||
if '\n' in text:
|
||||
lines = text.split('\n')
|
||||
unique_lines = []
|
||||
line_counts = {}
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
normalized = line.strip().lower()
|
||||
if not normalized:
|
||||
unique_lines.append(line)
|
||||
continue
|
||||
|
||||
line_counts[normalized] = line_counts.get(normalized, 0) + 1
|
||||
|
||||
if line_counts[normalized] <= threshold:
|
||||
unique_lines.append(line)
|
||||
|
||||
# 如果这是第一次超过阈值,记录重复详情
|
||||
if line_counts[normalized] == threshold + 1:
|
||||
# 找到原始形式(保留大小写)
|
||||
original_form = None
|
||||
for l in lines[:i]:
|
||||
if l.strip().lower() == normalized:
|
||||
original_form = l
|
||||
break
|
||||
|
||||
if original_form is None:
|
||||
original_form = line
|
||||
|
||||
repetition_details.append({
|
||||
'type': 'line_repetition',
|
||||
'repeated_string': original_form,
|
||||
'repeat_count': line_counts[normalized]
|
||||
})
|
||||
|
||||
if any(count > threshold for count in line_counts.values()):
|
||||
text = '\n'.join(unique_lines)
|
||||
is_repetitive = True
|
||||
|
||||
# 2. 处理同一行内的连续重复模式
|
||||
for length in range(min_length, 101):
|
||||
pattern = r'(.{' + str(length) + r'})(\1)+'
|
||||
|
||||
while True:
|
||||
match = re.search(pattern, text)
|
||||
if not match:
|
||||
break
|
||||
|
||||
repeated_part = match.group(1)
|
||||
full_match = match.group(0)
|
||||
|
||||
# 计算重复次数
|
||||
repeat_count = len(full_match) // len(repeated_part)
|
||||
|
||||
# 记录重复详情
|
||||
repetition_details.append({
|
||||
'type': 'inline_repetition',
|
||||
'repeated_string': repeated_part,
|
||||
'repeat_count': repeat_count,
|
||||
'total_length': len(full_match),
|
||||
'position': match.start()
|
||||
})
|
||||
|
||||
text = text.replace(full_match, repeated_part)
|
||||
is_repetitive = True
|
||||
|
||||
# 3. 处理句子级别的重复
|
||||
sentences = re.split(r'(?<=[.!?。?!])\s+', text)
|
||||
if len(sentences) > 1:
|
||||
sentence_counter = Counter(sentences)
|
||||
|
||||
for sentence, count in sentence_counter.items():
|
||||
if count > threshold:
|
||||
repetition_details.append({
|
||||
'type': 'sentence_repetition',
|
||||
'repeated_string': sentence,
|
||||
'repeat_count': count
|
||||
})
|
||||
|
||||
if any(count > threshold for count in sentence_counter.values()):
|
||||
unique_sentences = []
|
||||
seen_sentences = {}
|
||||
|
||||
for sentence in sentences:
|
||||
seen_sentences[sentence] = seen_sentences.get(sentence, 0) + 1
|
||||
if seen_sentences[sentence] <= threshold:
|
||||
unique_sentences.append(sentence)
|
||||
|
||||
# 重新组合文本
|
||||
text = ' '.join(unique_sentences)
|
||||
is_repetitive = True
|
||||
|
||||
# 4. 处理更短的重复(如果前面的方法没有检测到重复)
|
||||
if not is_repetitive and min_length > 5:
|
||||
for length in range(5, min_length):
|
||||
pattern = r'(.{' + str(length) + r'})(\1){2,}' # 至少重复3次才处理
|
||||
|
||||
while True:
|
||||
match = re.search(pattern, text)
|
||||
if not match:
|
||||
break
|
||||
|
||||
repeated_part = match.group(1)
|
||||
full_match = match.group(0)
|
||||
|
||||
# 计算重复次数
|
||||
repeat_count = len(full_match) // len(repeated_part)
|
||||
|
||||
# 记录重复详情
|
||||
repetition_details.append({
|
||||
'type': 'short_repetition',
|
||||
'repeated_string': repeated_part,
|
||||
'repeat_count': repeat_count,
|
||||
'total_length': len(full_match),
|
||||
'position': match.start()
|
||||
})
|
||||
|
||||
text = text.replace(full_match, repeated_part)
|
||||
is_repetitive = True
|
||||
|
||||
# 按重复类型和长度排序
|
||||
repetition_details.sort(key=lambda x: (-len(x['repeated_string']), x['type']))
|
||||
|
||||
return text, is_repetitive or text != original_text, repetition_details
|
||||
|
||||
def create_table(table_name, connection_pool):
|
||||
"""Create the required MySQL table if it does not exist."""
|
||||
db = connection_pool.get_connection()
|
||||
cursor = db.cursor()
|
||||
create_table_query = f"""
|
||||
CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
mp_id TEXT,
|
||||
question_model TEXT,
|
||||
question TEXT,
|
||||
answer_model TEXT,
|
||||
answer TEXT,
|
||||
answer_len INT
|
||||
)
|
||||
"""
|
||||
cursor.execute(create_table_query)
|
||||
db.commit()
|
||||
cursor.close()
|
||||
db.close()
|
||||
|
||||
def record_exists(mp_id, table_name, connection_pool):
|
||||
"""Check if a mp_id already exists in the table."""
|
||||
db = connection_pool.get_connection()
|
||||
cursor = db.cursor()
|
||||
query = f"SELECT * FROM {table_name} WHERE mp_id = %s"
|
||||
cursor.execute(query, (mp_id,))
|
||||
result = cursor.fetchone()
|
||||
cursor.fetchall() # Ensure all results are processed
|
||||
cursor.close()
|
||||
db.close()
|
||||
return result is not None
|
||||
|
||||
def insert_record(entry, table_name, connection_pool):
|
||||
"""Insert a record into the MySQL table."""
|
||||
db = None
|
||||
cursor = None
|
||||
try:
|
||||
db = connection_pool.get_connection()
|
||||
cursor = db.cursor()
|
||||
|
||||
insert_query = f"""
|
||||
INSERT INTO {table_name}
|
||||
(mp_id, question_model, question, answer_model, answer, answer_len)
|
||||
VALUES (%s, %s, %s, %s, %s, %s)
|
||||
"""
|
||||
values = (
|
||||
entry["mp_id"], entry["question_model"],
|
||||
entry["question"], entry["answer_model"], entry["answer"], entry["answer_len"],
|
||||
)
|
||||
cursor.execute(insert_query, values)
|
||||
db.commit()
|
||||
|
||||
except Error as e:
|
||||
print(f"Error: {e}")
|
||||
db.rollback()
|
||||
finally:
|
||||
# Ensure cursor is closed
|
||||
if cursor:
|
||||
cursor.close()
|
||||
# Ensure connection is returned to the pool
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
|
||||
# Initialize SQLite database connection
|
||||
def initialize_db():
|
||||
conn = sqlite3.connect('multi_turns_data.db', check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS conversations (
|
||||
mp_id TEXT PRIMARY KEY,
|
||||
sample TEXT,
|
||||
token_num INTEGER
|
||||
)
|
||||
''')
|
||||
conn.commit()
|
||||
return conn
|
||||
|
||||
# Save sample to SQLite database
|
||||
def save_to_db(conn, mp_id, sample, total_token):
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
INSERT OR REPLACE INTO conversations (mp_id, sample, token_num)
|
||||
VALUES (?, ?, ?)
|
||||
''', (mp_id, str(sample), total_token))
|
||||
conn.commit()
|
||||
|
||||
|
||||
def read_cif_txt_file(file_path):
|
||||
"""Read the markdown file and return its content."""
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
print(f"Error reading file {file_path}: {e}")
|
||||
return None
|
||||
|
||||
def round_values(data, precision=3):
|
||||
"""
|
||||
递归地将字典中的所有值保留三位小数
|
||||
"""
|
||||
if isinstance(data, dict): # 如果是字典
|
||||
return {key: round_values(value) for key, value in data.items()}
|
||||
elif isinstance(data, list): # 如果是列表,递归处理每个元素
|
||||
return [round_values(item) for item in data]
|
||||
elif isinstance(data, (int, float)): # 如果是数字,保留三位小数
|
||||
return round(data, precision)
|
||||
else: # 对其他类型,直接返回
|
||||
return data
|
||||
|
||||
|
||||
def decimal_to_fraction(decimal_value, max_denominator=1000):
|
||||
"""
|
||||
将小数转换为分数表示
|
||||
|
||||
参数:
|
||||
decimal_value: 要转换的小数
|
||||
max_denominator: 分母的最大值,用于控制精度
|
||||
|
||||
返回:
|
||||
分数表示的字符串
|
||||
"""
|
||||
frac = Fraction(decimal_value).limit_denominator(max_denominator)
|
||||
return f"{frac.numerator}/{frac.denominator}"
|
||||
|
||||
def poscar_to_fractional_representation(poscar_content, max_denominator=1000):
|
||||
"""
|
||||
将POSCAR文件中的数值转换为分数表示
|
||||
|
||||
参数:
|
||||
poscar_content: POSCAR文件内容
|
||||
max_denominator: 分母的最大值,用于控制精度
|
||||
|
||||
返回:
|
||||
转换后的POSCAR内容,数值以分数表示
|
||||
"""
|
||||
lines = poscar_content.strip().split('\n')
|
||||
result_lines = []
|
||||
|
||||
# 保留系统名称
|
||||
result_lines.append(lines[0])
|
||||
|
||||
# 保留缩放因子
|
||||
scaling_factor = float(lines[1])
|
||||
result_lines.append(lines[1])
|
||||
|
||||
# 处理晶格向量
|
||||
for i in range(2, 5):
|
||||
vector = [float(x) for x in lines[i].split()]
|
||||
# 将每个分量转换为分数
|
||||
fractional_vector = [decimal_to_fraction(x, max_denominator) for x in vector]
|
||||
result_lines.append(" " + " ".join(fractional_vector))
|
||||
|
||||
# 保留元素类型和数量
|
||||
if len(lines) > 5:
|
||||
result_lines.append(lines[5])
|
||||
if len(lines) > 6:
|
||||
result_lines.append(lines[6])
|
||||
|
||||
# 保留坐标类型
|
||||
if len(lines) > 7:
|
||||
result_lines.append(lines[7])
|
||||
|
||||
# 处理原子坐标
|
||||
for i in range(8, len(lines)):
|
||||
parts = lines[i].split()
|
||||
if len(parts) >= 3:
|
||||
# 将坐标转换为分数
|
||||
coords = [float(parts[j]) for j in range(3)]
|
||||
fractional_coords = [decimal_to_fraction(x, max_denominator) for x in coords]
|
||||
|
||||
# 构建新行
|
||||
new_line = " " + " ".join(fractional_coords)
|
||||
if len(parts) > 3:
|
||||
new_line += " " + " ".join(parts[3:])
|
||||
result_lines.append(new_line)
|
||||
else:
|
||||
# 保留非坐标行
|
||||
result_lines.append(lines[i])
|
||||
|
||||
return "\n".join(result_lines)
|
||||
|
||||
|
||||
def remove_symmetry_equiv_xyz(cif_content):
|
||||
"""
|
||||
删除CIF文件中的对称性操作部分
|
||||
|
||||
参数:
|
||||
cif_content: CIF文件内容字符串
|
||||
|
||||
返回:
|
||||
清理后的CIF内容字符串
|
||||
"""
|
||||
lines = cif_content.split('\n')
|
||||
output_lines = []
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i].strip()
|
||||
|
||||
# 检测循环开始
|
||||
if line == 'loop_':
|
||||
# 查看下一行,检查是否是对称性循环
|
||||
next_lines = []
|
||||
j = i + 1
|
||||
while j < len(lines) and lines[j].strip().startswith('_'):
|
||||
next_lines.append(lines[j].strip())
|
||||
j += 1
|
||||
|
||||
# 检查是否包含对称性操作标签
|
||||
if any('_symmetry_equiv_pos_as_xyz' in tag for tag in next_lines):
|
||||
# 跳过整个循环块
|
||||
while i < len(lines):
|
||||
if i + 1 >= len(lines):
|
||||
break
|
||||
|
||||
next_line = lines[i + 1].strip()
|
||||
# 检查是否到达下一个循环或数据块
|
||||
if next_line == 'loop_' or next_line.startswith('data_'):
|
||||
break
|
||||
|
||||
# 检查是否到达原子位置部分
|
||||
if next_line.startswith('_atom_site_'):
|
||||
break
|
||||
|
||||
i += 1
|
||||
else:
|
||||
# 不是对称性循环,保留loop_行
|
||||
output_lines.append(lines[i])
|
||||
else:
|
||||
# 非循环开始行,直接保留
|
||||
output_lines.append(lines[i])
|
||||
|
||||
i += 1
|
||||
|
||||
return '\n'.join(output_lines)
|
||||
|
||||
|
||||
|
||||
def remove_null_values(d):
|
||||
"""
|
||||
Recursively remove key-value pairs with null (None) values from a dictionary.
|
||||
|
||||
Args:
|
||||
d (dict): The dictionary to clean.
|
||||
|
||||
Returns:
|
||||
dict: A new dictionary without null values.
|
||||
"""
|
||||
if not isinstance(d, dict):
|
||||
raise ValueError("Input must be a dictionary")
|
||||
_d = copy.deepcopy(d)
|
||||
|
||||
def recursive_remove(d):
|
||||
cleaned_dict = {}
|
||||
for key, value in d.items():
|
||||
if isinstance(value, dict):
|
||||
# Recursively clean nested dictionaries
|
||||
nested_cleaned = recursive_remove(value)
|
||||
if nested_cleaned: # Only add non-empty dictionaries
|
||||
cleaned_dict[key] = nested_cleaned
|
||||
elif value is not None and key != 'version':
|
||||
cleaned_dict[key] = value
|
||||
|
||||
return cleaned_dict
|
||||
|
||||
clean_dict = recursive_remove(d)
|
||||
if _d['cbm'] is None and _d['vbm'] is None and _d['band_gap'] is not None:
|
||||
# clean_dict['band_gap'] = None
|
||||
clean_dict.pop('band_gap')
|
||||
return clean_dict
|
||||
|
||||
|
||||
def get_extra_cif_info(path: str, fields_name: list):
|
||||
"""Extract specific fields from the CIF description."""
|
||||
basic_fields = ['formula_pretty', 'chemsys', 'composition', 'elements', 'symmetry', 'nsites', 'volume', 'density']
|
||||
energy_electronic_fields = ['formation_energy_per_atom', 'energy_above_hull', 'is_stable', 'efermi', 'cbm', 'vbm', 'band_gap', 'is_gap_direct']
|
||||
metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'num_magnetic_sites']
|
||||
# metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'total_magnetization_normalized_vol', 'total_magnetization_normalized_formula_units', 'num_magnetic_sites', 'num_unique_magnetic_sites', 'types_of_magnetic_species', "decomposes_to"]
|
||||
|
||||
selected_fields = []
|
||||
if fields_name[0] == 'all_fields':
|
||||
selected_fields = basic_fields + energy_electronic_fields + metal_magentic_fields
|
||||
# selected_fields = energy_electronic_fields + metal_magentic_fields
|
||||
else:
|
||||
for field in fields_name:
|
||||
selected_fields.extend(locals().get(field, []))
|
||||
|
||||
with open(path, 'r') as f:
|
||||
docs = json.load(f)
|
||||
|
||||
new_docs = {}
|
||||
for field_name in selected_fields:
|
||||
new_docs[field_name] = docs.get(field_name, '')
|
||||
|
||||
# new_docs['structure'] = {"lattice": docs['structure']['lattice']}
|
||||
return new_docs
|
||||
|
||||
def extract_json(text):
|
||||
"""Extract JSON content from a block of text using regex."""
|
||||
json_pattern = re.compile(r'\\{(?:[^{}]|(?R))*\\}')
|
||||
matches = json_pattern.search(text)
|
||||
if matches:
|
||||
json_str = matches.group(0)
|
||||
try:
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
return None
|
||||
|
||||
def extract_and_parse_json(response):
|
||||
"""Extract and parse JSON from a response."""
|
||||
json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', response)
|
||||
json_str = json_match.group(1) if json_match else response.strip()
|
||||
json_str = re.sub(r'(\$[^\$]*\$)', lambda m: m.group(1).replace('\\', '\\\\'), json_str)
|
||||
json_str = json_str.replace('\\"', '"').replace("\\'", "'")
|
||||
try:
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON parse error: {e}")
|
||||
return 'errformat'
|
||||
|
||||
|
||||
# 计算输入消息的tokens
|
||||
def count_message_tokens(messages, model_name):
|
||||
encoding = tiktoken.encoding_for_model(model_name)
|
||||
num_tokens = 0
|
||||
|
||||
num_tokens += len(encoding.encode(messages))
|
||||
|
||||
return num_tokens
|
||||
|
||||
def make_multi_turns_sharegpt_sample(humans: list[str], gpts: list[str], system: str="{SYSTEM}"):
|
||||
sample = {}
|
||||
conversations = []
|
||||
|
||||
if system is not None and system != "":
|
||||
sample["system"] = system
|
||||
|
||||
assert len(humans) !=0, "human cannot be None"
|
||||
assert len(gpts) == len(humans), "human and gpt must have the same length"
|
||||
|
||||
for human, gpt in zip(humans, gpts):
|
||||
if human is not None and human != "":
|
||||
assert gpt is not None, "gpt cannot be None"
|
||||
assert gpt != "", "gpt cannot be empty"
|
||||
# 下列顺序不可改
|
||||
conversations.append({"from": "human", "value": human})
|
||||
conversations.append({"from": "gpt", "value": gpt})
|
||||
|
||||
sample["conversations"] = conversations
|
||||
return sample
|
||||
|
||||
|
||||
|
||||
##################################### utils
|
||||
Reference in New Issue
Block a user