819 lines
30 KiB
Python
Executable File
819 lines
30 KiB
Python
Executable File
"""
|
||
This script generates questions and answers from a given set of CIFs.
|
||
It uses the OpenAI API and MySQL for storing and retrieving data.
|
||
@author: Yutang Li
|
||
"""
|
||
import multiprocessing
|
||
import sqlite3
|
||
import jsonlines
|
||
import tiktoken
|
||
import re
|
||
from fractions import Fraction
|
||
import numpy as np
|
||
import glob
|
||
import tqdm
|
||
import copy
|
||
import json
|
||
import time
|
||
import random
|
||
from openai import OpenAI, APIError, RateLimitError
|
||
from mysql.connector import pooling, Error
|
||
from collections import Counter
|
||
|
||
|
||
|
||
def get_response_from_deepseek_r1(messages: list[dict], prefix: bool = False, max_retries: int = 3, initial_backoff: float = 1.0):
|
||
"""
|
||
Get response from DeepSeek API with retry mechanism.
|
||
|
||
Args:
|
||
messages: List of message dictionaries
|
||
prefix: Whether to use the prefix URL
|
||
max_retries: Maximum number of retry attempts
|
||
initial_backoff: Initial backoff time in seconds
|
||
|
||
Returns:
|
||
Tuple of (reasoning_content, content) or error messages
|
||
"""
|
||
retries = 0
|
||
while retries <= max_retries:
|
||
try:
|
||
base_url = "https://api.deepseek.com/beta" if prefix else "https://vip.apiyi.com/v1"
|
||
api_key = "sk-59279cc16ec740089146ef9aef9c1671" if prefix else "sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d"
|
||
|
||
client = OpenAI(api_key=api_key, base_url=base_url)
|
||
# messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
|
||
|
||
response = client.chat.completions.create(
|
||
model="deepseek-r1",
|
||
messages=messages,
|
||
temperature=0.6
|
||
)
|
||
#print("response",response)
|
||
|
||
reasoning_content = response.choices[0].message.content.split("</think>\n")[0].split("<think>\n")[-1]
|
||
if reasoning_content=='':
|
||
reasoning_content=response.choices[0].message.content.split("</think>\n")[1]
|
||
# while reasoning_content == "" :
|
||
# if retries<max_retries:
|
||
# response = client.chat.completions.create(
|
||
# model="deepseek-r1",
|
||
# messages=messages,
|
||
# temperature=0.6
|
||
# )
|
||
# retries+=1
|
||
# else:
|
||
# print(f"Max retries exceeded for RateLimitError: {rate_error}")
|
||
# return 'apierror', 'apierror'
|
||
# reasoning_content = response.choices[0].message.content.split("</think>\n")[0].split("<think>\n")[-1]
|
||
|
||
content = response.choices[0].message.content.split("</think>\n")[-1]
|
||
return reasoning_content, content
|
||
|
||
except RateLimitError as rate_error:
|
||
retries += 1
|
||
if retries > max_retries:
|
||
print(f"Max retries exceeded for RateLimitError: {rate_error}")
|
||
return 'apierror', 'apierror'
|
||
|
||
# Exponential backoff with jitter
|
||
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
|
||
print(f"Rate limit hit, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries})")
|
||
time.sleep(backoff_time)
|
||
|
||
except APIError as api_error:
|
||
retries += 1
|
||
if retries > max_retries:
|
||
print(f"Max retries exceeded for APIError: {api_error}")
|
||
return 'apierror', 'apierror'
|
||
|
||
# Check if the error is retryable
|
||
error_str = str(api_error)
|
||
if "timeout" in error_str.lower() or "connection" in error_str.lower() or "server" in error_str.lower():
|
||
# Exponential backoff with jitter
|
||
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
|
||
print(f"API error, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries}): {api_error}")
|
||
time.sleep(backoff_time)
|
||
else:
|
||
# Non-retryable API error
|
||
print(f"Non-retryable API error: {api_error}")
|
||
return 'apierror', 'apierror'
|
||
|
||
except Exception as e:
|
||
print(f"generate_design_question Unexpected error: {e}")
|
||
return 'unexpectederror', 'unexpectederror'
|
||
|
||
|
||
def get_response_from_llm(messages: list[dict], model_name: str, tools: list = None, max_retries: int = 3, initial_backoff: float = 1.0):
|
||
"""
|
||
Get response from LLM API with retry mechanism.
|
||
|
||
Args:
|
||
messages: List of message dictionaries
|
||
model_name: Name of the model to use
|
||
tools: Optional list of tools to use
|
||
max_retries: Maximum number of retry attempts
|
||
initial_backoff: Initial backoff time in seconds
|
||
|
||
Returns:
|
||
Content of the response or error message
|
||
"""
|
||
retries = 0
|
||
while retries <= max_retries:
|
||
try:
|
||
client = OpenAI(api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d", base_url="https://vip.apiyi.com/v1")
|
||
# messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
|
||
if tools is None:
|
||
response = client.chat.completions.create(
|
||
model=model_name,
|
||
messages=messages,
|
||
)
|
||
else:
|
||
response = client.chat.completions.create(
|
||
model=model_name,
|
||
messages=messages,
|
||
tools=tools,
|
||
tool_choice='auto',
|
||
parallel_tool_calls=True
|
||
)
|
||
content = response.choices[0].message.content
|
||
return content
|
||
|
||
except RateLimitError as rate_error:
|
||
retries += 1
|
||
if retries > max_retries:
|
||
print(f"Max retries exceeded for RateLimitError: {rate_error}")
|
||
return 'apierror'
|
||
|
||
# Exponential backoff with jitter
|
||
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
|
||
print(f"Rate limit hit, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries})")
|
||
time.sleep(backoff_time)
|
||
|
||
except APIError as api_error:
|
||
retries += 1
|
||
if retries > max_retries:
|
||
print(f"Max retries exceeded for APIError: {api_error}")
|
||
return 'apierror'
|
||
|
||
# Check if the error is retryable
|
||
error_str = str(api_error)
|
||
if "timeout" in error_str.lower() or "connection" in error_str.lower() or "server" in error_str.lower():
|
||
# Exponential backoff with jitter
|
||
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
|
||
print(f"API error, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries}): {api_error}")
|
||
time.sleep(backoff_time)
|
||
else:
|
||
# Non-retryable API error
|
||
print(f"Non-retryable API error: {api_error}")
|
||
return 'apierror'
|
||
|
||
except Exception as e:
|
||
print(f"generate_design_question Unexpected error: {e}")
|
||
return 'unexpectederror'
|
||
|
||
def get_response_from_qwq(messages: list[dict], model_name: str, tools: list = None, max_retries: int = 3, initial_backoff: float = 1.0):
|
||
"""
|
||
Get response from LLM API with retry mechanism.
|
||
|
||
Args:
|
||
messages: List of message dictionaries
|
||
model_name: Name of the model to use
|
||
tools: Optional list of tools to use
|
||
max_retries: Maximum number of retry attempts
|
||
initial_backoff: Initial backoff time in seconds
|
||
|
||
Returns:
|
||
Content of the response or error message
|
||
"""
|
||
retries = 0
|
||
while retries <= max_retries:
|
||
try:
|
||
client = OpenAI(api_key="sk-oYh3Xrhg8oDY2gW02c966f31C84449Ad86F9Cd9dF6E64a8d", base_url="https://vip.apiyi.com/v1")
|
||
# client = OpenAI(api_key="sk-df98afdc6b5b48db8195dcb4a68e804b", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||
# import random
|
||
# if random.random() > 0.5:
|
||
# client = OpenAI(api_key="sk-124748a0bdb24f4aa5ec2776e97cea2e", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||
# else:
|
||
# client = OpenAI(api_key="sk-f3dddc436b054ed1bb524d544bcb8f0f", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||
# messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
|
||
if tools is None:
|
||
response = client.chat.completions.create(
|
||
model=model_name,
|
||
messages=messages,
|
||
stream=True
|
||
)
|
||
else:
|
||
response = client.chat.completions.create(
|
||
model=model_name,
|
||
messages=messages,
|
||
tools=tools,
|
||
tool_choice='auto',
|
||
parallel_tool_calls=True,
|
||
stream=True
|
||
)
|
||
|
||
reasoning_content = "" # 定义完整思考过程
|
||
answer_content = "" # 定义完整回复
|
||
tool_info = [] # 存储工具调用信息
|
||
is_answering = False # 判断是否结束思考过程并开始回复
|
||
# print("="*20+"思考过程"+"="*20)
|
||
for chunk in response:
|
||
# if not chunk.choices:
|
||
# # 处理用量统计信息
|
||
# print("\n"+"="*20+"Usage"+"="*20)
|
||
# print(chunk.usage)
|
||
# else:
|
||
delta = chunk.choices[0].delta
|
||
# 处理AI的思考过程(链式推理)
|
||
if hasattr(delta, 'reasoning_content') and delta.reasoning_content is not None:
|
||
reasoning_content += delta.reasoning_content
|
||
# print(delta.reasoning_content,end="",flush=True) # 实时输出思考过程
|
||
|
||
# 处理最终回复内容
|
||
else:
|
||
if not is_answering: # 首次进入回复阶段时打印标题
|
||
is_answering = True
|
||
# print("\n"+"="*20+"回复内容"+"="*20)
|
||
if delta.content is not None:
|
||
answer_content += delta.content
|
||
# print(delta.content,end="",flush=True) # 流式输出回复内容
|
||
|
||
# 处理工具调用信息(支持并行工具调用)
|
||
if delta.tool_calls is not None:
|
||
for tool_call in delta.tool_calls:
|
||
index = tool_call.index # 工具调用索引,用于并行调用
|
||
|
||
# 动态扩展工具信息存储列表
|
||
while len(tool_info) <= index:
|
||
tool_info.append({})
|
||
|
||
# 收集工具调用ID(用于后续函数调用)
|
||
# if tool_call.id:
|
||
# tool_info[index]['id'] = tool_info[index].get('id', '') + tool_call.id
|
||
|
||
# 收集函数名称(用于后续路由到具体函数)
|
||
if tool_call.function and tool_call.function.name:
|
||
tool_info[index]['name'] = tool_info[index].get('name', '') + tool_call.function.name
|
||
|
||
# 收集函数参数(JSON字符串格式,需要后续解析)
|
||
if tool_call.function and tool_call.function.arguments:
|
||
tool_info[index]['arguments'] = tool_info[index].get('arguments', '') + tool_call.function.arguments
|
||
|
||
tools_response = ""
|
||
for tool in tool_info:
|
||
tools_response += ("<tool_call>\n" + json.dumps(tool, ensure_ascii=False) + "\n</tool_call>\n")
|
||
response = "<think>\n" + reasoning_content + "\n</think>\n" + "<answer>\n" + answer_content + tools_response + "\n</answer>\n"
|
||
return response, tool_info
|
||
# return reasoning_content, answer_content, tool_info
|
||
|
||
except RateLimitError as rate_error:
|
||
retries += 1
|
||
if retries > max_retries:
|
||
print(f"Max retries exceeded for RateLimitError: {rate_error}")
|
||
return 'apierror', []
|
||
|
||
# Exponential backoff with jitter
|
||
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
|
||
print(f"Rate limit hit, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries})")
|
||
time.sleep(backoff_time)
|
||
|
||
except APIError as api_error:
|
||
retries += 1
|
||
if retries > max_retries:
|
||
print(f"Max retries exceeded for APIError: {api_error}")
|
||
return 'apierror', []
|
||
|
||
# Check if the error is retryable
|
||
error_str = str(api_error)
|
||
if "timeout" in error_str.lower() or "connection" in error_str.lower() or "server" in error_str.lower():
|
||
# Exponential backoff with jitter
|
||
backoff_time = initial_backoff * (2 ** (retries - 1)) * (0.5 + random.random())
|
||
print(f"API error, retrying in {backoff_time:.2f} seconds (attempt {retries}/{max_retries}): {api_error}")
|
||
time.sleep(backoff_time)
|
||
else:
|
||
# Non-retryable API error
|
||
print(f"Non-retryable API error: {api_error}")
|
||
return 'apierror', []
|
||
|
||
except Exception as e:
|
||
print(f"generate_design_question Unexpected error: {e}")
|
||
return 'unexpectederror', []
|
||
|
||
|
||
|
||
def read_json_file(file_path):
|
||
"""Read the json file and return its content."""
|
||
try:
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
return json.load(f)
|
||
except Exception as e:
|
||
print(f"Error reading file {file_path}: {e}")
|
||
return None
|
||
|
||
def read_jsonline_file(file_path):
|
||
with jsonlines.open(file_path, mode='r') as reader:
|
||
datas = [line for line in reader]
|
||
return datas
|
||
################################## utils
|
||
|
||
def clean_all_repetitions_with_details(text, min_length=10, threshold=10):
|
||
"""
|
||
综合清理文本中的各种重复内容,并返回详细信息
|
||
|
||
参数:
|
||
- text: 要清理的文本
|
||
- min_length: 最小重复片段长度
|
||
- threshold: 重复内容的阈值
|
||
|
||
返回:
|
||
- cleaned_text: 清理后的文本
|
||
- is_repetitive: 是否检测到重复
|
||
- repetition_details: 重复内容的详细信息
|
||
"""
|
||
original_text = text
|
||
is_repetitive = False
|
||
repetition_details = []
|
||
|
||
# 1. 首先处理有换行符的重复
|
||
if '\n' in text:
|
||
lines = text.split('\n')
|
||
unique_lines = []
|
||
line_counts = {}
|
||
|
||
for i, line in enumerate(lines):
|
||
normalized = line.strip().lower()
|
||
if not normalized:
|
||
unique_lines.append(line)
|
||
continue
|
||
|
||
line_counts[normalized] = line_counts.get(normalized, 0) + 1
|
||
|
||
if line_counts[normalized] <= threshold:
|
||
unique_lines.append(line)
|
||
|
||
# 如果这是第一次超过阈值,记录重复详情
|
||
if line_counts[normalized] == threshold + 1:
|
||
# 找到原始形式(保留大小写)
|
||
original_form = None
|
||
for l in lines[:i]:
|
||
if l.strip().lower() == normalized:
|
||
original_form = l
|
||
break
|
||
|
||
if original_form is None:
|
||
original_form = line
|
||
|
||
repetition_details.append({
|
||
'type': 'line_repetition',
|
||
'repeated_string': original_form,
|
||
'repeat_count': line_counts[normalized]
|
||
})
|
||
|
||
if any(count > threshold for count in line_counts.values()):
|
||
text = '\n'.join(unique_lines)
|
||
is_repetitive = True
|
||
|
||
# 2. 处理同一行内的连续重复模式
|
||
for length in range(min_length, 101):
|
||
pattern = r'(.{' + str(length) + r'})(\1)+'
|
||
|
||
while True:
|
||
match = re.search(pattern, text)
|
||
if not match:
|
||
break
|
||
|
||
repeated_part = match.group(1)
|
||
full_match = match.group(0)
|
||
|
||
# 计算重复次数
|
||
repeat_count = len(full_match) // len(repeated_part)
|
||
|
||
# 记录重复详情
|
||
repetition_details.append({
|
||
'type': 'inline_repetition',
|
||
'repeated_string': repeated_part,
|
||
'repeat_count': repeat_count,
|
||
'total_length': len(full_match),
|
||
'position': match.start()
|
||
})
|
||
|
||
text = text.replace(full_match, repeated_part)
|
||
is_repetitive = True
|
||
|
||
# 3. 处理句子级别的重复
|
||
sentences = re.split(r'(?<=[.!?。?!])\s+', text)
|
||
if len(sentences) > 1:
|
||
sentence_counter = Counter(sentences)
|
||
|
||
for sentence, count in sentence_counter.items():
|
||
if count > threshold:
|
||
repetition_details.append({
|
||
'type': 'sentence_repetition',
|
||
'repeated_string': sentence,
|
||
'repeat_count': count
|
||
})
|
||
|
||
if any(count > threshold for count in sentence_counter.values()):
|
||
unique_sentences = []
|
||
seen_sentences = {}
|
||
|
||
for sentence in sentences:
|
||
seen_sentences[sentence] = seen_sentences.get(sentence, 0) + 1
|
||
if seen_sentences[sentence] <= threshold:
|
||
unique_sentences.append(sentence)
|
||
|
||
# 重新组合文本
|
||
text = ' '.join(unique_sentences)
|
||
is_repetitive = True
|
||
|
||
# 4. 处理更短的重复(如果前面的方法没有检测到重复)
|
||
if not is_repetitive and min_length > 5:
|
||
for length in range(5, min_length):
|
||
pattern = r'(.{' + str(length) + r'})(\1){2,}' # 至少重复3次才处理
|
||
|
||
while True:
|
||
match = re.search(pattern, text)
|
||
if not match:
|
||
break
|
||
|
||
repeated_part = match.group(1)
|
||
full_match = match.group(0)
|
||
|
||
# 计算重复次数
|
||
repeat_count = len(full_match) // len(repeated_part)
|
||
|
||
# 记录重复详情
|
||
repetition_details.append({
|
||
'type': 'short_repetition',
|
||
'repeated_string': repeated_part,
|
||
'repeat_count': repeat_count,
|
||
'total_length': len(full_match),
|
||
'position': match.start()
|
||
})
|
||
|
||
text = text.replace(full_match, repeated_part)
|
||
is_repetitive = True
|
||
|
||
# 按重复类型和长度排序
|
||
repetition_details.sort(key=lambda x: (-len(x['repeated_string']), x['type']))
|
||
|
||
return text, is_repetitive or text != original_text, repetition_details
|
||
|
||
def create_table(table_name, connection_pool):
|
||
"""Create the required MySQL table if it does not exist."""
|
||
db = connection_pool.get_connection()
|
||
cursor = db.cursor()
|
||
create_table_query = f"""
|
||
CREATE TABLE IF NOT EXISTS {table_name} (
|
||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||
mp_id TEXT,
|
||
question_model TEXT,
|
||
question TEXT,
|
||
answer_model TEXT,
|
||
answer TEXT,
|
||
answer_len INT
|
||
)
|
||
"""
|
||
cursor.execute(create_table_query)
|
||
db.commit()
|
||
cursor.close()
|
||
db.close()
|
||
|
||
def record_exists(mp_id, table_name, connection_pool):
|
||
"""Check if a mp_id already exists in the table."""
|
||
db = connection_pool.get_connection()
|
||
cursor = db.cursor()
|
||
query = f"SELECT * FROM {table_name} WHERE mp_id = %s"
|
||
cursor.execute(query, (mp_id,))
|
||
result = cursor.fetchone()
|
||
cursor.fetchall() # Ensure all results are processed
|
||
cursor.close()
|
||
db.close()
|
||
return result is not None
|
||
|
||
def insert_record(entry, table_name, connection_pool):
|
||
"""Insert a record into the MySQL table."""
|
||
db = None
|
||
cursor = None
|
||
try:
|
||
db = connection_pool.get_connection()
|
||
cursor = db.cursor()
|
||
|
||
insert_query = f"""
|
||
INSERT INTO {table_name}
|
||
(mp_id, question_model, question, answer_model, answer, answer_len)
|
||
VALUES (%s, %s, %s, %s, %s, %s)
|
||
"""
|
||
values = (
|
||
entry["mp_id"], entry["question_model"],
|
||
entry["question"], entry["answer_model"], entry["answer"], entry["answer_len"],
|
||
)
|
||
cursor.execute(insert_query, values)
|
||
db.commit()
|
||
|
||
except Error as e:
|
||
print(f"Error: {e}")
|
||
db.rollback()
|
||
finally:
|
||
# Ensure cursor is closed
|
||
if cursor:
|
||
cursor.close()
|
||
# Ensure connection is returned to the pool
|
||
if db:
|
||
db.close()
|
||
|
||
|
||
# Initialize SQLite database connection
|
||
def initialize_db():
|
||
conn = sqlite3.connect('multi_turns_data.db', check_same_thread=False)
|
||
cursor = conn.cursor()
|
||
cursor.execute('''
|
||
CREATE TABLE IF NOT EXISTS conversations (
|
||
mp_id TEXT PRIMARY KEY,
|
||
sample TEXT,
|
||
token_num INTEGER
|
||
)
|
||
''')
|
||
conn.commit()
|
||
return conn
|
||
|
||
# Save sample to SQLite database
|
||
def save_to_db(conn, mp_id, sample, total_token):
|
||
cursor = conn.cursor()
|
||
cursor.execute('''
|
||
INSERT OR REPLACE INTO conversations (mp_id, sample, token_num)
|
||
VALUES (?, ?, ?)
|
||
''', (mp_id, str(sample), total_token))
|
||
conn.commit()
|
||
|
||
|
||
def read_cif_txt_file(file_path):
|
||
"""Read the markdown file and return its content."""
|
||
try:
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
return f.read()
|
||
except Exception as e:
|
||
print(f"Error reading file {file_path}: {e}")
|
||
return None
|
||
|
||
def round_values(data, precision=3):
|
||
"""
|
||
递归地将字典中的所有值保留三位小数
|
||
"""
|
||
if isinstance(data, dict): # 如果是字典
|
||
return {key: round_values(value) for key, value in data.items()}
|
||
elif isinstance(data, list): # 如果是列表,递归处理每个元素
|
||
return [round_values(item) for item in data]
|
||
elif isinstance(data, (int, float)): # 如果是数字,保留三位小数
|
||
return round(data, precision)
|
||
else: # 对其他类型,直接返回
|
||
return data
|
||
|
||
|
||
def decimal_to_fraction(decimal_value, max_denominator=1000):
|
||
"""
|
||
将小数转换为分数表示
|
||
|
||
参数:
|
||
decimal_value: 要转换的小数
|
||
max_denominator: 分母的最大值,用于控制精度
|
||
|
||
返回:
|
||
分数表示的字符串
|
||
"""
|
||
frac = Fraction(decimal_value).limit_denominator(max_denominator)
|
||
return f"{frac.numerator}/{frac.denominator}"
|
||
|
||
def poscar_to_fractional_representation(poscar_content, max_denominator=1000):
|
||
"""
|
||
将POSCAR文件中的数值转换为分数表示
|
||
|
||
参数:
|
||
poscar_content: POSCAR文件内容
|
||
max_denominator: 分母的最大值,用于控制精度
|
||
|
||
返回:
|
||
转换后的POSCAR内容,数值以分数表示
|
||
"""
|
||
lines = poscar_content.strip().split('\n')
|
||
result_lines = []
|
||
|
||
# 保留系统名称
|
||
result_lines.append(lines[0])
|
||
|
||
# 保留缩放因子
|
||
scaling_factor = float(lines[1])
|
||
result_lines.append(lines[1])
|
||
|
||
# 处理晶格向量
|
||
for i in range(2, 5):
|
||
vector = [float(x) for x in lines[i].split()]
|
||
# 将每个分量转换为分数
|
||
fractional_vector = [decimal_to_fraction(x, max_denominator) for x in vector]
|
||
result_lines.append(" " + " ".join(fractional_vector))
|
||
|
||
# 保留元素类型和数量
|
||
if len(lines) > 5:
|
||
result_lines.append(lines[5])
|
||
if len(lines) > 6:
|
||
result_lines.append(lines[6])
|
||
|
||
# 保留坐标类型
|
||
if len(lines) > 7:
|
||
result_lines.append(lines[7])
|
||
|
||
# 处理原子坐标
|
||
for i in range(8, len(lines)):
|
||
parts = lines[i].split()
|
||
if len(parts) >= 3:
|
||
# 将坐标转换为分数
|
||
coords = [float(parts[j]) for j in range(3)]
|
||
fractional_coords = [decimal_to_fraction(x, max_denominator) for x in coords]
|
||
|
||
# 构建新行
|
||
new_line = " " + " ".join(fractional_coords)
|
||
if len(parts) > 3:
|
||
new_line += " " + " ".join(parts[3:])
|
||
result_lines.append(new_line)
|
||
else:
|
||
# 保留非坐标行
|
||
result_lines.append(lines[i])
|
||
|
||
return "\n".join(result_lines)
|
||
|
||
|
||
def remove_symmetry_equiv_xyz(cif_content):
|
||
"""
|
||
删除CIF文件中的对称性操作部分
|
||
|
||
参数:
|
||
cif_content: CIF文件内容字符串
|
||
|
||
返回:
|
||
清理后的CIF内容字符串
|
||
"""
|
||
lines = cif_content.split('\n')
|
||
output_lines = []
|
||
|
||
i = 0
|
||
while i < len(lines):
|
||
line = lines[i].strip()
|
||
|
||
# 检测循环开始
|
||
if line == 'loop_':
|
||
# 查看下一行,检查是否是对称性循环
|
||
next_lines = []
|
||
j = i + 1
|
||
while j < len(lines) and lines[j].strip().startswith('_'):
|
||
next_lines.append(lines[j].strip())
|
||
j += 1
|
||
|
||
# 检查是否包含对称性操作标签
|
||
if any('_symmetry_equiv_pos_as_xyz' in tag for tag in next_lines):
|
||
# 跳过整个循环块
|
||
while i < len(lines):
|
||
if i + 1 >= len(lines):
|
||
break
|
||
|
||
next_line = lines[i + 1].strip()
|
||
# 检查是否到达下一个循环或数据块
|
||
if next_line == 'loop_' or next_line.startswith('data_'):
|
||
break
|
||
|
||
# 检查是否到达原子位置部分
|
||
if next_line.startswith('_atom_site_'):
|
||
break
|
||
|
||
i += 1
|
||
else:
|
||
# 不是对称性循环,保留loop_行
|
||
output_lines.append(lines[i])
|
||
else:
|
||
# 非循环开始行,直接保留
|
||
output_lines.append(lines[i])
|
||
|
||
i += 1
|
||
|
||
return '\n'.join(output_lines)
|
||
|
||
|
||
|
||
def remove_null_values(d):
|
||
"""
|
||
Recursively remove key-value pairs with null (None) values from a dictionary.
|
||
|
||
Args:
|
||
d (dict): The dictionary to clean.
|
||
|
||
Returns:
|
||
dict: A new dictionary without null values.
|
||
"""
|
||
if not isinstance(d, dict):
|
||
raise ValueError("Input must be a dictionary")
|
||
_d = copy.deepcopy(d)
|
||
|
||
def recursive_remove(d):
|
||
cleaned_dict = {}
|
||
for key, value in d.items():
|
||
if isinstance(value, dict):
|
||
# Recursively clean nested dictionaries
|
||
nested_cleaned = recursive_remove(value)
|
||
if nested_cleaned: # Only add non-empty dictionaries
|
||
cleaned_dict[key] = nested_cleaned
|
||
elif value is not None and key != 'version':
|
||
cleaned_dict[key] = value
|
||
|
||
return cleaned_dict
|
||
|
||
clean_dict = recursive_remove(d)
|
||
if _d['cbm'] is None and _d['vbm'] is None and _d['band_gap'] is not None:
|
||
# clean_dict['band_gap'] = None
|
||
clean_dict.pop('band_gap')
|
||
return clean_dict
|
||
|
||
|
||
def get_extra_cif_info(path: str, fields_name: list):
|
||
"""Extract specific fields from the CIF description."""
|
||
basic_fields = ['formula_pretty', 'chemsys', 'composition', 'elements', 'symmetry', 'nsites', 'volume', 'density']
|
||
energy_electronic_fields = ['formation_energy_per_atom', 'energy_above_hull', 'is_stable', 'efermi', 'cbm', 'vbm', 'band_gap', 'is_gap_direct']
|
||
metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'num_magnetic_sites']
|
||
# metal_magentic_fields = ['is_metal', 'is_magnetic', "ordering", 'total_magnetization', 'total_magnetization_normalized_vol', 'total_magnetization_normalized_formula_units', 'num_magnetic_sites', 'num_unique_magnetic_sites', 'types_of_magnetic_species', "decomposes_to"]
|
||
|
||
selected_fields = []
|
||
if fields_name[0] == 'all_fields':
|
||
selected_fields = basic_fields + energy_electronic_fields + metal_magentic_fields
|
||
# selected_fields = energy_electronic_fields + metal_magentic_fields
|
||
else:
|
||
for field in fields_name:
|
||
selected_fields.extend(locals().get(field, []))
|
||
|
||
with open(path, 'r') as f:
|
||
docs = json.load(f)
|
||
|
||
new_docs = {}
|
||
for field_name in selected_fields:
|
||
new_docs[field_name] = docs.get(field_name, '')
|
||
|
||
# new_docs['structure'] = {"lattice": docs['structure']['lattice']}
|
||
return new_docs
|
||
|
||
def extract_json(text):
|
||
"""Extract JSON content from a block of text using regex."""
|
||
json_pattern = re.compile(r'\\{(?:[^{}]|(?R))*\\}')
|
||
matches = json_pattern.search(text)
|
||
if matches:
|
||
json_str = matches.group(0)
|
||
try:
|
||
return json.loads(json_str)
|
||
except json.JSONDecodeError:
|
||
return None
|
||
return None
|
||
|
||
def extract_and_parse_json(response):
|
||
"""Extract and parse JSON from a response."""
|
||
json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', response)
|
||
json_str = json_match.group(1) if json_match else response.strip()
|
||
json_str = re.sub(r'(\$[^\$]*\$)', lambda m: m.group(1).replace('\\', '\\\\'), json_str)
|
||
json_str = json_str.replace('\\"', '"').replace("\\'", "'")
|
||
try:
|
||
return json.loads(json_str)
|
||
except json.JSONDecodeError as e:
|
||
print(f"JSON parse error: {e}")
|
||
return 'errformat'
|
||
|
||
|
||
# 计算输入消息的tokens
|
||
def count_message_tokens(messages, model_name):
|
||
encoding = tiktoken.encoding_for_model(model_name)
|
||
num_tokens = 0
|
||
|
||
num_tokens += len(encoding.encode(messages))
|
||
|
||
return num_tokens
|
||
|
||
def make_multi_turns_sharegpt_sample(humans: list[str], gpts: list[str], system: str="{SYSTEM}"):
|
||
sample = {}
|
||
conversations = []
|
||
|
||
if system is not None and system != "":
|
||
sample["system"] = system
|
||
|
||
assert len(humans) !=0, "human cannot be None"
|
||
assert len(gpts) == len(humans), "human and gpt must have the same length"
|
||
|
||
for human, gpt in zip(humans, gpts):
|
||
if human is not None and human != "":
|
||
assert gpt is not None, "gpt cannot be None"
|
||
assert gpt != "", "gpt cannot be empty"
|
||
# 下列顺序不可改
|
||
conversations.append({"from": "human", "value": human})
|
||
conversations.append({"from": "gpt", "value": gpt})
|
||
|
||
sample["conversations"] = conversations
|
||
return sample
|
||
|
||
|
||
|
||
##################################### utils
|