CoACT initialize (#292)
This commit is contained in:
@@ -136,6 +136,82 @@ class PythonController:
|
|||||||
|
|
||||||
logger.error("Failed to execute command.")
|
logger.error("Failed to execute command.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def run_python_script(self, script: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Executes a python script on the server.
|
||||||
|
"""
|
||||||
|
payload = json.dumps({"code": script})
|
||||||
|
|
||||||
|
for _ in range(self.retry_times):
|
||||||
|
try:
|
||||||
|
response = requests.post(self.http_server + "/run_python", headers={'Content-Type': 'application/json'},
|
||||||
|
data=payload, timeout=90)
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response.json()
|
||||||
|
else:
|
||||||
|
return {"status": "error", "message": "Failed to execute command.", "output": None, "error": response.json()["error"]}
|
||||||
|
except requests.exceptions.ReadTimeout:
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
logger.error("An error occurred while trying to execute the command: %s", traceback.format_exc())
|
||||||
|
logger.info("Retrying to execute command.")
|
||||||
|
time.sleep(self.retry_interval)
|
||||||
|
|
||||||
|
logger.error("Failed to execute command.")
|
||||||
|
return {"status": "error", "message": "Failed to execute command.", "output": "", "error": "Retry limit reached."}
|
||||||
|
|
||||||
|
def run_bash_script(self, script: str, timeout: int = 30, working_dir: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Executes a bash script on the server.
|
||||||
|
|
||||||
|
:param script: The bash script content (can be multi-line)
|
||||||
|
:param timeout: Execution timeout in seconds (default: 30)
|
||||||
|
:param working_dir: Working directory for script execution (optional)
|
||||||
|
:return: Dictionary with status, output, error, and returncode, or None if failed
|
||||||
|
"""
|
||||||
|
payload = json.dumps({
|
||||||
|
"script": script,
|
||||||
|
"timeout": timeout,
|
||||||
|
"working_dir": working_dir
|
||||||
|
})
|
||||||
|
|
||||||
|
for _ in range(self.retry_times):
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
self.http_server + "/run_bash_script",
|
||||||
|
headers={'Content-Type': 'application/json'},
|
||||||
|
data=payload,
|
||||||
|
timeout=timeout + 100 # Add buffer to HTTP timeout
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
result = response.json()
|
||||||
|
logger.info("Bash script executed successfully with return code: %d", result.get("returncode", -1))
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
logger.error("Failed to execute bash script. Status code: %d, response: %s",
|
||||||
|
response.status_code, response.text)
|
||||||
|
logger.info("Retrying to execute bash script.")
|
||||||
|
except requests.exceptions.ReadTimeout:
|
||||||
|
logger.error("Bash script execution timed out")
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"output": "",
|
||||||
|
"error": f"Script execution timed out after {timeout} seconds",
|
||||||
|
"returncode": -1
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("An error occurred while trying to execute the bash script: %s", e)
|
||||||
|
logger.info("Retrying to execute bash script.")
|
||||||
|
time.sleep(self.retry_interval)
|
||||||
|
|
||||||
|
logger.error("Failed to execute bash script after %d retries.", self.retry_times)
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"output": "",
|
||||||
|
"error": f"Failed to execute bash script after {self.retry_times} retries",
|
||||||
|
"returncode": -1
|
||||||
|
}
|
||||||
|
|
||||||
def execute_action(self, action: Dict[str, Any]):
|
def execute_action(self, action: Dict[str, Any]):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1568,5 +1568,230 @@ def end_recording():
|
|||||||
return abort(500, description=f"Recording failed. The output file is missing or empty. ffmpeg stderr: {error_output}")
|
return abort(500, description=f"Recording failed. The output file is missing or empty. ffmpeg stderr: {error_output}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/run_python", methods=['POST'])
|
||||||
|
def run_python():
|
||||||
|
data = request.json
|
||||||
|
code = data.get('code', None)
|
||||||
|
|
||||||
|
if not code:
|
||||||
|
return jsonify({'status': 'error', 'message': 'Code not supplied!'}), 400
|
||||||
|
|
||||||
|
# Create a temporary file to save the Python code
|
||||||
|
import tempfile
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
# Generate unique filename
|
||||||
|
temp_filename = f"/tmp/python_exec_{uuid.uuid4().hex}.py"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Write code to temporary file
|
||||||
|
with open(temp_filename, 'w') as f:
|
||||||
|
f.write(code)
|
||||||
|
|
||||||
|
# Execute the file using subprocess to capture all output
|
||||||
|
result = subprocess.run(
|
||||||
|
['/usr/bin/python3', temp_filename],
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
text=True,
|
||||||
|
timeout=30 # 30 second timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up the temporary file
|
||||||
|
try:
|
||||||
|
os.remove(temp_filename)
|
||||||
|
except:
|
||||||
|
pass # Ignore cleanup errors
|
||||||
|
|
||||||
|
# Prepare response
|
||||||
|
output = result.stdout
|
||||||
|
error_output = result.stderr
|
||||||
|
|
||||||
|
# Combine output and errors if both exist
|
||||||
|
combined_message = output
|
||||||
|
if error_output:
|
||||||
|
combined_message += ('\n' + error_output) if output else error_output
|
||||||
|
|
||||||
|
# Determine status based on return code and errors
|
||||||
|
if result.returncode != 0:
|
||||||
|
status = 'error'
|
||||||
|
if not error_output:
|
||||||
|
# If no stderr but non-zero return code, add a generic error message
|
||||||
|
error_output = f"Process exited with code {result.returncode}"
|
||||||
|
combined_message = combined_message + '\n' + error_output if combined_message else error_output
|
||||||
|
else:
|
||||||
|
status = 'success'
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'status': status,
|
||||||
|
'message': combined_message,
|
||||||
|
'need_more': False, # Not applicable for file execution
|
||||||
|
'output': output, # stdout only
|
||||||
|
'error': error_output, # stderr only
|
||||||
|
'return_code': result.returncode
|
||||||
|
})
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
# Clean up the temporary file on timeout
|
||||||
|
try:
|
||||||
|
os.remove(temp_filename)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'status': 'error',
|
||||||
|
'message': 'Execution timeout: Code took too long to execute',
|
||||||
|
'error': 'TimeoutExpired',
|
||||||
|
'need_more': False,
|
||||||
|
'output': None,
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Clean up the temporary file on error
|
||||||
|
try:
|
||||||
|
os.remove(temp_filename)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Capture the exception details
|
||||||
|
return jsonify({
|
||||||
|
'status': 'error',
|
||||||
|
'message': f'Execution error: {str(e)}',
|
||||||
|
'error': traceback.format_exc(),
|
||||||
|
'need_more': False,
|
||||||
|
'output': None,
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/run_bash_script", methods=['POST'])
|
||||||
|
def run_bash_script():
|
||||||
|
data = request.json
|
||||||
|
script = data.get('script', None)
|
||||||
|
timeout = data.get('timeout', 100) # Default timeout of 30 seconds
|
||||||
|
working_dir = data.get('working_dir', None)
|
||||||
|
|
||||||
|
if not script:
|
||||||
|
return jsonify({
|
||||||
|
'status': 'error',
|
||||||
|
'output': 'Script not supplied!',
|
||||||
|
'error': "", # Always empty as requested
|
||||||
|
'returncode': -1
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
# Expand user directory if provided
|
||||||
|
if working_dir:
|
||||||
|
working_dir = os.path.expanduser(working_dir)
|
||||||
|
if not os.path.exists(working_dir):
|
||||||
|
return jsonify({
|
||||||
|
'status': 'error',
|
||||||
|
'output': f'Working directory does not exist: {working_dir}',
|
||||||
|
'error': "", # Always empty as requested
|
||||||
|
'returncode': -1
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
# Create a temporary script file
|
||||||
|
import tempfile
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.sh', delete=False) as tmp_file:
|
||||||
|
if "#!/bin/bash" not in script:
|
||||||
|
script = "#!/bin/bash\n\n" + script
|
||||||
|
tmp_file.write(script)
|
||||||
|
tmp_file_path = tmp_file.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Make the script executable
|
||||||
|
os.chmod(tmp_file_path, 0o755)
|
||||||
|
|
||||||
|
# Execute the script
|
||||||
|
if platform_name == "Windows":
|
||||||
|
# On Windows, use Git Bash or WSL if available, otherwise cmd
|
||||||
|
flags = subprocess.CREATE_NO_WINDOW
|
||||||
|
# Try to use bash if available (Git Bash, WSL, etc.)
|
||||||
|
result = subprocess.run(
|
||||||
|
['bash', tmp_file_path],
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT, # Merge stderr into stdout
|
||||||
|
text=True,
|
||||||
|
timeout=timeout,
|
||||||
|
cwd=working_dir,
|
||||||
|
creationflags=flags,
|
||||||
|
shell=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# On Unix-like systems, use bash directly
|
||||||
|
flags = 0
|
||||||
|
result = subprocess.run(
|
||||||
|
['/bin/bash', tmp_file_path],
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT, # Merge stderr into stdout
|
||||||
|
text=True,
|
||||||
|
timeout=timeout,
|
||||||
|
cwd=working_dir,
|
||||||
|
creationflags=flags,
|
||||||
|
shell=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log the command execution for trajectory recording
|
||||||
|
_append_event("BashScript",
|
||||||
|
{"script": script, "output": result.stdout, "error": "", "returncode": result.returncode},
|
||||||
|
ts=time.time())
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'status': 'success' if result.returncode == 0 else 'error',
|
||||||
|
'output': result.stdout, # Contains both stdout and stderr merged
|
||||||
|
'error': "", # Always empty as requested
|
||||||
|
'returncode': result.returncode
|
||||||
|
})
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
return jsonify({
|
||||||
|
'status': 'error',
|
||||||
|
'output': f'Script execution timed out after {timeout} seconds',
|
||||||
|
'error': "", # Always empty as requested
|
||||||
|
'returncode': -1
|
||||||
|
}), 500
|
||||||
|
except FileNotFoundError:
|
||||||
|
# Bash not found, try with sh
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
['sh', tmp_file_path],
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT, # Merge stderr into stdout
|
||||||
|
text=True,
|
||||||
|
timeout=timeout,
|
||||||
|
cwd=working_dir,
|
||||||
|
shell=False
|
||||||
|
)
|
||||||
|
|
||||||
|
_append_event("BashScript",
|
||||||
|
{"script": script, "output": result.stdout, "error": "", "returncode": result.returncode},
|
||||||
|
ts=time.time())
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'status': 'success' if result.returncode == 0 else 'error',
|
||||||
|
'output': result.stdout, # Contains both stdout and stderr merged
|
||||||
|
'error': "", # Always empty as requested
|
||||||
|
'returncode': result.returncode,
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({
|
||||||
|
'status': 'error',
|
||||||
|
'output': f'Failed to execute script: {str(e)}',
|
||||||
|
'error': "", # Always empty as requested
|
||||||
|
'returncode': -1
|
||||||
|
}), 500
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({
|
||||||
|
'status': 'error',
|
||||||
|
'output': f'Failed to execute script: {str(e)}',
|
||||||
|
'error': "", # Always empty as requested
|
||||||
|
'returncode': -1
|
||||||
|
}), 500
|
||||||
|
finally:
|
||||||
|
# Clean up the temporary file
|
||||||
|
try:
|
||||||
|
os.unlink(tmp_file_path)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
app.run(debug=True, host="0.0.0.0")
|
app.run(debug=True, host="0.0.0.0")
|
||||||
|
|||||||
27
mm_agents/coact/OAI_CONFIG_LIST
Normal file
27
mm_agents/coact/OAI_CONFIG_LIST
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"api_key": "KEY",
|
||||||
|
"tags": ["gpt-4o", "code", "explainer"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model": "o3",
|
||||||
|
"api_key": "KEY",
|
||||||
|
"tags": ["o3", "coding", "explainer"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model": "gpt-4.1",
|
||||||
|
"api_key": "KEY",
|
||||||
|
"tags": ["gpt-4.1", "coding", "explainer"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model": "o4-mini",
|
||||||
|
"api_key": "KEY",
|
||||||
|
"tags": ["o4-mini", "coding", "explainer"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model": "o3-mini",
|
||||||
|
"api_key": "KEY",
|
||||||
|
"tags": ["o3-mini", "coding", "explainer"]
|
||||||
|
}
|
||||||
|
]
|
||||||
0
mm_agents/coact/__init__.py
Normal file
0
mm_agents/coact/__init__.py
Normal file
81
mm_agents/coact/autogen/__init__.py
Normal file
81
mm_agents/coact/autogen/__init__.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from .agentchat import (
|
||||||
|
Agent,
|
||||||
|
AssistantAgent,
|
||||||
|
ChatResult,
|
||||||
|
ConversableAgent,
|
||||||
|
GroupChat,
|
||||||
|
GroupChatManager,
|
||||||
|
UpdateSystemMessage,
|
||||||
|
UserProxyAgent,
|
||||||
|
gather_usage_summary,
|
||||||
|
initiate_chats,
|
||||||
|
register_function,
|
||||||
|
)
|
||||||
|
from .agentchat.group.context_expression import ContextExpression
|
||||||
|
from .code_utils import DEFAULT_MODEL, FAST_MODEL
|
||||||
|
from .exception_utils import (
|
||||||
|
AgentNameConflictError,
|
||||||
|
InvalidCarryOverTypeError,
|
||||||
|
NoEligibleSpeakerError,
|
||||||
|
SenderRequiredError,
|
||||||
|
UndefinedNextAgentError,
|
||||||
|
)
|
||||||
|
from .llm_config import LLMConfig
|
||||||
|
from .oai import (
|
||||||
|
Cache,
|
||||||
|
ModelClient,
|
||||||
|
OpenAIWrapper,
|
||||||
|
config_list_from_dotenv,
|
||||||
|
config_list_from_json,
|
||||||
|
config_list_from_models,
|
||||||
|
config_list_gpt4_gpt35,
|
||||||
|
config_list_openai_aoai,
|
||||||
|
filter_config,
|
||||||
|
get_config_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the root logger.
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DEFAULT_MODEL",
|
||||||
|
"FAST_MODEL",
|
||||||
|
"Agent",
|
||||||
|
"AgentNameConflictError",
|
||||||
|
"AssistantAgent",
|
||||||
|
"Cache",
|
||||||
|
"ChatResult",
|
||||||
|
"ContextExpression",
|
||||||
|
"ConversableAgent",
|
||||||
|
"GroupChat",
|
||||||
|
"GroupChatManager",
|
||||||
|
"InvalidCarryOverTypeError",
|
||||||
|
"LLMConfig",
|
||||||
|
"ModelClient",
|
||||||
|
"NoEligibleSpeakerError",
|
||||||
|
"OpenAIWrapper",
|
||||||
|
"SenderRequiredError",
|
||||||
|
"UndefinedNextAgentError",
|
||||||
|
"UpdateSystemMessage",
|
||||||
|
"UserProxyAgent",
|
||||||
|
"config_list_from_dotenv",
|
||||||
|
"config_list_from_json",
|
||||||
|
"config_list_from_models",
|
||||||
|
"config_list_gpt4_gpt35",
|
||||||
|
"config_list_openai_aoai",
|
||||||
|
"filter_config",
|
||||||
|
"gather_usage_summary",
|
||||||
|
"get_config_list",
|
||||||
|
"initiate_chats",
|
||||||
|
"register_function",
|
||||||
|
]
|
||||||
38
mm_agents/coact/autogen/agentchat/__init__.py
Normal file
38
mm_agents/coact/autogen/agentchat/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
from .agent import Agent, LLMAgent
|
||||||
|
from .assistant_agent import AssistantAgent
|
||||||
|
from .chat import ChatResult, a_initiate_chats, initiate_chats
|
||||||
|
|
||||||
|
from .conversable_agent import ConversableAgent, UpdateSystemMessage, register_function
|
||||||
|
from .group.multi_agent_chat import a_initiate_group_chat, a_run_group_chat, initiate_group_chat, run_group_chat
|
||||||
|
from .groupchat import GroupChat, GroupChatManager
|
||||||
|
from .user_proxy_agent import UserProxyAgent
|
||||||
|
from .utils import gather_usage_summary
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Agent",
|
||||||
|
"AssistantAgent",
|
||||||
|
"ChatResult",
|
||||||
|
"ConversableAgent",
|
||||||
|
"GroupChat",
|
||||||
|
"GroupChatManager",
|
||||||
|
"LLMAgent",
|
||||||
|
"UpdateSystemMessage",
|
||||||
|
"UserProxyAgent",
|
||||||
|
"a_initiate_chats",
|
||||||
|
"a_initiate_group_chat",
|
||||||
|
"a_initiate_swarm_chat",
|
||||||
|
"a_run_group_chat",
|
||||||
|
"a_run_swarm",
|
||||||
|
"gather_usage_summary",
|
||||||
|
"initiate_chats",
|
||||||
|
"initiate_group_chat",
|
||||||
|
"register_function",
|
||||||
|
"run_group_chat",
|
||||||
|
"run_swarm",
|
||||||
|
]
|
||||||
182
mm_agents/coact/autogen/agentchat/agent.py
Normal file
182
mm_agents/coact/autogen/agentchat/agent.py
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional, Protocol, TypeVar, Union, runtime_checkable
|
||||||
|
|
||||||
|
from ..doc_utils import export_module
|
||||||
|
|
||||||
|
__all__ = ["Agent", "LLMAgent", "LLMMessageType"]
|
||||||
|
|
||||||
|
Tool = TypeVar("Tool")
|
||||||
|
|
||||||
|
LLMMessageType = dict[str, Any]
|
||||||
|
|
||||||
|
DEFAULT_SUMMARY_METHOD = "last_msg"
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
@export_module("autogen")
|
||||||
|
class Agent(Protocol):
|
||||||
|
"""(In preview) A protocol for Agent.
|
||||||
|
|
||||||
|
An agent can communicate with other agents and perform actions.
|
||||||
|
Different agents can differ in what actions they perform in the `receive` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""The name of the agent."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
"""The description of the agent. Used for the agent's introduction in
|
||||||
|
a group chat setting.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def send(
|
||||||
|
self,
|
||||||
|
message: Union[dict[str, Any], str],
|
||||||
|
recipient: "Agent",
|
||||||
|
request_reply: Optional[bool] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Send a message to another agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (dict or str): the message to send. If a dict, it should be
|
||||||
|
a JSON-serializable and follows the OpenAI's ChatCompletion schema.
|
||||||
|
recipient (Agent): the recipient of the message.
|
||||||
|
request_reply (bool): whether to request a reply from the recipient.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def a_send(
|
||||||
|
self,
|
||||||
|
message: Union[dict[str, Any], str],
|
||||||
|
recipient: "Agent",
|
||||||
|
request_reply: Optional[bool] = None,
|
||||||
|
) -> None:
|
||||||
|
"""(Async) Send a message to another agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (dict or str): the message to send. If a dict, it should be
|
||||||
|
a JSON-serializable and follows the OpenAI's ChatCompletion schema.
|
||||||
|
recipient (Agent): the recipient of the message.
|
||||||
|
request_reply (bool): whether to request a reply from the recipient.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def receive(
|
||||||
|
self,
|
||||||
|
message: Union[dict[str, Any], str],
|
||||||
|
sender: "Agent",
|
||||||
|
request_reply: Optional[bool] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Receive a message from another agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (dict or str): the message received. If a dict, it should be
|
||||||
|
a JSON-serializable and follows the OpenAI's ChatCompletion schema.
|
||||||
|
sender (Agent): the sender of the message.
|
||||||
|
request_reply (bool): whether the sender requests a reply.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def a_receive(
|
||||||
|
self,
|
||||||
|
message: Union[dict[str, Any], str],
|
||||||
|
sender: "Agent",
|
||||||
|
request_reply: Optional[bool] = None,
|
||||||
|
) -> None:
|
||||||
|
"""(Async) Receive a message from another agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (dict or str): the message received. If a dict, it should be
|
||||||
|
a JSON-serializable and follows the OpenAI's ChatCompletion schema.
|
||||||
|
sender (Agent): the sender of the message.
|
||||||
|
request_reply (bool): whether the sender requests a reply.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def generate_reply(
|
||||||
|
self,
|
||||||
|
messages: Optional[list[dict[str, Any]]] = None,
|
||||||
|
sender: Optional["Agent"] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[str, dict[str, Any], None]:
|
||||||
|
"""Generate a reply based on the received messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (list[dict[str, Any]]): a list of messages received from other agents.
|
||||||
|
The messages are dictionaries that are JSON-serializable and
|
||||||
|
follows the OpenAI's ChatCompletion schema.
|
||||||
|
sender: sender of an Agent instance.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str or dict or None: the generated reply. If None, no reply is generated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def a_generate_reply(
|
||||||
|
self,
|
||||||
|
messages: Optional[list[dict[str, Any]]] = None,
|
||||||
|
sender: Optional["Agent"] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[str, dict[str, Any], None]:
|
||||||
|
"""(Async) Generate a reply based on the received messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (list[dict[str, Any]]): a list of messages received from other agents.
|
||||||
|
The messages are dictionaries that are JSON-serializable and
|
||||||
|
follows the OpenAI's ChatCompletion schema.
|
||||||
|
sender: sender of an Agent instance.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str or dict or None: the generated reply. If None, no reply is generated.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def set_ui_tools(self, tools: list[Tool]) -> None:
|
||||||
|
"""Set the UI tools for the agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools: a list of UI tools to set.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def unset_ui_tools(self, tools: list[Tool]) -> None:
|
||||||
|
"""Unset the UI tools for the agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools: a list of UI tools to set.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
@export_module("autogen")
|
||||||
|
class LLMAgent(Agent, Protocol):
|
||||||
|
"""(In preview) A protocol for an LLM agent."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def system_message(self) -> str:
|
||||||
|
"""The system message of this agent."""
|
||||||
|
|
||||||
|
def update_system_message(self, system_message: str) -> None:
|
||||||
|
"""Update this agent's system message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_message (str): system message for inference.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
# mypy will fail if Conversable agent does not implement Agent protocol
|
||||||
|
from .conversable_agent import ConversableAgent
|
||||||
|
|
||||||
|
def _check_protocol_implementation(agent: ConversableAgent) -> Agent:
|
||||||
|
return agent
|
||||||
85
mm_agents/coact/autogen/agentchat/assistant_agent.py
Normal file
85
mm_agents/coact/autogen/agentchat/assistant_agent.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
from typing import Any, Callable, Literal, Optional, Union
|
||||||
|
|
||||||
|
from ..doc_utils import export_module
|
||||||
|
from ..llm_config import LLMConfig
|
||||||
|
from ..runtime_logging import log_new_agent, logging_enabled
|
||||||
|
from .conversable_agent import ConversableAgent
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen")
|
||||||
|
class AssistantAgent(ConversableAgent):
|
||||||
|
"""(In preview) Assistant agent, designed to solve a task with LLM.
|
||||||
|
|
||||||
|
AssistantAgent is a subclass of ConversableAgent configured with a default system message.
|
||||||
|
The default system message is designed to solve a task with LLM,
|
||||||
|
including suggesting python code blocks and debugging.
|
||||||
|
`human_input_mode` is default to "NEVER"
|
||||||
|
and `code_execution_config` is default to False.
|
||||||
|
This agent doesn't execute code by default, and expects the user to execute the code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_SYSTEM_MESSAGE = """You are a helpful AI assistant.
|
||||||
|
Solve tasks using your coding and language skills.
|
||||||
|
In the following cases, suggest python code (in a python coding block) or shell script (in a sh coding block) for the user to execute.
|
||||||
|
1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time, check the operating system. After sufficient info is printed and the task is ready to be solved based on your language skill, you can solve the task by yourself.
|
||||||
|
2. When you need to perform some task with code, use the code to perform the task and output the result. Finish the task smartly.
|
||||||
|
Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill.
|
||||||
|
When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user.
|
||||||
|
If you want the user to save the code in a file before executing it, put # filename: <filename> inside the code block as the first line. Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user.
|
||||||
|
If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try.
|
||||||
|
When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.
|
||||||
|
Reply "TERMINATE" in the end when everything is done.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_DESCRIPTION = "A helpful and general-purpose AI assistant that has strong language skills, Python skills, and Linux command line skills."
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
|
||||||
|
llm_config: Optional[Union[LLMConfig, dict[str, Any], Literal[False]]] = None,
|
||||||
|
is_termination_msg: Optional[Callable[[dict[str, Any]], bool]] = None,
|
||||||
|
max_consecutive_auto_reply: Optional[int] = None,
|
||||||
|
human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER",
|
||||||
|
description: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
"""Args:
|
||||||
|
name (str): agent name.
|
||||||
|
system_message (str): system message for the ChatCompletion inference.
|
||||||
|
Please override this attribute if you want to reprogram the agent.
|
||||||
|
llm_config (dict or False or None): llm inference configuration.
|
||||||
|
Please refer to [OpenAIWrapper.create](https://docs.ag2.ai/latest/docs/api-reference/autogen/OpenAIWrapper/#autogen.OpenAIWrapper.create)
|
||||||
|
for available options.
|
||||||
|
is_termination_msg (function): a function that takes a message in the form of a dictionary
|
||||||
|
and returns a boolean value indicating if this received message is a termination message.
|
||||||
|
The dict can contain the following keys: "content", "role", "name", "function_call".
|
||||||
|
max_consecutive_auto_reply (int): the maximum number of consecutive auto replies.
|
||||||
|
default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case).
|
||||||
|
The limit only plays a role when human_input_mode is not "ALWAYS".
|
||||||
|
**kwargs (dict): Please refer to other kwargs in
|
||||||
|
[ConversableAgent](https://docs.ag2.ai/latest/docs/api-reference/autogen/ConversableAgent).
|
||||||
|
"""
|
||||||
|
super().__init__(
|
||||||
|
name,
|
||||||
|
system_message,
|
||||||
|
is_termination_msg,
|
||||||
|
max_consecutive_auto_reply,
|
||||||
|
human_input_mode,
|
||||||
|
llm_config=llm_config,
|
||||||
|
description=description,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
if logging_enabled():
|
||||||
|
log_new_agent(self, locals())
|
||||||
|
|
||||||
|
# Update the provided description if None, and we are using the default system_message,
|
||||||
|
# then use the default description.
|
||||||
|
if description is None and system_message == self.DEFAULT_SYSTEM_MESSAGE:
|
||||||
|
self.description = self.DEFAULT_DESCRIPTION
|
||||||
309
mm_agents/coact/autogen/agentchat/chat.py
Normal file
309
mm_agents/coact/autogen/agentchat/chat.py
Normal file
@@ -0,0 +1,309 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
import asyncio
|
||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
import warnings
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from ..doc_utils import export_module
|
||||||
|
from ..events.agent_events import PostCarryoverProcessingEvent
|
||||||
|
from ..io.base import IOStream
|
||||||
|
from .utils import consolidate_chat_info
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
Prerequisite = tuple[int, int]
|
||||||
|
|
||||||
|
__all__ = ["ChatResult", "a_initiate_chats", "initiate_chats"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@export_module("autogen")
|
||||||
|
class ChatResult:
|
||||||
|
"""(Experimental) The result of a chat. Almost certain to be changed."""
|
||||||
|
|
||||||
|
chat_id: int = None
|
||||||
|
"""chat id"""
|
||||||
|
chat_history: list[dict[str, Any]] = None
|
||||||
|
"""The chat history."""
|
||||||
|
summary: str = None
|
||||||
|
"""A summary obtained from the chat."""
|
||||||
|
cost: dict[str, dict[str, Any]] = (
|
||||||
|
None # keys: "usage_including_cached_inference", "usage_excluding_cached_inference"
|
||||||
|
)
|
||||||
|
"""The cost of the chat.
|
||||||
|
The value for each usage type is a dictionary containing cost information for that specific type.
|
||||||
|
- "usage_including_cached_inference": Cost information on the total usage, including the tokens in cached inference.
|
||||||
|
- "usage_excluding_cached_inference": Cost information on the usage of tokens, excluding the tokens in cache. No larger than "usage_including_cached_inference".
|
||||||
|
"""
|
||||||
|
human_input: list[str] = None
|
||||||
|
"""A list of human input solicited during the chat."""
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_recipients(chat_queue: list[dict[str, Any]]) -> None:
|
||||||
|
"""Validate recipients exits and warn repetitive recipients."""
|
||||||
|
receipts_set = set()
|
||||||
|
for chat_info in chat_queue:
|
||||||
|
assert "recipient" in chat_info, "recipient must be provided."
|
||||||
|
receipts_set.add(chat_info["recipient"])
|
||||||
|
if len(receipts_set) < len(chat_queue):
|
||||||
|
warnings.warn(
|
||||||
|
"Repetitive recipients detected: The chat history will be cleared by default if a recipient appears more than once. To retain the chat history, please set 'clear_history=False' in the configuration of the repeating agent.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def __create_async_prerequisites(chat_queue: list[dict[str, Any]]) -> list[Prerequisite]:
|
||||||
|
"""Create list of Prerequisite (prerequisite_chat_id, chat_id)"""
|
||||||
|
prerequisites = []
|
||||||
|
for chat_info in chat_queue:
|
||||||
|
if "chat_id" not in chat_info:
|
||||||
|
raise ValueError("Each chat must have a unique id for async multi-chat execution.")
|
||||||
|
chat_id = chat_info["chat_id"]
|
||||||
|
pre_chats = chat_info.get("prerequisites", [])
|
||||||
|
for pre_chat_id in pre_chats:
|
||||||
|
if not isinstance(pre_chat_id, int):
|
||||||
|
raise ValueError("Prerequisite chat id is not int.")
|
||||||
|
prerequisites.append((chat_id, pre_chat_id))
|
||||||
|
return prerequisites
|
||||||
|
|
||||||
|
|
||||||
|
def __find_async_chat_order(chat_ids: set[int], prerequisites: list[Prerequisite]) -> list[int]:
|
||||||
|
"""Find chat order for async execution based on the prerequisite chats
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_ids: A set of all chat IDs that need to be scheduled
|
||||||
|
prerequisites: A list of tuples (chat_id, prerequisite_chat_id) where each tuple indicates that chat_id depends on prerequisite_chat_id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: a list of chat_id in order.
|
||||||
|
"""
|
||||||
|
edges = defaultdict(set)
|
||||||
|
indegree = defaultdict(int)
|
||||||
|
for pair in prerequisites:
|
||||||
|
chat, pre = pair[0], pair[1]
|
||||||
|
if chat not in edges[pre]:
|
||||||
|
indegree[chat] += 1
|
||||||
|
edges[pre].add(chat)
|
||||||
|
bfs = [i for i in chat_ids if i not in indegree]
|
||||||
|
chat_order = []
|
||||||
|
steps = len(indegree)
|
||||||
|
for _ in range(steps + 1):
|
||||||
|
if not bfs:
|
||||||
|
break
|
||||||
|
chat_order.extend(bfs)
|
||||||
|
nxt = []
|
||||||
|
for node in bfs:
|
||||||
|
if node in edges:
|
||||||
|
for course in edges[node]:
|
||||||
|
indegree[course] -= 1
|
||||||
|
if indegree[course] == 0:
|
||||||
|
nxt.append(course)
|
||||||
|
indegree.pop(course)
|
||||||
|
edges.pop(node)
|
||||||
|
bfs = nxt
|
||||||
|
|
||||||
|
if indegree:
|
||||||
|
return []
|
||||||
|
return chat_order
|
||||||
|
|
||||||
|
|
||||||
|
def _post_process_carryover_item(carryover_item):
|
||||||
|
if isinstance(carryover_item, str):
|
||||||
|
return carryover_item
|
||||||
|
elif isinstance(carryover_item, dict) and "content" in carryover_item:
|
||||||
|
return str(carryover_item["content"])
|
||||||
|
else:
|
||||||
|
return str(carryover_item)
|
||||||
|
|
||||||
|
|
||||||
|
def __post_carryover_processing(chat_info: dict[str, Any]) -> None:
|
||||||
|
iostream = IOStream.get_default()
|
||||||
|
|
||||||
|
if "message" not in chat_info:
|
||||||
|
warnings.warn(
|
||||||
|
"message is not provided in a chat_queue entry. input() will be called to get the initial message.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
iostream.send(PostCarryoverProcessingEvent(chat_info=chat_info))
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen")
|
||||||
|
def initiate_chats(chat_queue: list[dict[str, Any]]) -> list[ChatResult]:
|
||||||
|
"""Initiate a list of chats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_queue (List[Dict]): A list of dictionaries containing the information about the chats.
|
||||||
|
|
||||||
|
Each dictionary should contain the input arguments for
|
||||||
|
[`ConversableAgent.initiate_chat`](../ConversableAgent#initiate-chat).
|
||||||
|
For example:
|
||||||
|
- `"sender"` - the sender agent.
|
||||||
|
- `"recipient"` - the recipient agent.
|
||||||
|
- `"clear_history"` (bool) - whether to clear the chat history with the agent.
|
||||||
|
Default is True.
|
||||||
|
- `"silent"` (bool or None) - (Experimental) whether to print the messages in this
|
||||||
|
conversation. Default is False.
|
||||||
|
- `"cache"` (Cache or None) - the cache client to use for this conversation.
|
||||||
|
Default is None.
|
||||||
|
- `"max_turns"` (int or None) - maximum number of turns for the chat. If None, the chat
|
||||||
|
will continue until a termination condition is met. Default is None.
|
||||||
|
- `"summary_method"` (str or callable) - a string or callable specifying the method to get
|
||||||
|
a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg".
|
||||||
|
- `"summary_args"` (dict) - a dictionary of arguments to be passed to the summary_method.
|
||||||
|
Default is {}.
|
||||||
|
- `"message"` (str, callable or None) - if None, input() will be called to get the
|
||||||
|
initial message.
|
||||||
|
- `**context` - additional context information to be passed to the chat.
|
||||||
|
- `"carryover"` - It can be used to specify the carryover information to be passed
|
||||||
|
to this chat. If provided, we will combine this carryover with the "message" content when
|
||||||
|
generating the initial chat message in `generate_init_message`.
|
||||||
|
- `"finished_chat_indexes_to_exclude_from_carryover"` - It can be used by specifying a list of indexes of the finished_chats list,
|
||||||
|
from which to exclude the summaries for carryover. If 'finished_chat_indexes_to_exclude_from_carryover' is not provided or an empty list,
|
||||||
|
then summary from all the finished chats will be taken.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(list): a list of ChatResult objects corresponding to the finished chats in the chat_queue.
|
||||||
|
"""
|
||||||
|
consolidate_chat_info(chat_queue)
|
||||||
|
_validate_recipients(chat_queue)
|
||||||
|
current_chat_queue = chat_queue.copy()
|
||||||
|
finished_chats = []
|
||||||
|
while current_chat_queue:
|
||||||
|
chat_info = current_chat_queue.pop(0)
|
||||||
|
_chat_carryover = chat_info.get("carryover", [])
|
||||||
|
finished_chat_indexes_to_exclude_from_carryover = chat_info.get(
|
||||||
|
"finished_chat_indexes_to_exclude_from_carryover", []
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(_chat_carryover, str):
|
||||||
|
_chat_carryover = [_chat_carryover]
|
||||||
|
chat_info["carryover"] = _chat_carryover + [
|
||||||
|
r.summary for i, r in enumerate(finished_chats) if i not in finished_chat_indexes_to_exclude_from_carryover
|
||||||
|
]
|
||||||
|
|
||||||
|
if not chat_info.get("silent", False):
|
||||||
|
__post_carryover_processing(chat_info)
|
||||||
|
|
||||||
|
sender = chat_info["sender"]
|
||||||
|
chat_res = sender.initiate_chat(**chat_info)
|
||||||
|
finished_chats.append(chat_res)
|
||||||
|
return finished_chats
|
||||||
|
|
||||||
|
|
||||||
|
def __system_now_str():
|
||||||
|
ct = datetime.datetime.now()
|
||||||
|
return f" System time at {ct}. "
|
||||||
|
|
||||||
|
|
||||||
|
def _on_chat_future_done(chat_future: asyncio.Future, chat_id: int):
|
||||||
|
"""Update ChatResult when async Task for Chat is completed."""
|
||||||
|
logger.debug(f"Update chat {chat_id} result on task completion." + __system_now_str())
|
||||||
|
chat_result = chat_future.result()
|
||||||
|
chat_result.chat_id = chat_id
|
||||||
|
|
||||||
|
|
||||||
|
async def _dependent_chat_future(
|
||||||
|
chat_id: int, chat_info: dict[str, Any], prerequisite_chat_futures: dict[int, asyncio.Future]
|
||||||
|
) -> asyncio.Task:
|
||||||
|
"""Create an async Task for each chat."""
|
||||||
|
logger.debug(f"Create Task for chat {chat_id}." + __system_now_str())
|
||||||
|
_chat_carryover = chat_info.get("carryover", [])
|
||||||
|
finished_chat_indexes_to_exclude_from_carryover = chat_info.get(
|
||||||
|
"finished_chat_indexes_to_exclude_from_carryover", []
|
||||||
|
)
|
||||||
|
finished_chats = dict()
|
||||||
|
for chat in prerequisite_chat_futures:
|
||||||
|
chat_future = prerequisite_chat_futures[chat]
|
||||||
|
if chat_future.cancelled():
|
||||||
|
raise RuntimeError(f"Chat {chat} is cancelled.")
|
||||||
|
|
||||||
|
# wait for prerequisite chat results for the new chat carryover
|
||||||
|
finished_chats[chat] = await chat_future
|
||||||
|
|
||||||
|
if isinstance(_chat_carryover, str):
|
||||||
|
_chat_carryover = [_chat_carryover]
|
||||||
|
data = [
|
||||||
|
chat_result.summary
|
||||||
|
for chat_id, chat_result in finished_chats.items()
|
||||||
|
if chat_id not in finished_chat_indexes_to_exclude_from_carryover
|
||||||
|
]
|
||||||
|
chat_info["carryover"] = _chat_carryover + data
|
||||||
|
if not chat_info.get("silent", False):
|
||||||
|
__post_carryover_processing(chat_info)
|
||||||
|
|
||||||
|
sender = chat_info["sender"]
|
||||||
|
chat_res_future = asyncio.create_task(sender.a_initiate_chat(**chat_info))
|
||||||
|
call_back_with_args = partial(_on_chat_future_done, chat_id=chat_id)
|
||||||
|
chat_res_future.add_done_callback(call_back_with_args)
|
||||||
|
logger.debug(f"Task for chat {chat_id} created." + __system_now_str())
|
||||||
|
return chat_res_future
|
||||||
|
|
||||||
|
|
||||||
|
async def a_initiate_chats(chat_queue: list[dict[str, Any]]) -> dict[int, ChatResult]:
|
||||||
|
"""(async) Initiate a list of chats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_queue (List[Dict]): A list of dictionaries containing the information about the chats.
|
||||||
|
|
||||||
|
Each dictionary should contain the input arguments for
|
||||||
|
[`ConversableAgent.initiate_chat`](../../../ConversableAgent#initiate-chat).
|
||||||
|
For example:
|
||||||
|
- `"sender"` - the sender agent.
|
||||||
|
- `"recipient"` - the recipient agent.
|
||||||
|
- `"clear_history"` (bool) - whether to clear the chat history with the agent.
|
||||||
|
Default is True.
|
||||||
|
- `"silent"` (bool or None) - (Experimental) whether to print the messages in this
|
||||||
|
conversation. Default is False.
|
||||||
|
- `"cache"` (Cache or None) - the cache client to use for this conversation.
|
||||||
|
Default is None.
|
||||||
|
- `"max_turns"` (int or None) - maximum number of turns for the chat. If None, the chat
|
||||||
|
will continue until a termination condition is met. Default is None.
|
||||||
|
- `"summary_method"` (str or callable) - a string or callable specifying the method to get
|
||||||
|
a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg".
|
||||||
|
- `"summary_args"` (dict) - a dictionary of arguments to be passed to the summary_method.
|
||||||
|
Default is {}.
|
||||||
|
- `"message"` (str, callable or None) - if None, input() will be called to get the
|
||||||
|
initial message.
|
||||||
|
- `**context` - additional context information to be passed to the chat.
|
||||||
|
- `"carryover"` - It can be used to specify the carryover information to be passed
|
||||||
|
to this chat. If provided, we will combine this carryover with the "message" content when
|
||||||
|
generating the initial chat message in `generate_init_message`.
|
||||||
|
- `"finished_chat_indexes_to_exclude_from_carryover"` - It can be used by specifying a list of indexes of the finished_chats list,
|
||||||
|
from which to exclude the summaries for carryover. If 'finished_chat_indexes_to_exclude_from_carryover' is not provided or an empty list,
|
||||||
|
then summary from all the finished chats will be taken.
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- (Dict): a dict of ChatId: ChatResult corresponding to the finished chats in the chat_queue.
|
||||||
|
"""
|
||||||
|
consolidate_chat_info(chat_queue)
|
||||||
|
_validate_recipients(chat_queue)
|
||||||
|
chat_book = {chat_info["chat_id"]: chat_info for chat_info in chat_queue}
|
||||||
|
num_chats = chat_book.keys()
|
||||||
|
prerequisites = __create_async_prerequisites(chat_queue)
|
||||||
|
chat_order_by_id = __find_async_chat_order(num_chats, prerequisites)
|
||||||
|
finished_chat_futures = dict()
|
||||||
|
for chat_id in chat_order_by_id:
|
||||||
|
chat_info = chat_book[chat_id]
|
||||||
|
prerequisite_chat_ids = chat_info.get("prerequisites", [])
|
||||||
|
pre_chat_futures = dict()
|
||||||
|
for pre_chat_id in prerequisite_chat_ids:
|
||||||
|
pre_chat_future = finished_chat_futures[pre_chat_id]
|
||||||
|
pre_chat_futures[pre_chat_id] = pre_chat_future
|
||||||
|
current_chat_future = await _dependent_chat_future(chat_id, chat_info, pre_chat_futures)
|
||||||
|
finished_chat_futures[chat_id] = current_chat_future
|
||||||
|
await asyncio.gather(*list(finished_chat_futures.values()))
|
||||||
|
finished_chats = dict()
|
||||||
|
for chat in finished_chat_futures:
|
||||||
|
chat_result = finished_chat_futures[chat].result()
|
||||||
|
finished_chats[chat] = chat_result
|
||||||
|
return finished_chats
|
||||||
5
mm_agents/coact/autogen/agentchat/contrib/__init__.py
Normal file
5
mm_agents/coact/autogen/agentchat/contrib/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
__all__: list[str] = []
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
__all__: list[str] = []
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
from ...assistant_agent import ConversableAgent
|
||||||
|
|
||||||
|
|
||||||
|
class AgentCapability:
|
||||||
|
"""Base class for composable capabilities that can be added to an agent."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def add_to_agent(self, agent: ConversableAgent):
|
||||||
|
"""Adds a particular capability to the given agent. Must be implemented by the capability subclass.
|
||||||
|
An implementation will typically call agent.register_hook() one or more times. See teachability.py as an example.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
@@ -0,0 +1,301 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
import re
|
||||||
|
from typing import Any, Literal, Optional, Protocol, Union
|
||||||
|
|
||||||
|
from .... import Agent, ConversableAgent, code_utils
|
||||||
|
from ....cache import AbstractCache
|
||||||
|
from ....import_utils import optional_import_block, require_optional_import
|
||||||
|
from ....llm_config import LLMConfig
|
||||||
|
from .. import img_utils
|
||||||
|
from ..capabilities.agent_capability import AgentCapability
|
||||||
|
from ..text_analyzer_agent import TextAnalyzerAgent
|
||||||
|
|
||||||
|
with optional_import_block():
|
||||||
|
from PIL.Image import Image
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
SYSTEM_MESSAGE = "You've been given the special ability to generate images."
|
||||||
|
DESCRIPTION_MESSAGE = "This agent has the ability to generate images."
|
||||||
|
|
||||||
|
PROMPT_INSTRUCTIONS = """In detail, please summarize the provided prompt to generate the image described in the TEXT.
|
||||||
|
DO NOT include any advice. RESPOND like the following example:
|
||||||
|
EXAMPLE: Blue background, 3D shapes, ...
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGenerator(Protocol):
|
||||||
|
"""This class defines an interface for image generators.
|
||||||
|
|
||||||
|
Concrete implementations of this protocol must provide a `generate_image` method that takes a string prompt as
|
||||||
|
input and returns a PIL Image object.
|
||||||
|
|
||||||
|
NOTE: Current implementation does not allow you to edit a previously existing image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def generate_image(self, prompt: str) -> "Image":
|
||||||
|
"""Generates an image based on the provided prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: A string describing the desired image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A PIL Image object representing the generated image.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the image generation fails.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def cache_key(self, prompt: str) -> str:
|
||||||
|
"""Generates a unique cache key for the given prompt.
|
||||||
|
|
||||||
|
This key can be used to store and retrieve generated images based on the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: A string describing the desired image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A unique string that can be used as a cache key.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@require_optional_import("PIL", "unknown")
|
||||||
|
@require_optional_import("openai>=1.66.2", "openai")
|
||||||
|
class DalleImageGenerator:
|
||||||
|
"""Generates images using OpenAI's DALL-E models.
|
||||||
|
|
||||||
|
This class provides a convenient interface for generating images based on textual prompts using OpenAI's DALL-E
|
||||||
|
models. It allows you to specify the DALL-E model, resolution, quality, and the number of images to generate.
|
||||||
|
|
||||||
|
Note: Current implementation does not allow you to edit a previously existing image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
llm_config: Union[LLMConfig, dict[str, Any]],
|
||||||
|
resolution: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] = "1024x1024",
|
||||||
|
quality: Literal["standard", "hd"] = "standard",
|
||||||
|
num_images: int = 1,
|
||||||
|
):
|
||||||
|
"""Args:
|
||||||
|
llm_config (LLMConfig or dict): llm config, must contain a valid dalle model and OpenAI API key in config_list.
|
||||||
|
resolution (str): The resolution of the image you want to generate. Must be one of "256x256", "512x512", "1024x1024", "1792x1024", "1024x1792".
|
||||||
|
quality (str): The quality of the image you want to generate. Must be one of "standard", "hd".
|
||||||
|
num_images (int): The number of images to generate.
|
||||||
|
"""
|
||||||
|
config_list = llm_config["config_list"]
|
||||||
|
_validate_dalle_model(config_list[0]["model"])
|
||||||
|
_validate_resolution_format(resolution)
|
||||||
|
|
||||||
|
self._model = config_list[0]["model"]
|
||||||
|
self._resolution = resolution
|
||||||
|
self._quality = quality
|
||||||
|
self._num_images = num_images
|
||||||
|
self._dalle_client = OpenAI(api_key=config_list[0]["api_key"])
|
||||||
|
|
||||||
|
def generate_image(self, prompt: str) -> "Image":
|
||||||
|
response = self._dalle_client.images.generate(
|
||||||
|
model=self._model,
|
||||||
|
prompt=prompt,
|
||||||
|
size=self._resolution,
|
||||||
|
quality=self._quality,
|
||||||
|
n=self._num_images,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_url = response.data[0].url
|
||||||
|
if image_url is None:
|
||||||
|
raise ValueError("Failed to generate image.")
|
||||||
|
|
||||||
|
return img_utils.get_pil_image(image_url)
|
||||||
|
|
||||||
|
def cache_key(self, prompt: str) -> str:
|
||||||
|
keys = (prompt, self._model, self._resolution, self._quality, self._num_images)
|
||||||
|
return ",".join([str(k) for k in keys])
|
||||||
|
|
||||||
|
|
||||||
|
@require_optional_import("PIL", "unknown")
|
||||||
|
class ImageGeneration(AgentCapability):
|
||||||
|
"""This capability allows a ConversableAgent to generate images based on the message received from other Agents.
|
||||||
|
|
||||||
|
1. Utilizes a TextAnalyzerAgent to analyze incoming messages to identify requests for image generation and
|
||||||
|
extract relevant details.
|
||||||
|
2. Leverages the provided ImageGenerator (e.g., DalleImageGenerator) to create the image.
|
||||||
|
3. Optionally caches generated images for faster retrieval in future conversations.
|
||||||
|
|
||||||
|
NOTE: This capability increases the token usage of the agent, as it uses TextAnalyzerAgent to analyze every
|
||||||
|
message received by the agent.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
import autogen
|
||||||
|
from autogen.agentchat.contrib.capabilities.image_generation import ImageGeneration
|
||||||
|
|
||||||
|
# Assuming you have llm configs configured for the LLMs you want to use and Dalle.
|
||||||
|
# Create the agent
|
||||||
|
agent = autogen.ConversableAgent(
|
||||||
|
name="dalle", llm_config={...}, max_consecutive_auto_reply=3, human_input_mode="NEVER"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create an ImageGenerator with desired settings
|
||||||
|
dalle_gen = generate_images.DalleImageGenerator(llm_config={...})
|
||||||
|
|
||||||
|
# Add the ImageGeneration capability to the agent
|
||||||
|
agent.add_capability(ImageGeneration(image_generator=dalle_gen))
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_generator: ImageGenerator,
|
||||||
|
cache: Optional[AbstractCache] = None,
|
||||||
|
text_analyzer_llm_config: Optional[Union[LLMConfig, dict[str, Any]]] = None,
|
||||||
|
text_analyzer_instructions: str = PROMPT_INSTRUCTIONS,
|
||||||
|
verbosity: int = 0,
|
||||||
|
register_reply_position: int = 2,
|
||||||
|
):
|
||||||
|
"""Args:
|
||||||
|
image_generator (ImageGenerator): The image generator you would like to use to generate images.
|
||||||
|
cache (None or AbstractCache): The cache client to use to store and retrieve generated images. If None,
|
||||||
|
no caching will be used.
|
||||||
|
text_analyzer_llm_config (LLMConfig or Dict or None): The LLM config for the text analyzer. If None, the LLM config will
|
||||||
|
be retrieved from the agent you're adding the ability to.
|
||||||
|
text_analyzer_instructions (str): Instructions provided to the TextAnalyzerAgent used to analyze
|
||||||
|
incoming messages and extract the prompt for image generation. The default instructions focus on
|
||||||
|
summarizing the prompt. You can customize the instructions to achieve more granular control over prompt
|
||||||
|
extraction.
|
||||||
|
Example: 'Extract specific details from the message, like desired objects, styles, or backgrounds.'
|
||||||
|
verbosity (int): The verbosity level. Defaults to 0 and must be greater than or equal to 0. The text
|
||||||
|
analyzer llm calls will be silent if verbosity is less than 2.
|
||||||
|
register_reply_position (int): The position of the reply function in the agent's list of reply functions.
|
||||||
|
This capability registers a new reply function to handle messages with image generation requests.
|
||||||
|
Defaults to 2 to place it after the check termination and human reply for a ConversableAgent.
|
||||||
|
"""
|
||||||
|
self._image_generator = image_generator
|
||||||
|
self._cache = cache
|
||||||
|
self._text_analyzer_llm_config = text_analyzer_llm_config
|
||||||
|
self._text_analyzer_instructions = text_analyzer_instructions
|
||||||
|
self._verbosity = verbosity
|
||||||
|
self._register_reply_position = register_reply_position
|
||||||
|
|
||||||
|
self._agent: Optional[ConversableAgent] = None
|
||||||
|
self._text_analyzer: Optional[TextAnalyzerAgent] = None
|
||||||
|
|
||||||
|
def add_to_agent(self, agent: ConversableAgent):
|
||||||
|
"""Adds the Image Generation capability to the specified ConversableAgent.
|
||||||
|
|
||||||
|
This function performs the following modifications to the agent:
|
||||||
|
|
||||||
|
1. Registers a reply function: A new reply function is registered with the agent to handle messages that
|
||||||
|
potentially request image generation. This function analyzes the message and triggers image generation if
|
||||||
|
necessary.
|
||||||
|
2. Creates an Agent (TextAnalyzerAgent): This is used to analyze messages for image generation requirements.
|
||||||
|
3. Updates System Message: The agent's system message is updated to include a message indicating the
|
||||||
|
capability to generate images has been added.
|
||||||
|
4. Updates Description: The agent's description is updated to reflect the addition of the Image Generation
|
||||||
|
capability. This might be helpful in certain use cases, like group chats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent (ConversableAgent): The ConversableAgent to add the capability to.
|
||||||
|
"""
|
||||||
|
self._agent = agent
|
||||||
|
|
||||||
|
agent.register_reply([Agent, None], self._image_gen_reply, position=self._register_reply_position)
|
||||||
|
|
||||||
|
self._text_analyzer_llm_config = self._text_analyzer_llm_config or agent.llm_config
|
||||||
|
self._text_analyzer = TextAnalyzerAgent(llm_config=self._text_analyzer_llm_config)
|
||||||
|
|
||||||
|
agent.update_system_message(agent.system_message + "\n" + SYSTEM_MESSAGE)
|
||||||
|
agent.description += "\n" + DESCRIPTION_MESSAGE
|
||||||
|
|
||||||
|
def _image_gen_reply(
|
||||||
|
self,
|
||||||
|
recipient: ConversableAgent,
|
||||||
|
messages: Optional[list[dict[str, Any]]],
|
||||||
|
sender: Optional[Agent] = None,
|
||||||
|
config: Optional[Any] = None,
|
||||||
|
) -> tuple[bool, Optional[Union[str, dict[str, Any]]]]:
|
||||||
|
if messages is None:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
last_message = code_utils.content_str(messages[-1]["content"])
|
||||||
|
|
||||||
|
if not last_message:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
if self._should_generate_image(last_message):
|
||||||
|
prompt = self._extract_prompt(last_message)
|
||||||
|
|
||||||
|
image = self._cache_get(prompt)
|
||||||
|
if image is None:
|
||||||
|
image = self._image_generator.generate_image(prompt)
|
||||||
|
self._cache_set(prompt, image)
|
||||||
|
|
||||||
|
return True, self._generate_content_message(prompt, image)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
def _should_generate_image(self, message: str) -> bool:
|
||||||
|
assert self._text_analyzer is not None
|
||||||
|
|
||||||
|
instructions = """
|
||||||
|
Does any part of the TEXT ask the agent to generate an image?
|
||||||
|
The TEXT must explicitly mention that the image must be generated.
|
||||||
|
Answer with just one word, yes or no.
|
||||||
|
"""
|
||||||
|
analysis = self._text_analyzer.analyze_text(message, instructions)
|
||||||
|
|
||||||
|
return "yes" in self._extract_analysis(analysis).lower()
|
||||||
|
|
||||||
|
def _extract_prompt(self, last_message) -> str:
|
||||||
|
assert self._text_analyzer is not None
|
||||||
|
|
||||||
|
analysis = self._text_analyzer.analyze_text(last_message, self._text_analyzer_instructions)
|
||||||
|
return self._extract_analysis(analysis)
|
||||||
|
|
||||||
|
def _cache_get(self, prompt: str) -> Optional["Image"]:
|
||||||
|
if self._cache:
|
||||||
|
key = self._image_generator.cache_key(prompt)
|
||||||
|
cached_value = self._cache.get(key)
|
||||||
|
|
||||||
|
if cached_value:
|
||||||
|
return img_utils.get_pil_image(cached_value)
|
||||||
|
|
||||||
|
def _cache_set(self, prompt: str, image: "Image"):
|
||||||
|
if self._cache:
|
||||||
|
key = self._image_generator.cache_key(prompt)
|
||||||
|
self._cache.set(key, img_utils.pil_to_data_uri(image))
|
||||||
|
|
||||||
|
def _extract_analysis(self, analysis: Optional[Union[str, dict[str, Any]]]) -> str:
|
||||||
|
if isinstance(analysis, dict):
|
||||||
|
return code_utils.content_str(analysis["content"])
|
||||||
|
else:
|
||||||
|
return code_utils.content_str(analysis)
|
||||||
|
|
||||||
|
def _generate_content_message(self, prompt: str, image: "Image") -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": f"I generated an image with the prompt: {prompt}"},
|
||||||
|
{"type": "image_url", "image_url": {"url": img_utils.pil_to_data_uri(image)}},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Helpers
|
||||||
|
def _validate_resolution_format(resolution: str):
|
||||||
|
"""Checks if a string is in a valid resolution format (e.g., "1024x768")."""
|
||||||
|
pattern = r"^\d+x\d+$" # Matches a pattern of digits, "x", and digits
|
||||||
|
matched_resolution = re.match(pattern, resolution)
|
||||||
|
if matched_resolution is None:
|
||||||
|
raise ValueError(f"Invalid resolution format: {resolution}")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_dalle_model(model: str):
|
||||||
|
if model not in ["dall-e-3", "dall-e-2"]:
|
||||||
|
raise ValueError(f"Invalid DALL-E model: {model}. Must be 'dall-e-3' or 'dall-e-2'")
|
||||||
@@ -0,0 +1,393 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
from ....formatting_utils import colored
|
||||||
|
from ....import_utils import optional_import_block, require_optional_import
|
||||||
|
from ....llm_config import LLMConfig
|
||||||
|
from ...assistant_agent import ConversableAgent
|
||||||
|
from ..text_analyzer_agent import TextAnalyzerAgent
|
||||||
|
from .agent_capability import AgentCapability
|
||||||
|
|
||||||
|
with optional_import_block():
|
||||||
|
import chromadb
|
||||||
|
from chromadb.config import Settings
|
||||||
|
|
||||||
|
|
||||||
|
class Teachability(AgentCapability):
|
||||||
|
"""Teachability uses a vector database to give an agent the ability to remember user teachings,
|
||||||
|
where the user is any caller (human or not) sending messages to the teachable agent.
|
||||||
|
Teachability is designed to be composable with other agent capabilities.
|
||||||
|
To make any conversable agent teachable, instantiate both the agent and the Teachability class,
|
||||||
|
then pass the agent to teachability.add_to_agent(agent).
|
||||||
|
Note that teachable agents in a group chat must be given unique path_to_db_dir values.
|
||||||
|
|
||||||
|
When adding Teachability to an agent, the following are modified:
|
||||||
|
- The agent's system message is appended with a note about the agent's new ability.
|
||||||
|
- A hook is added to the agent's `process_last_received_message` hookable method,
|
||||||
|
and the hook potentially modifies the last of the received messages to include earlier teachings related to the message.
|
||||||
|
Added teachings do not propagate into the stored message history.
|
||||||
|
If new user teachings are detected, they are added to new memos in the vector database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
verbosity: Optional[int] = 0,
|
||||||
|
reset_db: Optional[bool] = False,
|
||||||
|
path_to_db_dir: Optional[str] = "./tmp/teachable_agent_db",
|
||||||
|
recall_threshold: Optional[float] = 1.5,
|
||||||
|
max_num_retrievals: Optional[int] = 10,
|
||||||
|
llm_config: Optional[Union[LLMConfig, dict[str, Any], bool]] = None,
|
||||||
|
):
|
||||||
|
"""Args:
|
||||||
|
verbosity (Optional, int): # 0 (default) for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.
|
||||||
|
reset_db (Optional, bool): True to clear the DB before starting. Default False.
|
||||||
|
path_to_db_dir (Optional, str): path to the directory where this particular agent's DB is stored. Default "./tmp/teachable_agent_db"
|
||||||
|
recall_threshold (Optional, float): The maximum distance for retrieved memos, where 0.0 is exact match. Default 1.5. Larger values allow more (but less relevant) memos to be recalled.
|
||||||
|
max_num_retrievals (Optional, int): The maximum number of memos to retrieve from the DB. Default 10.
|
||||||
|
llm_config (LLMConfig or dict or False): llm inference configuration passed to TextAnalyzerAgent.
|
||||||
|
If None, TextAnalyzerAgent uses llm_config from the teachable agent.
|
||||||
|
"""
|
||||||
|
self.verbosity = verbosity
|
||||||
|
self.path_to_db_dir = path_to_db_dir
|
||||||
|
self.recall_threshold = recall_threshold
|
||||||
|
self.max_num_retrievals = max_num_retrievals
|
||||||
|
self.llm_config = llm_config
|
||||||
|
|
||||||
|
self.analyzer = None
|
||||||
|
self.teachable_agent = None
|
||||||
|
|
||||||
|
# Create the memo store.
|
||||||
|
self.memo_store = MemoStore(self.verbosity, reset_db, self.path_to_db_dir)
|
||||||
|
|
||||||
|
def add_to_agent(self, agent: ConversableAgent):
|
||||||
|
"""Adds teachability to the given agent."""
|
||||||
|
self.teachable_agent = agent
|
||||||
|
|
||||||
|
# Register a hook for processing the last message.
|
||||||
|
agent.register_hook(hookable_method="process_last_received_message", hook=self.process_last_received_message)
|
||||||
|
|
||||||
|
# Was an llm_config passed to the constructor?
|
||||||
|
if self.llm_config is None:
|
||||||
|
# No. Use the agent's llm_config.
|
||||||
|
self.llm_config = agent.llm_config
|
||||||
|
assert self.llm_config, "Teachability requires a valid llm_config."
|
||||||
|
|
||||||
|
# Create the analyzer agent.
|
||||||
|
self.analyzer = TextAnalyzerAgent(llm_config=self.llm_config)
|
||||||
|
|
||||||
|
# Append extra info to the system message.
|
||||||
|
agent.update_system_message(
|
||||||
|
agent.system_message
|
||||||
|
+ "\nYou've been given the special ability to remember user teachings from prior conversations."
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepopulate_db(self):
|
||||||
|
"""Adds a few arbitrary memos to the DB."""
|
||||||
|
self.memo_store.prepopulate()
|
||||||
|
|
||||||
|
def process_last_received_message(self, text: Union[dict[str, Any], str]):
|
||||||
|
"""Appends any relevant memos to the message text, and stores any apparent teachings in new memos.
|
||||||
|
Uses TextAnalyzerAgent to make decisions about memo storage and retrieval.
|
||||||
|
"""
|
||||||
|
# Try to retrieve relevant memos from the DB.
|
||||||
|
expanded_text = text
|
||||||
|
if self.memo_store.last_memo_id > 0:
|
||||||
|
expanded_text = self._consider_memo_retrieval(text)
|
||||||
|
|
||||||
|
# Try to store any user teachings in new memos to be used in the future.
|
||||||
|
self._consider_memo_storage(text)
|
||||||
|
|
||||||
|
# Return the (possibly) expanded message text.
|
||||||
|
return expanded_text
|
||||||
|
|
||||||
|
def _consider_memo_storage(self, comment: Union[dict[str, Any], str]):
|
||||||
|
"""Decides whether to store something from one user comment in the DB."""
|
||||||
|
memo_added = False
|
||||||
|
|
||||||
|
# Check for a problem-solution pair.
|
||||||
|
response = self._analyze(
|
||||||
|
comment,
|
||||||
|
"Does any part of the TEXT ask the agent to perform a task or solve a problem? Answer with just one word, yes or no.",
|
||||||
|
)
|
||||||
|
if "yes" in response.lower():
|
||||||
|
# Can we extract advice?
|
||||||
|
advice = self._analyze(
|
||||||
|
comment,
|
||||||
|
"Briefly copy any advice from the TEXT that may be useful for a similar but different task in the future. But if no advice is present, just respond with 'none'.",
|
||||||
|
)
|
||||||
|
if "none" not in advice.lower():
|
||||||
|
# Yes. Extract the task.
|
||||||
|
task = self._analyze(
|
||||||
|
comment,
|
||||||
|
"Briefly copy just the task from the TEXT, then stop. Don't solve it, and don't include any advice.",
|
||||||
|
)
|
||||||
|
# Generalize the task.
|
||||||
|
general_task = self._analyze(
|
||||||
|
task,
|
||||||
|
"Summarize very briefly, in general terms, the type of task described in the TEXT. Leave out details that might not appear in a similar problem.",
|
||||||
|
)
|
||||||
|
# Add the task-advice (problem-solution) pair to the vector DB.
|
||||||
|
if self.verbosity >= 1:
|
||||||
|
print(colored("\nREMEMBER THIS TASK-ADVICE PAIR", "light_yellow"))
|
||||||
|
self.memo_store.add_input_output_pair(general_task, advice)
|
||||||
|
memo_added = True
|
||||||
|
|
||||||
|
# Check for information to be learned.
|
||||||
|
response = self._analyze(
|
||||||
|
comment,
|
||||||
|
"Does the TEXT contain information that could be committed to memory? Answer with just one word, yes or no.",
|
||||||
|
)
|
||||||
|
if "yes" in response.lower():
|
||||||
|
# Yes. What question would this information answer?
|
||||||
|
question = self._analyze(
|
||||||
|
comment,
|
||||||
|
"Imagine that the user forgot this information in the TEXT. How would they ask you for this information? Include no other text in your response.",
|
||||||
|
)
|
||||||
|
# Extract the information.
|
||||||
|
answer = self._analyze(
|
||||||
|
comment, "Copy the information from the TEXT that should be committed to memory. Add no explanation."
|
||||||
|
)
|
||||||
|
# Add the question-answer pair to the vector DB.
|
||||||
|
if self.verbosity >= 1:
|
||||||
|
print(colored("\nREMEMBER THIS QUESTION-ANSWER PAIR", "light_yellow"))
|
||||||
|
self.memo_store.add_input_output_pair(question, answer)
|
||||||
|
memo_added = True
|
||||||
|
|
||||||
|
# Were any memos added?
|
||||||
|
if memo_added:
|
||||||
|
# Yes. Save them to disk.
|
||||||
|
self.memo_store._save_memos()
|
||||||
|
|
||||||
|
def _consider_memo_retrieval(self, comment: Union[dict[str, Any], str]):
|
||||||
|
"""Decides whether to retrieve memos from the DB, and add them to the chat context."""
|
||||||
|
# First, use the comment directly as the lookup key.
|
||||||
|
if self.verbosity >= 1:
|
||||||
|
print(colored("\nLOOK FOR RELEVANT MEMOS, AS QUESTION-ANSWER PAIRS", "light_yellow"))
|
||||||
|
memo_list = self._retrieve_relevant_memos(comment)
|
||||||
|
|
||||||
|
# Next, if the comment involves a task, then extract and generalize the task before using it as the lookup key.
|
||||||
|
response = self._analyze(
|
||||||
|
comment,
|
||||||
|
"Does any part of the TEXT ask the agent to perform a task or solve a problem? Answer with just one word, yes or no.",
|
||||||
|
)
|
||||||
|
if "yes" in response.lower():
|
||||||
|
if self.verbosity >= 1:
|
||||||
|
print(colored("\nLOOK FOR RELEVANT MEMOS, AS TASK-ADVICE PAIRS", "light_yellow"))
|
||||||
|
# Extract the task.
|
||||||
|
task = self._analyze(
|
||||||
|
comment, "Copy just the task from the TEXT, then stop. Don't solve it, and don't include any advice."
|
||||||
|
)
|
||||||
|
# Generalize the task.
|
||||||
|
general_task = self._analyze(
|
||||||
|
task,
|
||||||
|
"Summarize very briefly, in general terms, the type of task described in the TEXT. Leave out details that might not appear in a similar problem.",
|
||||||
|
)
|
||||||
|
# Append any relevant memos.
|
||||||
|
memo_list.extend(self._retrieve_relevant_memos(general_task))
|
||||||
|
|
||||||
|
# De-duplicate the memo list.
|
||||||
|
memo_list = list(set(memo_list))
|
||||||
|
|
||||||
|
# Append the memos to the text of the last message.
|
||||||
|
return comment + self._concatenate_memo_texts(memo_list)
|
||||||
|
|
||||||
|
def _retrieve_relevant_memos(self, input_text: str) -> list:
|
||||||
|
"""Returns semantically related memos from the DB."""
|
||||||
|
memo_list = self.memo_store.get_related_memos(
|
||||||
|
input_text, n_results=self.max_num_retrievals, threshold=self.recall_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.verbosity >= 1: # noqa: SIM102
|
||||||
|
# Was anything retrieved?
|
||||||
|
if len(memo_list) == 0:
|
||||||
|
# No. Look at the closest memo.
|
||||||
|
print(colored("\nTHE CLOSEST MEMO IS BEYOND THE THRESHOLD:", "light_yellow"))
|
||||||
|
self.memo_store.get_nearest_memo(input_text)
|
||||||
|
print() # Print a blank line. The memo details were printed by get_nearest_memo().
|
||||||
|
|
||||||
|
# Create a list of just the memo output_text strings.
|
||||||
|
memo_list = [memo[1] for memo in memo_list]
|
||||||
|
return memo_list
|
||||||
|
|
||||||
|
def _concatenate_memo_texts(self, memo_list: list) -> str:
|
||||||
|
"""Concatenates the memo texts into a single string for inclusion in the chat context."""
|
||||||
|
memo_texts = ""
|
||||||
|
if len(memo_list) > 0:
|
||||||
|
info = "\n# Memories that might help\n"
|
||||||
|
for memo in memo_list:
|
||||||
|
info = info + "- " + memo + "\n"
|
||||||
|
if self.verbosity >= 1:
|
||||||
|
print(colored("\nMEMOS APPENDED TO LAST MESSAGE...\n" + info + "\n", "light_yellow"))
|
||||||
|
memo_texts = memo_texts + "\n" + info
|
||||||
|
return memo_texts
|
||||||
|
|
||||||
|
def _analyze(self, text_to_analyze: Union[dict[str, Any], str], analysis_instructions: Union[dict[str, Any], str]):
|
||||||
|
"""Asks TextAnalyzerAgent to analyze the given text according to specific instructions."""
|
||||||
|
self.analyzer.reset() # Clear the analyzer's list of messages.
|
||||||
|
self.teachable_agent.send(
|
||||||
|
recipient=self.analyzer, message=text_to_analyze, request_reply=False, silent=(self.verbosity < 2)
|
||||||
|
) # Put the message in the analyzer's list.
|
||||||
|
self.teachable_agent.send(
|
||||||
|
recipient=self.analyzer, message=analysis_instructions, request_reply=True, silent=(self.verbosity < 2)
|
||||||
|
) # Request the reply.
|
||||||
|
return self.teachable_agent.last_message(self.analyzer)["content"]
|
||||||
|
|
||||||
|
|
||||||
|
@require_optional_import("chromadb", "teachable")
|
||||||
|
class MemoStore:
|
||||||
|
"""Provides memory storage and retrieval for a teachable agent, using a vector database.
|
||||||
|
Each DB entry (called a memo) is a pair of strings: an input text and an output text.
|
||||||
|
The input text might be a question, or a task to perform.
|
||||||
|
The output text might be an answer to the question, or advice on how to perform the task.
|
||||||
|
Vector embeddings are currently supplied by Chroma's default Sentence Transformers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
verbosity: Optional[int] = 0,
|
||||||
|
reset: Optional[bool] = False,
|
||||||
|
path_to_db_dir: Optional[str] = "./tmp/teachable_agent_db",
|
||||||
|
):
|
||||||
|
"""Args:
|
||||||
|
- verbosity (Optional, int): 1 to print memory operations, 0 to omit them. 3+ to print memo lists.
|
||||||
|
- reset (Optional, bool): True to clear the DB before starting. Default False.
|
||||||
|
- path_to_db_dir (Optional, str): path to the directory where the DB is stored.
|
||||||
|
"""
|
||||||
|
self.verbosity = verbosity
|
||||||
|
self.path_to_db_dir = path_to_db_dir
|
||||||
|
|
||||||
|
# Load or create the vector DB on disk.
|
||||||
|
settings = Settings(
|
||||||
|
anonymized_telemetry=False, allow_reset=True, is_persistent=True, persist_directory=path_to_db_dir
|
||||||
|
)
|
||||||
|
self.db_client = chromadb.Client(settings)
|
||||||
|
self.vec_db = self.db_client.create_collection("memos", get_or_create=True) # The collection is the DB.
|
||||||
|
|
||||||
|
# Load or create the associated memo dict on disk.
|
||||||
|
self.path_to_dict = os.path.join(path_to_db_dir, "uid_text_dict.pkl")
|
||||||
|
self.uid_text_dict = {}
|
||||||
|
self.last_memo_id = 0
|
||||||
|
if (not reset) and os.path.exists(self.path_to_dict):
|
||||||
|
print(colored("\nLOADING MEMORY FROM DISK", "light_green"))
|
||||||
|
print(colored(f" Location = {self.path_to_dict}", "light_green"))
|
||||||
|
with open(self.path_to_dict, "rb") as f:
|
||||||
|
self.uid_text_dict = pickle.load(f)
|
||||||
|
self.last_memo_id = len(self.uid_text_dict)
|
||||||
|
if self.verbosity >= 3:
|
||||||
|
self.list_memos()
|
||||||
|
|
||||||
|
# Clear the DB if requested.
|
||||||
|
if reset:
|
||||||
|
self.reset_db()
|
||||||
|
|
||||||
|
def list_memos(self):
|
||||||
|
"""Prints the contents of MemoStore."""
|
||||||
|
print(colored("LIST OF MEMOS", "light_green"))
|
||||||
|
for uid, text in self.uid_text_dict.items():
|
||||||
|
input_text, output_text = text
|
||||||
|
print(
|
||||||
|
colored(
|
||||||
|
f" ID: {uid}\n INPUT TEXT: {input_text}\n OUTPUT TEXT: {output_text}",
|
||||||
|
"light_green",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _save_memos(self):
|
||||||
|
"""Saves self.uid_text_dict to disk."""
|
||||||
|
with open(self.path_to_dict, "wb") as file:
|
||||||
|
pickle.dump(self.uid_text_dict, file)
|
||||||
|
|
||||||
|
def reset_db(self):
|
||||||
|
"""Forces immediate deletion of the DB's contents, in memory and on disk."""
|
||||||
|
print(colored("\nCLEARING MEMORY", "light_green"))
|
||||||
|
self.db_client.delete_collection("memos")
|
||||||
|
self.vec_db = self.db_client.create_collection("memos")
|
||||||
|
self.uid_text_dict = {}
|
||||||
|
self._save_memos()
|
||||||
|
|
||||||
|
def add_input_output_pair(self, input_text: str, output_text: str):
|
||||||
|
"""Adds an input-output pair to the vector DB."""
|
||||||
|
self.last_memo_id += 1
|
||||||
|
self.vec_db.add(documents=[input_text], ids=[str(self.last_memo_id)])
|
||||||
|
self.uid_text_dict[str(self.last_memo_id)] = input_text, output_text
|
||||||
|
if self.verbosity >= 1:
|
||||||
|
print(
|
||||||
|
colored(
|
||||||
|
f"\nINPUT-OUTPUT PAIR ADDED TO VECTOR DATABASE:\n ID\n {self.last_memo_id}\n INPUT\n {input_text}\n OUTPUT\n {output_text}\n",
|
||||||
|
"light_yellow",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if self.verbosity >= 3:
|
||||||
|
self.list_memos()
|
||||||
|
|
||||||
|
def get_nearest_memo(self, query_text: str):
|
||||||
|
"""Retrieves the nearest memo to the given query text."""
|
||||||
|
results = self.vec_db.query(query_texts=[query_text], n_results=1)
|
||||||
|
uid, input_text, distance = results["ids"][0][0], results["documents"][0][0], results["distances"][0][0]
|
||||||
|
input_text_2, output_text = self.uid_text_dict[uid]
|
||||||
|
assert input_text == input_text_2
|
||||||
|
if self.verbosity >= 1:
|
||||||
|
print(
|
||||||
|
colored(
|
||||||
|
f"\nINPUT-OUTPUT PAIR RETRIEVED FROM VECTOR DATABASE:\n INPUT1\n {input_text}\n OUTPUT\n {output_text}\n DISTANCE\n {distance}",
|
||||||
|
"light_yellow",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return input_text, output_text, distance
|
||||||
|
|
||||||
|
def get_related_memos(self, query_text: str, n_results: int, threshold: Union[int, float]):
|
||||||
|
"""Retrieves memos that are related to the given query text within the specified distance threshold."""
|
||||||
|
if n_results > len(self.uid_text_dict):
|
||||||
|
n_results = len(self.uid_text_dict)
|
||||||
|
results = self.vec_db.query(query_texts=[query_text], n_results=n_results)
|
||||||
|
memos = []
|
||||||
|
num_results = len(results["ids"][0])
|
||||||
|
for i in range(num_results):
|
||||||
|
uid, input_text, distance = results["ids"][0][i], results["documents"][0][i], results["distances"][0][i]
|
||||||
|
if distance < threshold:
|
||||||
|
input_text_2, output_text = self.uid_text_dict[uid]
|
||||||
|
assert input_text == input_text_2
|
||||||
|
if self.verbosity >= 1:
|
||||||
|
print(
|
||||||
|
colored(
|
||||||
|
f"\nINPUT-OUTPUT PAIR RETRIEVED FROM VECTOR DATABASE:\n INPUT1\n {input_text}\n OUTPUT\n {output_text}\n DISTANCE\n {distance}",
|
||||||
|
"light_yellow",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
memos.append((input_text, output_text, distance))
|
||||||
|
return memos
|
||||||
|
|
||||||
|
def prepopulate(self):
|
||||||
|
"""Adds a few arbitrary examples to the vector DB, just to make retrieval less trivial."""
|
||||||
|
if self.verbosity >= 1:
|
||||||
|
print(colored("\nPREPOPULATING MEMORY", "light_green"))
|
||||||
|
examples = []
|
||||||
|
examples.append({"text": "When I say papers I mean research papers, which are typically pdfs.", "label": "yes"})
|
||||||
|
examples.append({"text": "Please verify that each paper you listed actually uses langchain.", "label": "no"})
|
||||||
|
examples.append({"text": "Tell gpt the output should still be latex code.", "label": "no"})
|
||||||
|
examples.append({"text": "Hint: convert pdfs to text and then answer questions based on them.", "label": "yes"})
|
||||||
|
examples.append({
|
||||||
|
"text": "To create a good PPT, include enough content to make it interesting.",
|
||||||
|
"label": "yes",
|
||||||
|
})
|
||||||
|
examples.append({
|
||||||
|
"text": "No, for this case the columns should be aspects and the rows should be frameworks.",
|
||||||
|
"label": "no",
|
||||||
|
})
|
||||||
|
examples.append({"text": "When writing code, remember to include any libraries that are used.", "label": "yes"})
|
||||||
|
examples.append({"text": "Please summarize the papers by Eric Horvitz on bounded rationality.", "label": "no"})
|
||||||
|
examples.append({"text": "Compare the h-index of Daniel Weld and Oren Etzioni.", "label": "no"})
|
||||||
|
examples.append({
|
||||||
|
"text": "Double check to be sure that the columns in a table correspond to what was asked for.",
|
||||||
|
"label": "yes",
|
||||||
|
})
|
||||||
|
for example in examples:
|
||||||
|
self.add_input_output_pair(example["text"], example["label"])
|
||||||
|
self._save_memos()
|
||||||
@@ -0,0 +1,66 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
from ....import_utils import optional_import_block, require_optional_import
|
||||||
|
|
||||||
|
with optional_import_block() as result:
|
||||||
|
import llmlingua
|
||||||
|
from llmlingua import PromptCompressor
|
||||||
|
|
||||||
|
|
||||||
|
class TextCompressor(Protocol):
|
||||||
|
"""Defines a protocol for text compression to optimize agent interactions."""
|
||||||
|
|
||||||
|
def compress_text(self, text: str, **compression_params) -> dict[str, Any]:
|
||||||
|
"""This method takes a string as input and returns a dictionary containing the compressed text and other
|
||||||
|
relevant information. The compressed text should be stored under the 'compressed_text' key in the dictionary.
|
||||||
|
To calculate the number of saved tokens, the dictionary should include 'origin_tokens' and 'compressed_tokens' keys.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@require_optional_import("llmlingua", "long-context")
|
||||||
|
class LLMLingua:
|
||||||
|
"""Compresses text messages using LLMLingua for improved efficiency in processing and response generation.
|
||||||
|
|
||||||
|
NOTE: The effectiveness of compression and the resultant token savings can vary based on the content of the messages
|
||||||
|
and the specific configurations used for the PromptCompressor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prompt_compressor_kwargs: dict = dict(
|
||||||
|
model_name="microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
|
||||||
|
use_llmlingua2=True,
|
||||||
|
device_map="cpu",
|
||||||
|
),
|
||||||
|
structured_compression: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Args:
|
||||||
|
prompt_compressor_kwargs (dict): A dictionary of keyword arguments for the PromptCompressor. Defaults to a
|
||||||
|
dictionary with model_name set to "microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
|
||||||
|
use_llmlingua2 set to True, and device_map set to "cpu".
|
||||||
|
structured_compression (bool): A flag indicating whether to use structured compression. If True, the
|
||||||
|
structured_compress_prompt method of the PromptCompressor is used. Otherwise, the compress_prompt method
|
||||||
|
is used. Defaults to False.
|
||||||
|
dictionary.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If the llmlingua library is not installed.
|
||||||
|
"""
|
||||||
|
self._prompt_compressor = PromptCompressor(**prompt_compressor_kwargs)
|
||||||
|
|
||||||
|
assert isinstance(self._prompt_compressor, llmlingua.PromptCompressor)
|
||||||
|
self._compression_method = (
|
||||||
|
self._prompt_compressor.structured_compress_prompt
|
||||||
|
if structured_compression
|
||||||
|
else self._prompt_compressor.compress_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
def compress_text(self, text: str, **compression_params) -> dict[str, Any]:
|
||||||
|
return self._compression_method([text], **compression_params)
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from ....agentchat import ConversableAgent
|
||||||
|
from ....tools import Tool
|
||||||
|
|
||||||
|
|
||||||
|
class ToolsCapability:
|
||||||
|
"""Adding a list of tools as composable capabilities to a single agent.
|
||||||
|
This class can be inherited from to allow code to run at the point of creating or adding the capability.
|
||||||
|
|
||||||
|
Note: both caller and executor of the tools are the same agent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tool_list: list[Tool]):
|
||||||
|
self.tools = [tool for tool in tool_list]
|
||||||
|
|
||||||
|
def add_to_agent(self, agent: ConversableAgent):
|
||||||
|
"""Add tools to the given agent."""
|
||||||
|
for tool in self.tools:
|
||||||
|
tool.register_tool(agent=agent)
|
||||||
@@ -0,0 +1,93 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
import copy
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from ....formatting_utils import colored
|
||||||
|
from .transforms import MessageTransform
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...conversable_agent import ConversableAgent
|
||||||
|
|
||||||
|
|
||||||
|
class TransformMessages:
|
||||||
|
"""Agent capability for transforming messages before reply generation.
|
||||||
|
|
||||||
|
This capability allows you to apply a series of message transformations to
|
||||||
|
a ConversableAgent's incoming messages before they are processed for response
|
||||||
|
generation. This is useful for tasks such as:
|
||||||
|
|
||||||
|
- Limiting the number of messages considered for context.
|
||||||
|
- Truncating messages to meet token limits.
|
||||||
|
- Filtering sensitive information.
|
||||||
|
- Customizing message formatting.
|
||||||
|
|
||||||
|
To use `TransformMessages`:
|
||||||
|
|
||||||
|
1. Create message transformations (e.g., `MessageHistoryLimiter`, `MessageTokenLimiter`).
|
||||||
|
2. Instantiate `TransformMessages` with a list of these transformations.
|
||||||
|
3. Add the `TransformMessages` instance to your `ConversableAgent` using `add_to_agent`.
|
||||||
|
|
||||||
|
NOTE: Order of message transformations is important. You could get different results based on
|
||||||
|
the order of transformations.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from agentchat import ConversableAgent
|
||||||
|
from agentchat.contrib.capabilities import TransformMessages, MessageHistoryLimiter, MessageTokenLimiter
|
||||||
|
|
||||||
|
max_messages = MessageHistoryLimiter(max_messages=2)
|
||||||
|
truncate_messages = MessageTokenLimiter(max_tokens=500)
|
||||||
|
transform_messages = TransformMessages(transforms=[max_messages, truncate_messages])
|
||||||
|
|
||||||
|
agent = ConversableAgent(...)
|
||||||
|
transform_messages.add_to_agent(agent)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, transforms: list[MessageTransform] = [], verbose: bool = True):
|
||||||
|
"""Args:
|
||||||
|
transforms: A list of message transformations to apply.
|
||||||
|
verbose: Whether to print logs of each transformation or not.
|
||||||
|
"""
|
||||||
|
self._transforms = transforms
|
||||||
|
self._verbose = verbose
|
||||||
|
|
||||||
|
def add_to_agent(self, agent: "ConversableAgent"):
|
||||||
|
"""Adds the message transformations capability to the specified ConversableAgent.
|
||||||
|
|
||||||
|
This function performs the following modifications to the agent:
|
||||||
|
|
||||||
|
1. Registers a hook that automatically transforms all messages before they are processed for
|
||||||
|
response generation.
|
||||||
|
"""
|
||||||
|
agent.register_hook(hookable_method="process_all_messages_before_reply", hook=self._transform_messages)
|
||||||
|
|
||||||
|
def _transform_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
post_transform_messages = copy.deepcopy(messages)
|
||||||
|
system_message = None
|
||||||
|
|
||||||
|
if messages[0]["role"] == "system":
|
||||||
|
system_message = copy.deepcopy(messages[0])
|
||||||
|
post_transform_messages.pop(0)
|
||||||
|
|
||||||
|
for transform in self._transforms:
|
||||||
|
# deepcopy in case pre_transform_messages will later be used for logs printing
|
||||||
|
pre_transform_messages = (
|
||||||
|
copy.deepcopy(post_transform_messages) if self._verbose else post_transform_messages
|
||||||
|
)
|
||||||
|
post_transform_messages = transform.apply_transform(pre_transform_messages)
|
||||||
|
|
||||||
|
if self._verbose:
|
||||||
|
logs_str, had_effect = transform.get_logs(pre_transform_messages, post_transform_messages)
|
||||||
|
if had_effect:
|
||||||
|
print(colored(logs_str, "yellow"))
|
||||||
|
|
||||||
|
if system_message:
|
||||||
|
post_transform_messages.insert(0, system_message)
|
||||||
|
|
||||||
|
return post_transform_messages
|
||||||
@@ -0,0 +1,579 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
import copy
|
||||||
|
import sys
|
||||||
|
from typing import Any, Optional, Protocol, Union
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
from termcolor import colored
|
||||||
|
|
||||||
|
from .... import token_count_utils
|
||||||
|
from ....cache import AbstractCache, Cache
|
||||||
|
from ....types import MessageContentType
|
||||||
|
from . import transforms_util
|
||||||
|
from .text_compressors import LLMLingua, TextCompressor
|
||||||
|
|
||||||
|
|
||||||
|
class MessageTransform(Protocol):
|
||||||
|
"""Defines a contract for message transformation.
|
||||||
|
|
||||||
|
Classes implementing this protocol should provide an `apply_transform` method
|
||||||
|
that takes a list of messages and returns the transformed list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def apply_transform(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""Applies a transformation to a list of messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: A list of dictionaries representing messages.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new list of dictionaries containing the transformed messages.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_logs(
|
||||||
|
self, pre_transform_messages: list[dict[str, Any]], post_transform_messages: list[dict[str, Any]]
|
||||||
|
) -> tuple[str, bool]:
|
||||||
|
"""Creates the string including the logs of the transformation
|
||||||
|
|
||||||
|
Alongside the string, it returns a boolean indicating whether the transformation had an effect or not.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pre_transform_messages: A list of dictionaries representing messages before the transformation.
|
||||||
|
post_transform_messages: A list of dictionaries representig messages after the transformation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple with a string with the logs and a flag indicating whether the transformation had an effect or not.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class MessageHistoryLimiter:
|
||||||
|
"""Limits the number of messages considered by an agent for response generation.
|
||||||
|
|
||||||
|
This transform keeps only the most recent messages up to the specified maximum number of messages (max_messages).
|
||||||
|
It trims the conversation history by removing older messages, retaining only the most recent messages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_messages: Optional[int] = None,
|
||||||
|
keep_first_message: bool = False,
|
||||||
|
exclude_names: Optional[list[str]] = None,
|
||||||
|
):
|
||||||
|
"""Args:
|
||||||
|
max_messages Optional[int]: Maximum number of messages to keep in the context. Must be greater than 0 if not None.
|
||||||
|
keep_first_message bool: Whether to keep the original first message in the conversation history.
|
||||||
|
Defaults to False.
|
||||||
|
exclude_names Optional[list[str]]: List of message sender names to exclude from the message history.
|
||||||
|
Messages from these senders will be filtered out before applying the message limit. Defaults to None.
|
||||||
|
"""
|
||||||
|
self._validate_max_messages(max_messages)
|
||||||
|
self._max_messages = max_messages
|
||||||
|
self._keep_first_message = keep_first_message
|
||||||
|
self._exclude_names = exclude_names
|
||||||
|
|
||||||
|
def apply_transform(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""Truncates the conversation history to the specified maximum number of messages.
|
||||||
|
|
||||||
|
This method returns a new list containing the most recent messages up to the specified
|
||||||
|
maximum number of messages (max_messages). If max_messages is None, it returns the
|
||||||
|
original list of messages unmodified.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (List[Dict]): The list of messages representing the conversation history.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: A new list containing the most recent messages up to the specified maximum.
|
||||||
|
"""
|
||||||
|
|
||||||
|
exclude_names = getattr(self, "_exclude_names", None)
|
||||||
|
|
||||||
|
filtered = [msg for msg in messages if msg.get("name") not in exclude_names] if exclude_names else messages
|
||||||
|
|
||||||
|
if self._max_messages is None or len(filtered) <= self._max_messages:
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
truncated_messages = []
|
||||||
|
remaining_count = self._max_messages
|
||||||
|
|
||||||
|
# Start with the first message if we need to keep it
|
||||||
|
if self._keep_first_message and filtered:
|
||||||
|
truncated_messages = [filtered[0]]
|
||||||
|
remaining_count -= 1
|
||||||
|
|
||||||
|
# Loop through messages in reverse
|
||||||
|
for i in range(len(filtered) - 1, 0, -1):
|
||||||
|
if remaining_count > 1:
|
||||||
|
truncated_messages.insert(1 if self._keep_first_message else 0, filtered[i])
|
||||||
|
if remaining_count == 1: # noqa: SIM102
|
||||||
|
# If there's only 1 slot left and it's a 'tools' message, ignore it.
|
||||||
|
if filtered[i].get("role") != "tool":
|
||||||
|
truncated_messages.insert(1, filtered[i])
|
||||||
|
|
||||||
|
remaining_count -= 1
|
||||||
|
if remaining_count == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
return truncated_messages
|
||||||
|
|
||||||
|
def get_logs(
|
||||||
|
self, pre_transform_messages: list[dict[str, Any]], post_transform_messages: list[dict[str, Any]]
|
||||||
|
) -> tuple[str, bool]:
|
||||||
|
pre_transform_messages_len = len(pre_transform_messages)
|
||||||
|
post_transform_messages_len = len(post_transform_messages)
|
||||||
|
|
||||||
|
if post_transform_messages_len < pre_transform_messages_len:
|
||||||
|
logs_str = (
|
||||||
|
f"Removed {pre_transform_messages_len - post_transform_messages_len} messages. "
|
||||||
|
f"Number of messages reduced from {pre_transform_messages_len} to {post_transform_messages_len}."
|
||||||
|
)
|
||||||
|
return logs_str, True
|
||||||
|
return "No messages were removed.", False
|
||||||
|
|
||||||
|
def _validate_max_messages(self, max_messages: Optional[int]):
|
||||||
|
if max_messages is not None and max_messages < 1:
|
||||||
|
raise ValueError("max_messages must be None or greater than 1")
|
||||||
|
|
||||||
|
|
||||||
|
class MessageTokenLimiter:
|
||||||
|
"""Truncates messages to meet token limits for efficient processing and response generation.
|
||||||
|
|
||||||
|
This transformation applies two levels of truncation to the conversation history:
|
||||||
|
|
||||||
|
1. Truncates each individual message to the maximum number of tokens specified by max_tokens_per_message.
|
||||||
|
2. Truncates the overall conversation history to the maximum number of tokens specified by max_tokens.
|
||||||
|
|
||||||
|
NOTE: Tokens are counted using the encoder for the specified model. Different models may yield different token
|
||||||
|
counts for the same text.
|
||||||
|
|
||||||
|
NOTE: For multimodal LLMs, the token count may be inaccurate as it does not account for the non-text input
|
||||||
|
(e.g images).
|
||||||
|
|
||||||
|
The truncation process follows these steps in order:
|
||||||
|
|
||||||
|
1. The minimum tokens threshold (`min_tokens`) is checked (0 by default). If the total number of tokens in messages
|
||||||
|
is less than this threshold, then the messages are returned as is. In other case, the following process is applied.
|
||||||
|
2. Messages are processed in reverse order (newest to oldest).
|
||||||
|
3. Individual messages are truncated based on max_tokens_per_message. For multimodal messages containing both text
|
||||||
|
and other types of content, only the text content is truncated.
|
||||||
|
4. The overall conversation history is truncated based on the max_tokens limit. Once the accumulated token count
|
||||||
|
exceeds this limit, the current message being processed get truncated to meet the total token count and any
|
||||||
|
remaining messages get discarded.
|
||||||
|
5. The truncated conversation history is reconstructed by prepending the messages to a new list to preserve the
|
||||||
|
original message order.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_tokens_per_message: Optional[int] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
min_tokens: Optional[int] = None,
|
||||||
|
model: str = "gpt-3.5-turbo-0613",
|
||||||
|
filter_dict: Optional[dict[str, Any]] = None,
|
||||||
|
exclude_filter: bool = True,
|
||||||
|
):
|
||||||
|
"""Args:
|
||||||
|
max_tokens_per_message (None or int): Maximum number of tokens to keep in each message.
|
||||||
|
Must be greater than or equal to 0 if not None.
|
||||||
|
max_tokens (Optional[int]): Maximum number of tokens to keep in the chat history.
|
||||||
|
Must be greater than or equal to 0 if not None.
|
||||||
|
min_tokens (Optional[int]): Minimum number of tokens in messages to apply the transformation.
|
||||||
|
Must be greater than or equal to 0 if not None.
|
||||||
|
model (str): The target OpenAI model for tokenization alignment.
|
||||||
|
filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
|
||||||
|
If None, no filters will be applied.
|
||||||
|
exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
|
||||||
|
excluded from token truncation. If False, messages that match the filter will be truncated.
|
||||||
|
"""
|
||||||
|
self._model = model
|
||||||
|
self._max_tokens_per_message = self._validate_max_tokens(max_tokens_per_message)
|
||||||
|
self._max_tokens = self._validate_max_tokens(max_tokens)
|
||||||
|
self._min_tokens = self._validate_min_tokens(min_tokens, max_tokens)
|
||||||
|
self._filter_dict = filter_dict
|
||||||
|
self._exclude_filter = exclude_filter
|
||||||
|
|
||||||
|
def apply_transform(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""Applies token truncation to the conversation history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (List[Dict]): The list of messages representing the conversation history.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: A new list containing the truncated messages up to the specified token limits.
|
||||||
|
"""
|
||||||
|
assert self._max_tokens_per_message is not None
|
||||||
|
assert self._max_tokens is not None
|
||||||
|
assert self._min_tokens is not None
|
||||||
|
|
||||||
|
# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
|
||||||
|
if not transforms_util.min_tokens_reached(messages, self._min_tokens):
|
||||||
|
return messages
|
||||||
|
|
||||||
|
temp_messages = copy.deepcopy(messages)
|
||||||
|
processed_messages = []
|
||||||
|
processed_messages_tokens = 0
|
||||||
|
|
||||||
|
for msg in reversed(temp_messages):
|
||||||
|
# Some messages may not have content.
|
||||||
|
if not transforms_util.is_content_right_type(msg.get("content")):
|
||||||
|
processed_messages.insert(0, msg)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not transforms_util.should_transform_message(msg, self._filter_dict, self._exclude_filter):
|
||||||
|
processed_messages.insert(0, msg)
|
||||||
|
processed_messages_tokens += transforms_util.count_text_tokens(msg["content"])
|
||||||
|
continue
|
||||||
|
|
||||||
|
expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message
|
||||||
|
|
||||||
|
# If adding this message would exceed the token limit, truncate the last message to meet the total token
|
||||||
|
# limit and discard all remaining messages
|
||||||
|
if expected_tokens_remained < 0:
|
||||||
|
msg["content"] = self._truncate_str_to_tokens(
|
||||||
|
msg["content"], self._max_tokens - processed_messages_tokens
|
||||||
|
)
|
||||||
|
processed_messages.insert(0, msg)
|
||||||
|
break
|
||||||
|
|
||||||
|
msg["content"] = self._truncate_str_to_tokens(msg["content"], self._max_tokens_per_message)
|
||||||
|
msg_tokens = transforms_util.count_text_tokens(msg["content"])
|
||||||
|
|
||||||
|
# prepend the message to the list to preserve order
|
||||||
|
processed_messages_tokens += msg_tokens
|
||||||
|
processed_messages.insert(0, msg)
|
||||||
|
|
||||||
|
return processed_messages
|
||||||
|
|
||||||
|
def get_logs(
|
||||||
|
self, pre_transform_messages: list[dict[str, Any]], post_transform_messages: list[dict[str, Any]]
|
||||||
|
) -> tuple[str, bool]:
|
||||||
|
pre_transform_messages_tokens = sum(
|
||||||
|
transforms_util.count_text_tokens(msg["content"]) for msg in pre_transform_messages if "content" in msg
|
||||||
|
)
|
||||||
|
post_transform_messages_tokens = sum(
|
||||||
|
transforms_util.count_text_tokens(msg["content"]) for msg in post_transform_messages if "content" in msg
|
||||||
|
)
|
||||||
|
|
||||||
|
if post_transform_messages_tokens < pre_transform_messages_tokens:
|
||||||
|
logs_str = (
|
||||||
|
f"Truncated {pre_transform_messages_tokens - post_transform_messages_tokens} tokens. "
|
||||||
|
f"Number of tokens reduced from {pre_transform_messages_tokens} to {post_transform_messages_tokens}"
|
||||||
|
)
|
||||||
|
return logs_str, True
|
||||||
|
return "No tokens were truncated.", False
|
||||||
|
|
||||||
|
def _truncate_str_to_tokens(self, contents: Union[str, list], n_tokens: int) -> Union[str, list]:
|
||||||
|
if isinstance(contents, str):
|
||||||
|
return self._truncate_tokens(contents, n_tokens)
|
||||||
|
elif isinstance(contents, list):
|
||||||
|
return self._truncate_multimodal_text(contents, n_tokens)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Contents must be a string or a list of dictionaries. Received type: {type(contents)}")
|
||||||
|
|
||||||
|
def _truncate_multimodal_text(self, contents: list[dict[str, Any]], n_tokens: int) -> list[dict[str, Any]]:
|
||||||
|
"""Truncates text content within a list of multimodal elements, preserving the overall structure."""
|
||||||
|
tmp_contents = []
|
||||||
|
for content in contents:
|
||||||
|
if content["type"] == "text":
|
||||||
|
truncated_text = self._truncate_tokens(content["text"], n_tokens)
|
||||||
|
tmp_contents.append({"type": "text", "text": truncated_text})
|
||||||
|
else:
|
||||||
|
tmp_contents.append(content)
|
||||||
|
return tmp_contents
|
||||||
|
|
||||||
|
def _truncate_tokens(self, text: str, n_tokens: int) -> str:
|
||||||
|
encoding = tiktoken.encoding_for_model(self._model) # Get the appropriate tokenizer
|
||||||
|
|
||||||
|
encoded_tokens = encoding.encode(text)
|
||||||
|
truncated_tokens = encoded_tokens[:n_tokens]
|
||||||
|
truncated_text = encoding.decode(truncated_tokens) # Decode back to text
|
||||||
|
|
||||||
|
return truncated_text
|
||||||
|
|
||||||
|
def _validate_max_tokens(self, max_tokens: Optional[int] = None) -> Optional[int]:
|
||||||
|
if max_tokens is not None and max_tokens < 0:
|
||||||
|
raise ValueError("max_tokens and max_tokens_per_message must be None or greater than or equal to 0")
|
||||||
|
|
||||||
|
try:
|
||||||
|
allowed_tokens = token_count_utils.get_max_token_limit(self._model)
|
||||||
|
except Exception:
|
||||||
|
print(colored(f"Model {self._model} not found in token_count_utils.", "yellow"))
|
||||||
|
allowed_tokens = None
|
||||||
|
|
||||||
|
if max_tokens is not None and allowed_tokens is not None and max_tokens > allowed_tokens:
|
||||||
|
print(
|
||||||
|
colored(
|
||||||
|
f"Max token was set to {max_tokens}, but {self._model} can only accept {allowed_tokens} tokens. Capping it to {allowed_tokens}.",
|
||||||
|
"yellow",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return allowed_tokens
|
||||||
|
|
||||||
|
return max_tokens if max_tokens is not None else sys.maxsize
|
||||||
|
|
||||||
|
def _validate_min_tokens(self, min_tokens: Optional[int], max_tokens: Optional[int]) -> int:
|
||||||
|
if min_tokens is None:
|
||||||
|
return 0
|
||||||
|
if min_tokens < 0:
|
||||||
|
raise ValueError("min_tokens must be None or greater than or equal to 0.")
|
||||||
|
if max_tokens is not None and min_tokens > max_tokens:
|
||||||
|
raise ValueError("min_tokens must not be more than max_tokens.")
|
||||||
|
return min_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class TextMessageCompressor:
|
||||||
|
"""A transform for compressing text messages in a conversation history.
|
||||||
|
|
||||||
|
It uses a specified text compression method to reduce the token count of messages, which can lead to more efficient
|
||||||
|
processing and response generation by downstream models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
text_compressor: Optional[TextCompressor] = None,
|
||||||
|
min_tokens: Optional[int] = None,
|
||||||
|
compression_params: dict = dict(),
|
||||||
|
cache: Optional[AbstractCache] = None,
|
||||||
|
filter_dict: Optional[dict[str, Any]] = None,
|
||||||
|
exclude_filter: bool = True,
|
||||||
|
):
|
||||||
|
"""Args:
|
||||||
|
text_compressor (TextCompressor or None): An instance of a class that implements the TextCompressor
|
||||||
|
protocol. If None, it defaults to LLMLingua.
|
||||||
|
min_tokens (int or None): Minimum number of tokens in messages to apply the transformation. Must be greater
|
||||||
|
than or equal to 0 if not None. If None, no threshold-based compression is applied.
|
||||||
|
compression_args (dict): A dictionary of arguments for the compression method. Defaults to an empty
|
||||||
|
dictionary.
|
||||||
|
cache (None or AbstractCache): The cache client to use to store and retrieve previously compressed messages.
|
||||||
|
If None, no caching will be used.
|
||||||
|
filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
|
||||||
|
If None, no filters will be applied.
|
||||||
|
exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
|
||||||
|
excluded from compression. If False, messages that match the filter will be compressed.
|
||||||
|
"""
|
||||||
|
if text_compressor is None:
|
||||||
|
text_compressor = LLMLingua()
|
||||||
|
|
||||||
|
self._validate_min_tokens(min_tokens)
|
||||||
|
|
||||||
|
self._text_compressor = text_compressor
|
||||||
|
self._min_tokens = min_tokens
|
||||||
|
self._compression_args = compression_params
|
||||||
|
self._filter_dict = filter_dict
|
||||||
|
self._exclude_filter = exclude_filter
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
self._cache = Cache.disk()
|
||||||
|
else:
|
||||||
|
self._cache = cache
|
||||||
|
|
||||||
|
# Optimizing savings calculations to optimize log generation
|
||||||
|
self._recent_tokens_savings = 0
|
||||||
|
|
||||||
|
def apply_transform(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""Applies compression to messages in a conversation history based on the specified configuration.
|
||||||
|
|
||||||
|
The function processes each message according to the `compression_args` and `min_tokens` settings, applying
|
||||||
|
the specified compression configuration and returning a new list of messages with reduced token counts
|
||||||
|
where possible.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (List[Dict]): A list of message dictionaries to be compressed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: A list of dictionaries with the message content compressed according to the configured
|
||||||
|
method and scope.
|
||||||
|
"""
|
||||||
|
# Make sure there is at least one message
|
||||||
|
if not messages:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
|
||||||
|
if not transforms_util.min_tokens_reached(messages, self._min_tokens):
|
||||||
|
return messages
|
||||||
|
|
||||||
|
total_savings = 0
|
||||||
|
processed_messages = messages.copy()
|
||||||
|
for message in processed_messages:
|
||||||
|
# Some messages may not have content.
|
||||||
|
if not transforms_util.is_content_right_type(message.get("content")):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not transforms_util.should_transform_message(message, self._filter_dict, self._exclude_filter):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if transforms_util.is_content_text_empty(message["content"]):
|
||||||
|
continue
|
||||||
|
|
||||||
|
cache_key = transforms_util.cache_key(message["content"], self._min_tokens)
|
||||||
|
cached_content = transforms_util.cache_content_get(self._cache, cache_key)
|
||||||
|
if cached_content is not None:
|
||||||
|
message["content"], savings = cached_content
|
||||||
|
else:
|
||||||
|
message["content"], savings = self._compress(message["content"])
|
||||||
|
|
||||||
|
transforms_util.cache_content_set(self._cache, cache_key, message["content"], savings)
|
||||||
|
|
||||||
|
assert isinstance(savings, int)
|
||||||
|
total_savings += savings
|
||||||
|
|
||||||
|
self._recent_tokens_savings = total_savings
|
||||||
|
return processed_messages
|
||||||
|
|
||||||
|
def get_logs(
|
||||||
|
self, pre_transform_messages: list[dict[str, Any]], post_transform_messages: list[dict[str, Any]]
|
||||||
|
) -> tuple[str, bool]:
|
||||||
|
if self._recent_tokens_savings > 0:
|
||||||
|
return f"{self._recent_tokens_savings} tokens saved with text compression.", True
|
||||||
|
else:
|
||||||
|
return "No tokens saved with text compression.", False
|
||||||
|
|
||||||
|
def _compress(self, content: MessageContentType) -> tuple[MessageContentType, int]:
|
||||||
|
"""Compresses the given text or multimodal content using the specified compression method."""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return self._compress_text(content)
|
||||||
|
elif isinstance(content, list):
|
||||||
|
return self._compress_multimodal(content)
|
||||||
|
else:
|
||||||
|
return content, 0
|
||||||
|
|
||||||
|
def _compress_multimodal(self, content: MessageContentType) -> tuple[MessageContentType, int]:
|
||||||
|
tokens_saved = 0
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, dict) and "text" in item:
|
||||||
|
item["text"], savings = self._compress_text(item["text"])
|
||||||
|
tokens_saved += savings
|
||||||
|
|
||||||
|
elif isinstance(item, str):
|
||||||
|
item, savings = self._compress_text(item)
|
||||||
|
tokens_saved += savings
|
||||||
|
|
||||||
|
return content, tokens_saved
|
||||||
|
|
||||||
|
def _compress_text(self, text: str) -> tuple[str, int]:
|
||||||
|
"""Compresses the given text using the specified compression method."""
|
||||||
|
compressed_text = self._text_compressor.compress_text(text, **self._compression_args)
|
||||||
|
|
||||||
|
savings = 0
|
||||||
|
if "origin_tokens" in compressed_text and "compressed_tokens" in compressed_text:
|
||||||
|
savings = compressed_text["origin_tokens"] - compressed_text["compressed_tokens"]
|
||||||
|
|
||||||
|
return compressed_text["compressed_prompt"], savings
|
||||||
|
|
||||||
|
def _validate_min_tokens(self, min_tokens: Optional[int]):
|
||||||
|
if min_tokens is not None and min_tokens <= 0:
|
||||||
|
raise ValueError("min_tokens must be greater than 0 or None")
|
||||||
|
|
||||||
|
|
||||||
|
class TextMessageContentName:
|
||||||
|
"""A transform for including the agent's name in the content of a message.
|
||||||
|
|
||||||
|
How to create and apply the transform:
|
||||||
|
# Imports
|
||||||
|
from autogen.agentchat.contrib.capabilities import transform_messages, transforms
|
||||||
|
|
||||||
|
# Create Transform
|
||||||
|
name_transform = transforms.TextMessageContentName(position="start", format_string="'{name}' said:\n")
|
||||||
|
|
||||||
|
# Create the TransformMessages
|
||||||
|
context_handling = transform_messages.TransformMessages(
|
||||||
|
transforms=[
|
||||||
|
name_transform
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add it to an agent so when they run inference it will apply to the messages
|
||||||
|
context_handling.add_to_agent(my_agent)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
position: str = "start",
|
||||||
|
format_string: str = "{name}:\n",
|
||||||
|
deduplicate: bool = True,
|
||||||
|
filter_dict: Optional[dict[str, Any]] = None,
|
||||||
|
exclude_filter: bool = True,
|
||||||
|
):
|
||||||
|
"""Args:
|
||||||
|
position (str): The position to add the name to the content. The possible options are 'start' or 'end'. Defaults to 'start'.
|
||||||
|
format_string (str): The f-string to format the message name with. Use '{name}' as a placeholder for the agent's name. Defaults to '{name}:\n' and must contain '{name}'.
|
||||||
|
deduplicate (bool): Whether to deduplicate the formatted string so it doesn't appear twice (sometimes the LLM will add it to new messages itself). Defaults to True.
|
||||||
|
filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
|
||||||
|
If None, no filters will be applied.
|
||||||
|
exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
|
||||||
|
excluded from compression. If False, messages that match the filter will be compressed.
|
||||||
|
"""
|
||||||
|
assert isinstance(position, str) and position in ["start", "end"]
|
||||||
|
assert isinstance(format_string, str) and "{name}" in format_string
|
||||||
|
assert isinstance(deduplicate, bool) and deduplicate is not None
|
||||||
|
|
||||||
|
self._position = position
|
||||||
|
self._format_string = format_string
|
||||||
|
self._deduplicate = deduplicate
|
||||||
|
self._filter_dict = filter_dict
|
||||||
|
self._exclude_filter = exclude_filter
|
||||||
|
|
||||||
|
# Track the number of messages changed for logging
|
||||||
|
self._messages_changed = 0
|
||||||
|
|
||||||
|
def apply_transform(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""Applies the name change to the message based on the position and format string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (List[Dict]): A list of message dictionaries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: A list of dictionaries with the message content updated with names.
|
||||||
|
"""
|
||||||
|
# Make sure there is at least one message
|
||||||
|
if not messages:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
messages_changed = 0
|
||||||
|
processed_messages = copy.deepcopy(messages)
|
||||||
|
for message in processed_messages:
|
||||||
|
# Some messages may not have content.
|
||||||
|
if not transforms_util.is_content_right_type(
|
||||||
|
message.get("content")
|
||||||
|
) or not transforms_util.is_content_right_type(message.get("name")):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not transforms_util.should_transform_message(message, self._filter_dict, self._exclude_filter):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if transforms_util.is_content_text_empty(message["content"]) or transforms_util.is_content_text_empty(
|
||||||
|
message["name"]
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get and format the name in the content
|
||||||
|
content = message["content"]
|
||||||
|
formatted_name = self._format_string.format(name=message["name"])
|
||||||
|
|
||||||
|
if self._position == "start":
|
||||||
|
if not self._deduplicate or not content.startswith(formatted_name):
|
||||||
|
message["content"] = f"{formatted_name}{content}"
|
||||||
|
|
||||||
|
messages_changed += 1
|
||||||
|
else:
|
||||||
|
if not self._deduplicate or not content.endswith(formatted_name):
|
||||||
|
message["content"] = f"{content}{formatted_name}"
|
||||||
|
|
||||||
|
messages_changed += 1
|
||||||
|
|
||||||
|
self._messages_changed = messages_changed
|
||||||
|
return processed_messages
|
||||||
|
|
||||||
|
def get_logs(
|
||||||
|
self, pre_transform_messages: list[dict[str, Any]], post_transform_messages: list[dict[str, Any]]
|
||||||
|
) -> tuple[str, bool]:
|
||||||
|
if self._messages_changed > 0:
|
||||||
|
return f"{self._messages_changed} message(s) changed to incorporate name.", True
|
||||||
|
else:
|
||||||
|
return "No messages changed to incorporate name.", False
|
||||||
@@ -0,0 +1,122 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
from collections.abc import Hashable
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from .... import token_count_utils
|
||||||
|
from ....cache.abstract_cache_base import AbstractCache
|
||||||
|
from ....oai.openai_utils import filter_config
|
||||||
|
from ....types import MessageContentType
|
||||||
|
|
||||||
|
|
||||||
|
def cache_key(content: MessageContentType, *args: Hashable) -> str:
|
||||||
|
"""Calculates the cache key for the given message content and any other hashable args.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content (MessageContentType): The message content to calculate the cache key for.
|
||||||
|
*args: Any additional hashable args to include in the cache key.
|
||||||
|
"""
|
||||||
|
str_keys = [str(key) for key in (content, *args)]
|
||||||
|
return "".join(str_keys)
|
||||||
|
|
||||||
|
|
||||||
|
def cache_content_get(cache: Optional[AbstractCache], key: str) -> Optional[tuple[MessageContentType, ...]]:
|
||||||
|
"""Retrieves cached content from the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache (None or AbstractCache): The cache to retrieve the content from. If None, the cache is ignored.
|
||||||
|
key (str): The key to retrieve the content from.
|
||||||
|
"""
|
||||||
|
if cache:
|
||||||
|
cached_value = cache.get(key)
|
||||||
|
if cached_value:
|
||||||
|
return cached_value
|
||||||
|
|
||||||
|
|
||||||
|
def cache_content_set(cache: Optional[AbstractCache], key: str, content: MessageContentType, *extra_values):
|
||||||
|
"""Sets content into the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache (None or AbstractCache): The cache to set the content into. If None, the cache is ignored.
|
||||||
|
key (str): The key to set the content into.
|
||||||
|
content (MessageContentType): The message content to set into the cache.
|
||||||
|
*extra_values: Additional values to be passed to the cache.
|
||||||
|
"""
|
||||||
|
if cache:
|
||||||
|
cache_value = (content, *extra_values)
|
||||||
|
cache.set(key, cache_value)
|
||||||
|
|
||||||
|
|
||||||
|
def min_tokens_reached(messages: list[dict[str, Any]], min_tokens: Optional[int]) -> bool:
|
||||||
|
"""Returns True if the total number of tokens in the messages is greater than or equal to the specified value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (List[Dict]): A list of messages to check.
|
||||||
|
min_tokens (None or int): The minimum number of tokens to check for.
|
||||||
|
"""
|
||||||
|
if not min_tokens:
|
||||||
|
return True
|
||||||
|
|
||||||
|
messages_tokens = sum(count_text_tokens(msg["content"]) for msg in messages if "content" in msg)
|
||||||
|
return messages_tokens >= min_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def count_text_tokens(content: MessageContentType) -> int:
|
||||||
|
"""Calculates the number of text tokens in the given message content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content (MessageContentType): The message content to calculate the number of text tokens for.
|
||||||
|
"""
|
||||||
|
token_count = 0
|
||||||
|
if isinstance(content, str):
|
||||||
|
token_count = token_count_utils.count_token(content)
|
||||||
|
elif isinstance(content, list):
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
token_count += token_count_utils.count_token(item)
|
||||||
|
else:
|
||||||
|
token_count += count_text_tokens(item.get("text", ""))
|
||||||
|
return token_count
|
||||||
|
|
||||||
|
|
||||||
|
def is_content_right_type(content: Any) -> bool:
|
||||||
|
"""A helper function to check if the passed in content is of the right type."""
|
||||||
|
return isinstance(content, (str, list))
|
||||||
|
|
||||||
|
|
||||||
|
def is_content_text_empty(content: MessageContentType) -> bool:
|
||||||
|
"""Checks if the content of the message does not contain any text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content (MessageContentType): The message content to check.
|
||||||
|
"""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content == ""
|
||||||
|
elif isinstance(content, list):
|
||||||
|
texts = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
texts.append(item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
texts.append(item.get("text", ""))
|
||||||
|
return not any(texts)
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def should_transform_message(message: dict[str, Any], filter_dict: Optional[dict[str, Any]], exclude: bool) -> bool:
|
||||||
|
"""Validates whether the transform should be applied according to the filter dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (Dict[str, Any]): The message to validate.
|
||||||
|
filter_dict (None or Dict[str, Any]): The filter dictionary to validate against. If None, the transform is always applied.
|
||||||
|
exclude (bool): Whether to exclude messages that match the filter dictionary.
|
||||||
|
"""
|
||||||
|
if not filter_dict:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return len(filter_config([message], filter_dict, exclude)) > 0
|
||||||
@@ -0,0 +1,212 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
import copy
|
||||||
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
|
from ....code_utils import content_str
|
||||||
|
from ....oai.client import OpenAIWrapper
|
||||||
|
from ...assistant_agent import ConversableAgent
|
||||||
|
from ..img_utils import (
|
||||||
|
convert_base64_to_data_uri,
|
||||||
|
get_image_data,
|
||||||
|
get_pil_image,
|
||||||
|
gpt4v_formatter,
|
||||||
|
)
|
||||||
|
from .agent_capability import AgentCapability
|
||||||
|
|
||||||
|
DEFAULT_DESCRIPTION_PROMPT = (
|
||||||
|
"Write a detailed caption for this image. "
|
||||||
|
"Pay special attention to any details that might be useful or relevant "
|
||||||
|
"to the ongoing conversation."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VisionCapability(AgentCapability):
|
||||||
|
"""We can add vision capability to regular ConversableAgent, even if the agent does not have the multimodal capability,
|
||||||
|
such as GPT-3.5-turbo agent, Llama, Orca, or Mistral agents. This vision capability will invoke a LMM client to describe
|
||||||
|
the image (captioning) before sending the information to the agent's actual client.
|
||||||
|
|
||||||
|
The vision capability will hook to the ConversableAgent's `process_last_received_message`.
|
||||||
|
|
||||||
|
Some technical details:
|
||||||
|
When the agent (who has the vision capability) received an message, it will:
|
||||||
|
1. _process_received_message:
|
||||||
|
a. _append_oai_message
|
||||||
|
2. generate_reply: if the agent is a MultimodalAgent, it will also use the image tag.
|
||||||
|
a. hook process_last_received_message (NOTE: this is where the vision capability will be hooked to.)
|
||||||
|
b. hook process_all_messages_before_reply
|
||||||
|
3. send:
|
||||||
|
a. hook process_message_before_send
|
||||||
|
b. _append_oai_message
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
lmm_config: dict[str, Any],
|
||||||
|
description_prompt: Optional[str] = DEFAULT_DESCRIPTION_PROMPT,
|
||||||
|
custom_caption_func: Callable = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initializes a new instance, setting up the configuration for interacting with
|
||||||
|
a Language Multimodal (LMM) client and specifying optional parameters for image
|
||||||
|
description and captioning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lmm_config (Dict): Configuration for the LMM client, which is used to call
|
||||||
|
the LMM service for describing the image. This must be a dictionary containing
|
||||||
|
the necessary configuration parameters. If `lmm_config` is False or an empty dictionary,
|
||||||
|
it is considered invalid, and initialization will assert.
|
||||||
|
description_prompt (Optional[str], optional): The prompt to use for generating
|
||||||
|
descriptions of the image. This parameter allows customization of the
|
||||||
|
prompt passed to the LMM service. Defaults to `DEFAULT_DESCRIPTION_PROMPT` if not provided.
|
||||||
|
custom_caption_func (Callable, optional): A callable that, if provided, will be used
|
||||||
|
to generate captions for images. This allows for custom captioning logic outside
|
||||||
|
of the standard LMM service interaction.
|
||||||
|
The callable should take three parameters as input:
|
||||||
|
1. an image URL (or local location)
|
||||||
|
2. image_data (a PIL image)
|
||||||
|
3. lmm_client (to call remote LMM)
|
||||||
|
and then return a description (as string).
|
||||||
|
If not provided, captioning will rely on the LMM client configured via `lmm_config`.
|
||||||
|
If provided, we will not run the default self._get_image_caption method.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If neither a valid `lmm_config` nor a `custom_caption_func` is provided,
|
||||||
|
an AssertionError is raised to indicate that the Vision Capability requires
|
||||||
|
one of these to be valid for operation.
|
||||||
|
"""
|
||||||
|
self._lmm_config = lmm_config
|
||||||
|
self._description_prompt = description_prompt
|
||||||
|
self._parent_agent = None
|
||||||
|
|
||||||
|
if lmm_config:
|
||||||
|
self._lmm_client = OpenAIWrapper(**lmm_config)
|
||||||
|
else:
|
||||||
|
self._lmm_client = None
|
||||||
|
|
||||||
|
self._custom_caption_func = custom_caption_func
|
||||||
|
assert self._lmm_config or custom_caption_func, (
|
||||||
|
"Vision Capability requires a valid lmm_config or custom_caption_func."
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_to_agent(self, agent: ConversableAgent) -> None:
|
||||||
|
self._parent_agent = agent
|
||||||
|
|
||||||
|
# Append extra info to the system message.
|
||||||
|
agent.update_system_message(agent.system_message + "\nYou've been given the ability to interpret images.")
|
||||||
|
|
||||||
|
# Register a hook for processing the last message.
|
||||||
|
agent.register_hook(hookable_method="process_last_received_message", hook=self.process_last_received_message)
|
||||||
|
|
||||||
|
def process_last_received_message(self, content: Union[str, list[dict[str, Any]]]) -> str:
|
||||||
|
"""Processes the last received message content by normalizing and augmenting it
|
||||||
|
with descriptions of any included images. The function supports input content
|
||||||
|
as either a string or a list of dictionaries, where each dictionary represents
|
||||||
|
a content item (e.g., text, image). If the content contains image URLs, it
|
||||||
|
fetches the image data, generates a caption for each image, and inserts the
|
||||||
|
caption into the augmented content.
|
||||||
|
|
||||||
|
The function aims to transform the content into a format compatible with GPT-4V
|
||||||
|
multimodal inputs, specifically by formatting strings into PIL-compatible
|
||||||
|
images if needed and appending text descriptions for images. This allows for
|
||||||
|
a more accessible presentation of the content, especially in contexts where
|
||||||
|
images cannot be displayed directly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content (Union[str, List[dict[str, Any]]]): The last received message content, which
|
||||||
|
can be a plain text string or a list of dictionaries representing
|
||||||
|
different types of content items (e.g., text, image_url).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The augmented message content
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If an item in the content list is not a dictionary.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Assuming `self._get_image_caption(img_data)` returns
|
||||||
|
"A beautiful sunset over the mountains" for the image.
|
||||||
|
|
||||||
|
- Input as String:
|
||||||
|
content = "Check out this cool photo!"
|
||||||
|
Output: "Check out this cool photo!"
|
||||||
|
(Content is a string without an image, remains unchanged.)
|
||||||
|
|
||||||
|
- Input as String, with image location:
|
||||||
|
content = "What's weather in this cool photo: `<img http://example.com/photo.jpg>`"
|
||||||
|
Output: "What's weather in this cool photo: `<img http://example.com/photo.jpg>` in case you can not see, the caption of this image is:
|
||||||
|
A beautiful sunset over the mountains\n"
|
||||||
|
(Caption added after the image)
|
||||||
|
|
||||||
|
- Input as List with Text Only:
|
||||||
|
content = `[{"type": "text", "text": "Here's an interesting fact."}]`
|
||||||
|
Output: "Here's an interesting fact."
|
||||||
|
(No images in the content, it remains unchanged.)
|
||||||
|
|
||||||
|
- Input as List with Image URL:
|
||||||
|
```python
|
||||||
|
content = [
|
||||||
|
{"type": "text", "text": "What's weather in this cool photo:"},
|
||||||
|
{"type": "image_url", "image_url": "http://example.com/photo.jpg"},
|
||||||
|
]
|
||||||
|
```
|
||||||
|
Output: "What's weather in this cool photo: `<img http://example.com/photo.jpg>` in case you can not see, the caption of this image is:
|
||||||
|
A beautiful sunset over the mountains\n"
|
||||||
|
(Caption added after the image)
|
||||||
|
"""
|
||||||
|
copy.deepcopy(content)
|
||||||
|
# normalize the content into the gpt-4v format for multimodal
|
||||||
|
# we want to keep the URL format to keep it concise.
|
||||||
|
if isinstance(content, str):
|
||||||
|
content = gpt4v_formatter(content, img_format="url")
|
||||||
|
|
||||||
|
aug_content: str = ""
|
||||||
|
for item in content:
|
||||||
|
assert isinstance(item, dict)
|
||||||
|
if item["type"] == "text":
|
||||||
|
aug_content += item["text"]
|
||||||
|
elif item["type"] == "image_url":
|
||||||
|
img_url = item["image_url"]
|
||||||
|
img_caption = ""
|
||||||
|
|
||||||
|
if self._custom_caption_func:
|
||||||
|
img_caption = self._custom_caption_func(img_url, get_pil_image(img_url), self._lmm_client)
|
||||||
|
elif self._lmm_client:
|
||||||
|
img_data = get_image_data(img_url)
|
||||||
|
img_caption = self._get_image_caption(img_data)
|
||||||
|
else:
|
||||||
|
img_caption = ""
|
||||||
|
|
||||||
|
aug_content += f"<img {img_url}> in case you can not see, the caption of this image is: {img_caption}\n"
|
||||||
|
else:
|
||||||
|
print(f"Warning: the input type should either be `test` or `image_url`. Skip {item['type']} here.")
|
||||||
|
|
||||||
|
return aug_content
|
||||||
|
|
||||||
|
def _get_image_caption(self, img_data: str) -> str:
|
||||||
|
"""Args:
|
||||||
|
img_data (str): base64 encoded image data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: caption for the given image.
|
||||||
|
"""
|
||||||
|
response = self._lmm_client.create(
|
||||||
|
context=None,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": self._description_prompt},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": convert_base64_to_data_uri(img_data),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
description = response.choices[0].message.content
|
||||||
|
return content_str(description)
|
||||||
411
mm_agents/coact/autogen/agentchat/contrib/img_utils.py
Normal file
411
mm_agents/coact/autogen/agentchat/contrib/img_utils.py
Normal file
@@ -0,0 +1,411 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
import base64
|
||||||
|
import copy
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from io import BytesIO
|
||||||
|
from math import ceil
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from ...import_utils import optional_import_block, require_optional_import
|
||||||
|
from .. import utils
|
||||||
|
|
||||||
|
with optional_import_block():
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
# Parameters for token counting for images for different models
|
||||||
|
MODEL_PARAMS = {
|
||||||
|
"gpt-4-vision": {
|
||||||
|
"max_edge": 2048,
|
||||||
|
"min_edge": 768,
|
||||||
|
"tile_size": 512,
|
||||||
|
"base_token_count": 85,
|
||||||
|
"token_multiplier": 170,
|
||||||
|
},
|
||||||
|
"gpt-4o-mini": {
|
||||||
|
"max_edge": 2048,
|
||||||
|
"min_edge": 768,
|
||||||
|
"tile_size": 512,
|
||||||
|
"base_token_count": 2833,
|
||||||
|
"token_multiplier": 5667,
|
||||||
|
},
|
||||||
|
"gpt-4o": {"max_edge": 2048, "min_edge": 768, "tile_size": 512, "base_token_count": 85, "token_multiplier": 170},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@require_optional_import("PIL", "unknown")
|
||||||
|
def get_pil_image(image_file: Union[str, "Image.Image"]) -> "Image.Image":
|
||||||
|
"""Loads an image from a file and returns a PIL Image object.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
image_file (str, or Image): The filename, URL, URI, or base64 string of the image file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image.Image: The PIL Image object.
|
||||||
|
"""
|
||||||
|
if isinstance(image_file, Image.Image):
|
||||||
|
# Already a PIL Image object
|
||||||
|
return image_file
|
||||||
|
|
||||||
|
# Remove quotes if existed
|
||||||
|
if image_file.startswith('"') and image_file.endswith('"'):
|
||||||
|
image_file = image_file[1:-1]
|
||||||
|
if image_file.startswith("'") and image_file.endswith("'"):
|
||||||
|
image_file = image_file[1:-1]
|
||||||
|
|
||||||
|
if image_file.startswith("http://") or image_file.startswith("https://"):
|
||||||
|
# A URL file
|
||||||
|
response = requests.get(image_file)
|
||||||
|
content = BytesIO(response.content)
|
||||||
|
image = Image.open(content)
|
||||||
|
# Match base64-encoded image URIs for supported formats: jpg, jpeg, png, gif, bmp, webp
|
||||||
|
elif re.match(r"data:image/(?:jpg|jpeg|png|gif|bmp|webp);base64,", image_file):
|
||||||
|
# A URI. Remove the prefix and decode the base64 string.
|
||||||
|
base64_data = re.sub(r"data:image/(?:jpg|jpeg|png|gif|bmp|webp);base64,", "", image_file)
|
||||||
|
image = _to_pil(base64_data)
|
||||||
|
elif os.path.exists(image_file):
|
||||||
|
# A local file
|
||||||
|
image = Image.open(image_file)
|
||||||
|
else:
|
||||||
|
# base64 encoded string
|
||||||
|
image = _to_pil(image_file)
|
||||||
|
|
||||||
|
return image.convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@require_optional_import("PIL", "unknown")
|
||||||
|
def get_image_data(image_file: Union[str, "Image.Image"], use_b64=True) -> bytes:
|
||||||
|
"""Loads an image and returns its data either as raw bytes or in base64-encoded format.
|
||||||
|
|
||||||
|
This function first loads an image from the specified file, URL, or base64 string using
|
||||||
|
the `get_pil_image` function. It then saves this image in memory in PNG format and
|
||||||
|
retrieves its binary content. Depending on the `use_b64` flag, this binary content is
|
||||||
|
either returned directly or as a base64-encoded string.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
image_file (str, or Image): The path to the image file, a URL to an image, or a base64-encoded
|
||||||
|
string of the image.
|
||||||
|
use_b64 (bool): If True, the function returns a base64-encoded string of the image data.
|
||||||
|
If False, it returns the raw byte data of the image. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: The image data in raw bytes if `use_b64` is False, or a base64-encoded string
|
||||||
|
if `use_b64` is True.
|
||||||
|
"""
|
||||||
|
image = get_pil_image(image_file)
|
||||||
|
|
||||||
|
buffered = BytesIO()
|
||||||
|
image.save(buffered, format="PNG")
|
||||||
|
content = buffered.getvalue()
|
||||||
|
|
||||||
|
if use_b64:
|
||||||
|
return base64.b64encode(content).decode("utf-8")
|
||||||
|
else:
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
@require_optional_import("PIL", "unknown")
|
||||||
|
def llava_formatter(prompt: str, order_image_tokens: bool = False) -> tuple[str, list[str]]:
|
||||||
|
"""Formats the input prompt by replacing image tags and returns the new prompt along with image locations.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- prompt (str): The input string that may contain image tags like `<img ...>`.
|
||||||
|
- order_image_tokens (bool, optional): Whether to order the image tokens with numbers.
|
||||||
|
It will be useful for GPT-4V. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- Tuple[str, List[str]]: A tuple containing the formatted string and a list of images (loaded in b64 format).
|
||||||
|
"""
|
||||||
|
# Initialize variables
|
||||||
|
new_prompt = prompt
|
||||||
|
image_locations = []
|
||||||
|
images = []
|
||||||
|
image_count = 0
|
||||||
|
|
||||||
|
# Regular expression pattern for matching <img ...> tags
|
||||||
|
img_tag_pattern = re.compile(r"<img ([^>]+)>")
|
||||||
|
|
||||||
|
# Find all image tags
|
||||||
|
for match in img_tag_pattern.finditer(prompt):
|
||||||
|
image_location = match.group(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
img_data = get_image_data(image_location)
|
||||||
|
except Exception as e:
|
||||||
|
# Remove the token
|
||||||
|
print(f"Warning! Unable to load image from {image_location}, because of {e}")
|
||||||
|
new_prompt = new_prompt.replace(match.group(0), "", 1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
image_locations.append(image_location)
|
||||||
|
images.append(img_data)
|
||||||
|
|
||||||
|
# Increment the image count and replace the tag in the prompt
|
||||||
|
new_token = f"<image {image_count}>" if order_image_tokens else "<image>"
|
||||||
|
|
||||||
|
new_prompt = new_prompt.replace(match.group(0), new_token, 1)
|
||||||
|
image_count += 1
|
||||||
|
|
||||||
|
return new_prompt, images
|
||||||
|
|
||||||
|
|
||||||
|
@require_optional_import("PIL", "unknown")
|
||||||
|
def pil_to_data_uri(image: "Image.Image") -> str:
|
||||||
|
"""Converts a PIL Image object to a data URI.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
image (Image.Image): The PIL Image object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The data URI string.
|
||||||
|
"""
|
||||||
|
buffered = BytesIO()
|
||||||
|
image.save(buffered, format="PNG")
|
||||||
|
content = buffered.getvalue()
|
||||||
|
return convert_base64_to_data_uri(base64.b64encode(content).decode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
def convert_base64_to_data_uri(base64_image):
|
||||||
|
def _get_mime_type_from_data_uri(base64_image):
|
||||||
|
# Decode the base64 string
|
||||||
|
image_data = base64.b64decode(base64_image)
|
||||||
|
# Check the first few bytes for known signatures
|
||||||
|
if image_data.startswith(b"\xff\xd8\xff"):
|
||||||
|
return "image/jpeg"
|
||||||
|
elif image_data.startswith(b"\x89PNG\r\n\x1a\n"):
|
||||||
|
return "image/png"
|
||||||
|
elif image_data.startswith(b"GIF87a") or image_data.startswith(b"GIF89a"):
|
||||||
|
return "image/gif"
|
||||||
|
elif image_data.startswith(b"RIFF") and image_data[8:12] == b"WEBP":
|
||||||
|
return "image/webp"
|
||||||
|
return "image/jpeg" # use jpeg for unknown formats, best guess.
|
||||||
|
|
||||||
|
mime_type = _get_mime_type_from_data_uri(base64_image)
|
||||||
|
data_uri = f"data:{mime_type};base64,{base64_image}"
|
||||||
|
return data_uri
|
||||||
|
|
||||||
|
|
||||||
|
@require_optional_import("PIL", "unknown")
|
||||||
|
def gpt4v_formatter(prompt: str, img_format: str = "uri") -> list[Union[str, dict[str, Any]]]:
|
||||||
|
"""Formats the input prompt by replacing image tags and returns a list of text and images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): The input string that may contain image tags like `<img ...>`.
|
||||||
|
img_format (str): what image format should be used. One of "uri", "url", "pil".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Union[str, dict[str, Any]]]: A list of alternating text and image dictionary items.
|
||||||
|
"""
|
||||||
|
assert img_format in ["uri", "url", "pil"]
|
||||||
|
|
||||||
|
output = []
|
||||||
|
last_index = 0
|
||||||
|
image_count = 0
|
||||||
|
|
||||||
|
# Find all image tags
|
||||||
|
for parsed_tag in utils.parse_tags_from_content("img", prompt):
|
||||||
|
image_location = parsed_tag["attr"]["src"]
|
||||||
|
try:
|
||||||
|
if img_format == "pil":
|
||||||
|
img_data = get_pil_image(image_location)
|
||||||
|
elif img_format == "uri":
|
||||||
|
img_data = get_image_data(image_location)
|
||||||
|
img_data = convert_base64_to_data_uri(img_data)
|
||||||
|
elif img_format == "url":
|
||||||
|
img_data = image_location
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown image format {img_format}")
|
||||||
|
except Exception as e:
|
||||||
|
# Warning and skip this token
|
||||||
|
print(f"Warning! Unable to load image from {image_location}, because {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Add text before this image tag to output list
|
||||||
|
output.append({"type": "text", "text": prompt[last_index : parsed_tag["match"].start()]})
|
||||||
|
|
||||||
|
# Add image data to output list
|
||||||
|
output.append({"type": "image_url", "image_url": {"url": img_data}})
|
||||||
|
|
||||||
|
last_index = parsed_tag["match"].end()
|
||||||
|
image_count += 1
|
||||||
|
|
||||||
|
# Add remaining text to output list
|
||||||
|
if last_index < len(prompt):
|
||||||
|
output.append({"type": "text", "text": prompt[last_index:]})
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def extract_img_paths(paragraph: str) -> list:
|
||||||
|
"""Extract image paths (URLs or local paths) from a text paragraph.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
paragraph (str): The input text paragraph.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of extracted image paths.
|
||||||
|
"""
|
||||||
|
# Regular expression to match image URLs and file paths.
|
||||||
|
# This regex detects URLs and file paths with common image extensions, including support for the webp format.
|
||||||
|
img_path_pattern = re.compile(
|
||||||
|
r"\b(?:http[s]?://\S+\.(?:jpg|jpeg|png|gif|bmp|webp)|\S+\.(?:jpg|jpeg|png|gif|bmp|webp))\b", re.IGNORECASE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find all matches in the paragraph
|
||||||
|
img_paths = re.findall(img_path_pattern, paragraph)
|
||||||
|
return img_paths
|
||||||
|
|
||||||
|
|
||||||
|
@require_optional_import("PIL", "unknown")
|
||||||
|
def _to_pil(data: str) -> "Image.Image":
|
||||||
|
"""Converts a base64 encoded image data string to a PIL Image object.
|
||||||
|
|
||||||
|
This function first decodes the base64 encoded string to bytes, then creates a BytesIO object from the bytes,
|
||||||
|
and finally creates and returns a PIL Image object from the BytesIO object.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
data (str): The encoded image data string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image.Image: The PIL Image object created from the input data.
|
||||||
|
"""
|
||||||
|
return Image.open(BytesIO(base64.b64decode(data)))
|
||||||
|
|
||||||
|
|
||||||
|
@require_optional_import("PIL", "unknown")
|
||||||
|
def message_formatter_pil_to_b64(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""Converts the PIL image URLs in the messages to base64 encoded data URIs.
|
||||||
|
|
||||||
|
This function iterates over a list of message dictionaries. For each message,
|
||||||
|
if it contains a 'content' key with a list of items, it looks for items
|
||||||
|
with an 'image_url' key. The function then converts the PIL image URL
|
||||||
|
(pointed to by 'image_url') to a base64 encoded data URI.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
messages (List[Dict]): A list of message dictionaries. Each dictionary
|
||||||
|
may contain a 'content' key with a list of items,
|
||||||
|
some of which might be image URLs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: A new list of message dictionaries with PIL image URLs in the
|
||||||
|
'image_url' key converted to base64 encoded data URIs.
|
||||||
|
|
||||||
|
Example Input:
|
||||||
|
example 1:
|
||||||
|
```python
|
||||||
|
[
|
||||||
|
{'content': [{'type': 'text', 'text': 'You are a helpful AI assistant.'}], 'role': 'system'},
|
||||||
|
{'content': [
|
||||||
|
{'type': 'text', 'text': "What's the breed of this dog here?"},
|
||||||
|
{'type': 'image_url', 'image_url': {'url': a PIL.Image.Image}},
|
||||||
|
{'type': 'text', 'text': '.'}],
|
||||||
|
'role': 'user'}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
Example Output:
|
||||||
|
example 1:
|
||||||
|
```python
|
||||||
|
[
|
||||||
|
{'content': [{'type': 'text', 'text': 'You are a helpful AI assistant.'}], 'role': 'system'},
|
||||||
|
{'content': [
|
||||||
|
{'type': 'text', 'text': "What's the breed of this dog here?"},
|
||||||
|
{'type': 'image_url', 'image_url': {'url': a B64 Image}},
|
||||||
|
{'type': 'text', 'text': '.'}],
|
||||||
|
'role': 'user'}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
new_messages = []
|
||||||
|
for message in messages:
|
||||||
|
# deepcopy to avoid modifying the original message.
|
||||||
|
message = copy.deepcopy(message)
|
||||||
|
if isinstance(message, dict) and "content" in message:
|
||||||
|
# First, if the content is a string, parse it into a list of parts.
|
||||||
|
# This is for tool output that contains images.
|
||||||
|
if isinstance(message["content"], str):
|
||||||
|
message["content"] = gpt4v_formatter(message["content"], img_format="pil")
|
||||||
|
|
||||||
|
# Second, if the content is a list, process any image parts.
|
||||||
|
if isinstance(message["content"], list):
|
||||||
|
for item in message["content"]:
|
||||||
|
if (
|
||||||
|
isinstance(item, dict)
|
||||||
|
and "image_url" in item
|
||||||
|
and isinstance(item["image_url"]["url"], Image.Image)
|
||||||
|
):
|
||||||
|
item["image_url"]["url"] = pil_to_data_uri(item["image_url"]["url"])
|
||||||
|
|
||||||
|
new_messages.append(message)
|
||||||
|
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
|
||||||
|
@require_optional_import("PIL", "unknown")
|
||||||
|
def num_tokens_from_gpt_image(
|
||||||
|
image_data: Union[str, "Image.Image"], model: str = "gpt-4-vision", low_quality: bool = False
|
||||||
|
) -> int:
|
||||||
|
"""Calculate the number of tokens required to process an image based on its dimensions
|
||||||
|
after scaling for different GPT models. Supports "gpt-4-vision", "gpt-4o", and "gpt-4o-mini".
|
||||||
|
This function scales the image so that its longest edge is at most 2048 pixels and its shortest
|
||||||
|
edge is at most 768 pixels (for "gpt-4-vision"). It then calculates the number of 512x512 tiles
|
||||||
|
needed to cover the scaled image and computes the total tokens based on the number of these tiles.
|
||||||
|
|
||||||
|
Reference: https://openai.com/api/pricing/
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_data : Union[str, Image.Image]: The image data which can either be a base64 encoded string, a URL, a file path, or a PIL Image object.
|
||||||
|
model: str: The model being used for image processing. Can be "gpt-4-vision", "gpt-4o", or "gpt-4o-mini".
|
||||||
|
low_quality: bool: Whether to use low-quality processing. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The total number of tokens required for processing the image.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
--------
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> img = Image.new("RGB", (2500, 2500), color="red")
|
||||||
|
>>> num_tokens_from_gpt_image(img, model="gpt-4-vision")
|
||||||
|
765
|
||||||
|
"""
|
||||||
|
image = get_pil_image(image_data) # PIL Image
|
||||||
|
width, height = image.size
|
||||||
|
|
||||||
|
# Determine model parameters
|
||||||
|
if "gpt-4-vision" in model or "gpt-4-turbo" in model or "gpt-4v" in model or "gpt-4-v" in model:
|
||||||
|
params = MODEL_PARAMS["gpt-4-vision"]
|
||||||
|
elif "gpt-4o-mini" in model:
|
||||||
|
params = MODEL_PARAMS["gpt-4o-mini"]
|
||||||
|
elif "gpt-4o" in model:
|
||||||
|
params = MODEL_PARAMS["gpt-4o"]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model {model} is not supported. Choose 'gpt-4-vision', 'gpt-4-turbo', 'gpt-4v', 'gpt-4-v', 'gpt-4o', or 'gpt-4o-mini'."
|
||||||
|
)
|
||||||
|
|
||||||
|
if low_quality:
|
||||||
|
return params["base_token_count"]
|
||||||
|
|
||||||
|
# 1. Constrain the longest edge
|
||||||
|
if max(width, height) > params["max_edge"]:
|
||||||
|
scale_factor = params["max_edge"] / max(width, height)
|
||||||
|
width, height = int(width * scale_factor), int(height * scale_factor)
|
||||||
|
|
||||||
|
# 2. Further constrain the shortest edge
|
||||||
|
if min(width, height) > params["min_edge"]:
|
||||||
|
scale_factor = params["min_edge"] / min(width, height)
|
||||||
|
width, height = int(width * scale_factor), int(height * scale_factor)
|
||||||
|
|
||||||
|
# 3. Count how many tiles are needed to cover the image
|
||||||
|
tiles_width = ceil(width / params["tile_size"])
|
||||||
|
tiles_height = ceil(height / params["tile_size"])
|
||||||
|
total_tokens = params["base_token_count"] + params["token_multiplier"] * (tiles_width * tiles_height)
|
||||||
|
|
||||||
|
return total_tokens
|
||||||
@@ -0,0 +1,153 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
import copy
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
from ... import OpenAIWrapper
|
||||||
|
from ...code_utils import content_str
|
||||||
|
from .. import Agent, ConversableAgent
|
||||||
|
from ..contrib.img_utils import (
|
||||||
|
gpt4v_formatter,
|
||||||
|
message_formatter_pil_to_b64,
|
||||||
|
)
|
||||||
|
|
||||||
|
DEFAULT_LMM_SYS_MSG = """You are a helpful AI assistant."""
|
||||||
|
DEFAULT_MODEL = "gpt-4-vision-preview"
|
||||||
|
|
||||||
|
|
||||||
|
class MultimodalConversableAgent(ConversableAgent):
|
||||||
|
DEFAULT_CONFIG = {
|
||||||
|
"model": DEFAULT_MODEL,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
system_message: Optional[Union[str, list]] = DEFAULT_LMM_SYS_MSG,
|
||||||
|
is_termination_msg: str = None,
|
||||||
|
*args,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
"""Args:
|
||||||
|
name (str): agent name.
|
||||||
|
system_message (str): system message for the OpenAIWrapper inference.
|
||||||
|
Please override this attribute if you want to reprogram the agent.
|
||||||
|
**kwargs (dict): Please refer to other kwargs in
|
||||||
|
[ConversableAgent](/docs/api-reference/autogen/ConversableAgent#conversableagent).
|
||||||
|
"""
|
||||||
|
super().__init__(
|
||||||
|
name,
|
||||||
|
system_message,
|
||||||
|
is_termination_msg=is_termination_msg,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
# call the setter to handle special format.
|
||||||
|
self.update_system_message(system_message)
|
||||||
|
self._is_termination_msg = (
|
||||||
|
is_termination_msg
|
||||||
|
if is_termination_msg is not None
|
||||||
|
else (lambda x: content_str(x.get("content")) == "TERMINATE")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Override the `generate_oai_reply`
|
||||||
|
self.replace_reply_func(ConversableAgent.generate_oai_reply, MultimodalConversableAgent.generate_oai_reply)
|
||||||
|
self.replace_reply_func(
|
||||||
|
ConversableAgent.a_generate_oai_reply,
|
||||||
|
MultimodalConversableAgent.a_generate_oai_reply,
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_system_message(self, system_message: Union[dict[str, Any], list[str], str]):
|
||||||
|
"""Update the system message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_message (str): system message for the OpenAIWrapper inference.
|
||||||
|
"""
|
||||||
|
self._oai_system_message[0]["content"] = self._message_to_dict(system_message)["content"]
|
||||||
|
self._oai_system_message[0]["role"] = "system"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _message_to_dict(message: Union[dict[str, Any], list[str], str]) -> dict:
|
||||||
|
"""Convert a message to a dictionary. This implementation
|
||||||
|
handles the GPT-4V formatting for easier prompts.
|
||||||
|
|
||||||
|
The message can be a string, a dictionary, or a list of dictionaries:
|
||||||
|
- If it's a string, it will be cast into a list and placed in the 'content' field.
|
||||||
|
- If it's a list, it will be directly placed in the 'content' field.
|
||||||
|
- If it's a dictionary, it is already in message dict format. The 'content' field of this dictionary
|
||||||
|
will be processed using the gpt4v_formatter.
|
||||||
|
"""
|
||||||
|
if isinstance(message, str):
|
||||||
|
return {"content": gpt4v_formatter(message, img_format="pil")}
|
||||||
|
if isinstance(message, list):
|
||||||
|
return {"content": message}
|
||||||
|
if isinstance(message, dict):
|
||||||
|
assert "content" in message, "The message dict must have a `content` field"
|
||||||
|
if isinstance(message["content"], str):
|
||||||
|
message = copy.deepcopy(message)
|
||||||
|
message["content"] = gpt4v_formatter(message["content"], img_format="pil")
|
||||||
|
try:
|
||||||
|
content_str(message["content"])
|
||||||
|
except (TypeError, ValueError) as e:
|
||||||
|
print("The `content` field should be compatible with the content_str function!")
|
||||||
|
raise e
|
||||||
|
return message
|
||||||
|
raise ValueError(f"Unsupported message type: {type(message)}")
|
||||||
|
|
||||||
|
def generate_oai_reply(
|
||||||
|
self,
|
||||||
|
messages: Optional[list[dict[str, Any]]] = None,
|
||||||
|
sender: Optional[Agent] = None,
|
||||||
|
config: Optional[OpenAIWrapper] = None,
|
||||||
|
) -> tuple[bool, Optional[Union[str, dict[str, Any]]]]:
|
||||||
|
"""Generate a reply using autogen.oai."""
|
||||||
|
client = self.client if config is None else config
|
||||||
|
if client is None:
|
||||||
|
return False, None
|
||||||
|
if messages is None:
|
||||||
|
messages = self._oai_messages[sender]
|
||||||
|
|
||||||
|
messages_with_b64_img = message_formatter_pil_to_b64(self._oai_system_message + messages)
|
||||||
|
|
||||||
|
new_messages = []
|
||||||
|
for message in messages_with_b64_img:
|
||||||
|
if 'tool_responses' in message:
|
||||||
|
for tool_response in message['tool_responses']:
|
||||||
|
tmp_image = None
|
||||||
|
tmp_list = []
|
||||||
|
for ctx in message['content']:
|
||||||
|
if ctx['type'] == 'image_url':
|
||||||
|
tmp_image = ctx
|
||||||
|
tmp_list.append({
|
||||||
|
'role': 'tool',
|
||||||
|
'tool_call_id': tool_response['tool_call_id'],
|
||||||
|
'content': [message['content'][0]]
|
||||||
|
})
|
||||||
|
if tmp_image:
|
||||||
|
tmp_list.append({
|
||||||
|
'role': 'user',
|
||||||
|
'content': [
|
||||||
|
{'type': 'text', 'text': 'I take a screenshot for the current state for you.'},
|
||||||
|
tmp_image
|
||||||
|
]
|
||||||
|
})
|
||||||
|
new_messages.extend(tmp_list)
|
||||||
|
else:
|
||||||
|
new_messages.append(message)
|
||||||
|
messages_with_b64_img = new_messages.copy()
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: #1143 handle token limit exceeded error
|
||||||
|
response = client.create(
|
||||||
|
context=messages[-1].pop("context", None), messages=messages_with_b64_img, agent=self.name
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: line 301, line 271 is converting messages to dict. Can be removed after ChatCompletionMessage_to_dict is merged.
|
||||||
|
extracted_response = client.extract_text_or_completion_object(response)[0]
|
||||||
|
if not isinstance(extracted_response, str):
|
||||||
|
extracted_response = extracted_response.model_dump()
|
||||||
|
return True, extracted_response
|
||||||
4023
mm_agents/coact/autogen/agentchat/conversable_agent.py
Normal file
4023
mm_agents/coact/autogen/agentchat/conversable_agent.py
Normal file
File diff suppressed because it is too large
Load Diff
64
mm_agents/coact/autogen/agentchat/group/__init__.py
Normal file
64
mm_agents/coact/autogen/agentchat/group/__init__.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
__all__: list[str] = []
|
||||||
|
|
||||||
|
from .available_condition import ExpressionAvailableCondition, StringAvailableCondition
|
||||||
|
from .context_condition import ExpressionContextCondition, StringContextCondition
|
||||||
|
from .context_expression import ContextExpression
|
||||||
|
from .context_str import ContextStr
|
||||||
|
from .context_variables import ContextVariables
|
||||||
|
from .handoffs import Handoffs
|
||||||
|
from .llm_condition import ContextStrLLMCondition, StringLLMCondition
|
||||||
|
from .on_condition import OnCondition
|
||||||
|
from .on_context_condition import OnContextCondition
|
||||||
|
from .reply_result import ReplyResult
|
||||||
|
from .speaker_selection_result import SpeakerSelectionResult
|
||||||
|
from .targets.group_chat_target import GroupChatConfig, GroupChatTarget
|
||||||
|
|
||||||
|
"""
|
||||||
|
from .targets.group_manager_target import (
|
||||||
|
GroupManagerSelectionMessageContextStr,
|
||||||
|
GroupManagerSelectionMessageString,
|
||||||
|
GroupManagerTarget,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
from .targets.transition_target import (
|
||||||
|
AgentNameTarget,
|
||||||
|
AgentTarget,
|
||||||
|
AskUserTarget,
|
||||||
|
NestedChatTarget,
|
||||||
|
RevertToUserTarget,
|
||||||
|
StayTarget,
|
||||||
|
TerminateTarget,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AgentNameTarget",
|
||||||
|
"AgentTarget",
|
||||||
|
"AskUserTarget",
|
||||||
|
"ContextExpression",
|
||||||
|
"ContextStr",
|
||||||
|
"ContextStrLLMCondition",
|
||||||
|
"ContextVariables",
|
||||||
|
"ExpressionAvailableCondition",
|
||||||
|
"ExpressionContextCondition",
|
||||||
|
"GroupChatConfig",
|
||||||
|
"GroupChatTarget",
|
||||||
|
# "GroupManagerSelectionMessageContextStr",
|
||||||
|
# "GroupManagerSelectionMessageString",
|
||||||
|
# "GroupManagerTarget",
|
||||||
|
"Handoffs",
|
||||||
|
"NestedChatTarget",
|
||||||
|
"OnCondition",
|
||||||
|
"OnContextCondition",
|
||||||
|
"ReplyResult",
|
||||||
|
"RevertToUserTarget",
|
||||||
|
"SpeakerSelectionResult",
|
||||||
|
"StayTarget",
|
||||||
|
"StringAvailableCondition",
|
||||||
|
"StringContextCondition",
|
||||||
|
"StringLLMCondition",
|
||||||
|
"TerminateTarget",
|
||||||
|
]
|
||||||
@@ -0,0 +1,91 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .context_expression import ContextExpression
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
# Avoid circular import
|
||||||
|
from ..conversable_agent import ConversableAgent
|
||||||
|
|
||||||
|
__all__ = ["AvailableCondition", "ExpressionAvailableCondition", "StringAvailableCondition"]
|
||||||
|
|
||||||
|
|
||||||
|
class AvailableCondition(BaseModel):
|
||||||
|
"""Protocol for determining if a condition is available to be evaluated."""
|
||||||
|
|
||||||
|
def is_available(self, agent: "ConversableAgent", messages: list[dict[str, Any]]) -> bool:
|
||||||
|
"""Determine if the condition should be considered for evaluation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent evaluating the condition
|
||||||
|
messages: The conversation history
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the condition should be evaluated, False otherwise
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Requires subclasses to implement.")
|
||||||
|
|
||||||
|
|
||||||
|
class StringAvailableCondition(AvailableCondition):
|
||||||
|
"""String-based available condition.
|
||||||
|
|
||||||
|
This condition checks if a named context variable exists and is truthy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
context_variable: str
|
||||||
|
|
||||||
|
def __init__(self, context_variable: str, **data: Any) -> None:
|
||||||
|
"""Initialize with a context variable name as a positional parameter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_variable: The name of the context variable to check
|
||||||
|
data: Additional data for the parent class
|
||||||
|
"""
|
||||||
|
super().__init__(context_variable=context_variable, **data)
|
||||||
|
|
||||||
|
def is_available(self, agent: "ConversableAgent", messages: list[dict[str, Any]]) -> bool:
|
||||||
|
"""Check if the named context variable is truthy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent with context variables
|
||||||
|
messages: The conversation history (not used)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the variable exists and is truthy, False otherwise
|
||||||
|
"""
|
||||||
|
return bool(agent.context_variables.get(self.context_variable, False))
|
||||||
|
|
||||||
|
|
||||||
|
class ExpressionAvailableCondition(AvailableCondition):
|
||||||
|
"""Expression-based available condition.
|
||||||
|
|
||||||
|
This condition evaluates a ContextExpression against the context variables.
|
||||||
|
"""
|
||||||
|
|
||||||
|
expression: ContextExpression
|
||||||
|
|
||||||
|
def __init__(self, expression: ContextExpression, **data: Any) -> None:
|
||||||
|
"""Initialize with an expression as a positional parameter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expression: The context expression to evaluate
|
||||||
|
data: Additional data for the parent class
|
||||||
|
"""
|
||||||
|
super().__init__(expression=expression, **data)
|
||||||
|
|
||||||
|
def is_available(self, agent: "ConversableAgent", messages: list[dict[str, Any]]) -> bool:
|
||||||
|
"""Evaluate the expression against the context variables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent with context variables
|
||||||
|
messages: The conversation history (not used)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Boolean result of the expression evaluation
|
||||||
|
"""
|
||||||
|
return self.expression.evaluate(agent.context_variables)
|
||||||
77
mm_agents/coact/autogen/agentchat/group/context_condition.py
Normal file
77
mm_agents/coact/autogen/agentchat/group/context_condition.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .context_expression import ContextExpression
|
||||||
|
from .context_variables import ContextVariables
|
||||||
|
|
||||||
|
__all__ = ["ContextCondition", "ExpressionContextCondition", "StringContextCondition"]
|
||||||
|
|
||||||
|
|
||||||
|
class ContextCondition(BaseModel):
|
||||||
|
"""Protocol for conditions evaluated directly using context variables."""
|
||||||
|
|
||||||
|
def evaluate(self, context_variables: ContextVariables) -> bool:
|
||||||
|
"""Evaluate the condition to a boolean result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_variables: The context variables to evaluate against
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Boolean result of the condition evaluation
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Requires subclasses to implement.")
|
||||||
|
|
||||||
|
|
||||||
|
class StringContextCondition(ContextCondition):
|
||||||
|
"""Simple string-based context condition.
|
||||||
|
|
||||||
|
This condition checks if a named context variable exists and is truthy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
variable_name: str
|
||||||
|
|
||||||
|
def evaluate(self, context_variables: ContextVariables) -> bool:
|
||||||
|
"""Check if the named context variable is truthy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_variables: The context variables to check against
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the variable exists and is truthy, False otherwise
|
||||||
|
"""
|
||||||
|
return bool(context_variables.get(self.variable_name, False))
|
||||||
|
|
||||||
|
|
||||||
|
class ExpressionContextCondition(ContextCondition):
|
||||||
|
"""Complex expression-based context condition.
|
||||||
|
|
||||||
|
This condition evaluates a ContextExpression against the context variables.
|
||||||
|
"""
|
||||||
|
|
||||||
|
expression: ContextExpression
|
||||||
|
|
||||||
|
def __init__(self, expression: ContextExpression, **data: Any) -> None:
|
||||||
|
"""Initialize with an expression as a positional parameter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expression: The context expression to evaluate
|
||||||
|
data: Additional data for the parent class
|
||||||
|
"""
|
||||||
|
super().__init__(expression=expression, **data)
|
||||||
|
|
||||||
|
def evaluate(self, context_variables: ContextVariables) -> bool:
|
||||||
|
"""Evaluate the expression against the context variables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_variables: The context variables to evaluate against
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Boolean result of the expression evaluation
|
||||||
|
"""
|
||||||
|
return self.expression.evaluate(context_variables)
|
||||||
238
mm_agents/coact/autogen/agentchat/group/context_expression.py
Normal file
238
mm_agents/coact/autogen/agentchat/group/context_expression.py
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from ...doc_utils import export_module
|
||||||
|
from .context_variables import ContextVariables
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@export_module("autogen")
|
||||||
|
class ContextExpression:
|
||||||
|
"""A class to evaluate logical expressions using context variables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expression (str): A string containing a logical expression with context variable references.
|
||||||
|
- Variable references use ${var_name} syntax: ${logged_in}, ${attempts}
|
||||||
|
- String literals can use normal quotes: 'hello', "world"
|
||||||
|
- Supported operators:
|
||||||
|
- Logical: not/!, and/&, or/|
|
||||||
|
- Comparison: >, <, >=, <=, ==, !=
|
||||||
|
- Supported functions:
|
||||||
|
- len(${var_name}): Gets the length of a list, string, or other collection
|
||||||
|
- Parentheses can be used for grouping
|
||||||
|
- Examples:
|
||||||
|
- "not ${logged_in} and ${is_admin} or ${guest_checkout}"
|
||||||
|
- "!${logged_in} & ${is_admin} | ${guest_checkout}"
|
||||||
|
- "len(${orders}) > 0 & ${user_active}"
|
||||||
|
- "len(${cart_items}) == 0 | ${checkout_started}"
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SyntaxError: If the expression cannot be parsed
|
||||||
|
ValueError: If the expression contains disallowed operations
|
||||||
|
"""
|
||||||
|
|
||||||
|
expression: str
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
# Validate the expression immediately upon creation
|
||||||
|
try:
|
||||||
|
# Extract variable references and replace with placeholders
|
||||||
|
self._variable_names = self._extract_variable_names(self.expression)
|
||||||
|
|
||||||
|
# Convert symbolic operators to Python keywords
|
||||||
|
python_expr = self._convert_to_python_syntax(self.expression)
|
||||||
|
|
||||||
|
# Sanitize for AST parsing
|
||||||
|
sanitized_expr = self._prepare_for_ast(python_expr)
|
||||||
|
|
||||||
|
# Use ast to parse and validate the expression
|
||||||
|
self._ast = ast.parse(sanitized_expr, mode="eval")
|
||||||
|
|
||||||
|
# Verify it only contains allowed operations
|
||||||
|
self._validate_operations(self._ast.body)
|
||||||
|
|
||||||
|
# Store the Python-syntax version for evaluation
|
||||||
|
self._python_expr = python_expr
|
||||||
|
|
||||||
|
except SyntaxError as e:
|
||||||
|
raise SyntaxError(f"Invalid expression syntax in '{self.expression}': {str(e)}")
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error validating expression '{self.expression}': {str(e)}")
|
||||||
|
|
||||||
|
def _extract_variable_names(self, expr: str) -> list[str]:
|
||||||
|
"""Extract all variable references ${var_name} from the expression."""
|
||||||
|
# Find all patterns like ${var_name}
|
||||||
|
matches = re.findall(r"\${([^}]*)}", expr)
|
||||||
|
return matches
|
||||||
|
|
||||||
|
def _convert_to_python_syntax(self, expr: str) -> str:
|
||||||
|
"""Convert symbolic operators to Python keywords."""
|
||||||
|
# We need to be careful about operators inside string literals
|
||||||
|
# First, temporarily replace string literals with placeholders
|
||||||
|
string_literals = []
|
||||||
|
|
||||||
|
def replace_string_literal(match: re.Match[str]) -> str:
|
||||||
|
string_literals.append(match.group(0))
|
||||||
|
return f"__STRING_LITERAL_{len(string_literals) - 1}__"
|
||||||
|
|
||||||
|
# Replace both single and double quoted strings
|
||||||
|
expr_without_strings = re.sub(r"'[^']*'|\"[^\"]*\"", replace_string_literal, expr)
|
||||||
|
|
||||||
|
# Handle the NOT operator (!) - no parentheses handling needed
|
||||||
|
# Replace standalone ! before variables or expressions
|
||||||
|
expr_without_strings = re.sub(r"!\s*(\${|\()", "not \\1", expr_without_strings)
|
||||||
|
|
||||||
|
# Handle AND and OR operators - simpler approach without parentheses handling
|
||||||
|
expr_without_strings = re.sub(r"\s+&\s+", " and ", expr_without_strings)
|
||||||
|
expr_without_strings = re.sub(r"\s+\|\s+", " or ", expr_without_strings)
|
||||||
|
|
||||||
|
# Now put string literals back
|
||||||
|
for i, literal in enumerate(string_literals):
|
||||||
|
expr_without_strings = expr_without_strings.replace(f"__STRING_LITERAL_{i}__", literal)
|
||||||
|
|
||||||
|
return expr_without_strings
|
||||||
|
|
||||||
|
def _prepare_for_ast(self, expr: str) -> str:
|
||||||
|
"""Convert the expression to valid Python for AST parsing by replacing variables with placeholders."""
|
||||||
|
# Replace ${var_name} with var_name for AST parsing
|
||||||
|
processed_expr = expr
|
||||||
|
for var_name in self._variable_names:
|
||||||
|
processed_expr = processed_expr.replace(f"${{{var_name}}}", var_name)
|
||||||
|
|
||||||
|
return processed_expr
|
||||||
|
|
||||||
|
def _validate_operations(self, node: ast.AST) -> None:
|
||||||
|
"""Recursively validate that only allowed operations exist in the AST."""
|
||||||
|
allowed_node_types = (
|
||||||
|
# Boolean operations
|
||||||
|
ast.BoolOp,
|
||||||
|
ast.UnaryOp,
|
||||||
|
ast.And,
|
||||||
|
ast.Or,
|
||||||
|
ast.Not,
|
||||||
|
# Comparison operations
|
||||||
|
ast.Compare,
|
||||||
|
ast.Eq,
|
||||||
|
ast.NotEq,
|
||||||
|
ast.Lt,
|
||||||
|
ast.LtE,
|
||||||
|
ast.Gt,
|
||||||
|
ast.GtE,
|
||||||
|
# Basic nodes
|
||||||
|
ast.Name,
|
||||||
|
ast.Load,
|
||||||
|
ast.Constant,
|
||||||
|
ast.Expression,
|
||||||
|
# Support for basic numeric operations in comparisons
|
||||||
|
ast.Num,
|
||||||
|
ast.NameConstant,
|
||||||
|
# Support for negative numbers
|
||||||
|
ast.USub,
|
||||||
|
ast.UnaryOp,
|
||||||
|
# Support for string literals
|
||||||
|
ast.Str,
|
||||||
|
ast.Constant,
|
||||||
|
# Support for function calls (specifically len())
|
||||||
|
ast.Call,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(node, allowed_node_types):
|
||||||
|
raise ValueError(f"Operation type {type(node).__name__} is not allowed in logical expressions")
|
||||||
|
|
||||||
|
# Special validation for function calls - only allow len()
|
||||||
|
if isinstance(node, ast.Call):
|
||||||
|
if not (isinstance(node.func, ast.Name) and node.func.id == "len"):
|
||||||
|
raise ValueError(f"Only the len() function is allowed, got: {getattr(node.func, 'id', 'unknown')}")
|
||||||
|
if len(node.args) != 1:
|
||||||
|
raise ValueError(f"len() function must have exactly one argument, got {len(node.args)}")
|
||||||
|
|
||||||
|
# Special validation for Compare nodes
|
||||||
|
if isinstance(node, ast.Compare):
|
||||||
|
for op in node.ops:
|
||||||
|
if not isinstance(op, (ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE)):
|
||||||
|
raise ValueError(f"Comparison operator {type(op).__name__} is not allowed")
|
||||||
|
|
||||||
|
# Recursively check child nodes
|
||||||
|
for child in ast.iter_child_nodes(node):
|
||||||
|
self._validate_operations(child)
|
||||||
|
|
||||||
|
def evaluate(self, context_variables: ContextVariables) -> bool:
|
||||||
|
"""Evaluate the expression using the provided context variables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_variables: Dictionary of context variables to use for evaluation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: The result of evaluating the expression
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If a variable referenced in the expression is not found in the context
|
||||||
|
"""
|
||||||
|
# Create a modified expression that we can safely evaluate
|
||||||
|
eval_expr = self._python_expr # Use the Python-syntax version
|
||||||
|
|
||||||
|
# First, handle len() functions with variable references inside
|
||||||
|
len_pattern = r"len\(\${([^}]*)}\)"
|
||||||
|
len_matches = list(re.finditer(len_pattern, eval_expr))
|
||||||
|
|
||||||
|
# Process all len() operations first
|
||||||
|
for match in len_matches:
|
||||||
|
var_name = match.group(1)
|
||||||
|
# Check if variable exists in context, raise KeyError if not
|
||||||
|
if not context_variables.contains(var_name):
|
||||||
|
raise KeyError(f"Missing context variable: '{var_name}'")
|
||||||
|
|
||||||
|
var_value = context_variables.get(var_name)
|
||||||
|
|
||||||
|
# Calculate the length - works for lists, strings, dictionaries, etc.
|
||||||
|
try:
|
||||||
|
length_value = len(var_value) # type: ignore[arg-type]
|
||||||
|
except TypeError:
|
||||||
|
# If the value doesn't support len(), treat as 0
|
||||||
|
length_value = 0
|
||||||
|
|
||||||
|
# Replace the len() expression with the actual length
|
||||||
|
full_match = match.group(0)
|
||||||
|
eval_expr = eval_expr.replace(full_match, str(length_value))
|
||||||
|
|
||||||
|
# Then replace remaining variable references with their values
|
||||||
|
for var_name in self._variable_names:
|
||||||
|
# Skip variables that were already processed in len() expressions
|
||||||
|
if any(m.group(1) == var_name for m in len_matches):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if variable exists in context, raise KeyError if not
|
||||||
|
if not context_variables.contains(var_name):
|
||||||
|
raise KeyError(f"Missing context variable: '{var_name}'")
|
||||||
|
|
||||||
|
# Get the value from context
|
||||||
|
var_value = context_variables.get(var_name)
|
||||||
|
|
||||||
|
# Format the value appropriately based on its type
|
||||||
|
if isinstance(var_value, (bool, int, float)):
|
||||||
|
formatted_value = str(var_value)
|
||||||
|
elif isinstance(var_value, str):
|
||||||
|
formatted_value = f"'{var_value}'" # Quote strings
|
||||||
|
elif isinstance(var_value, (list, dict, tuple)):
|
||||||
|
# For collections, convert to their boolean evaluation
|
||||||
|
formatted_value = str(bool(var_value))
|
||||||
|
else:
|
||||||
|
formatted_value = str(var_value)
|
||||||
|
|
||||||
|
# Replace the variable reference with the formatted value
|
||||||
|
eval_expr = eval_expr.replace(f"${{{var_name}}}", formatted_value)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return eval(eval_expr) # type: ignore[no-any-return]
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Error evaluating expression '{self.expression}' (are you sure you're using ${{my_context_variable_key}}): {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"ContextExpression('{self.expression}')"
|
||||||
41
mm_agents/coact/autogen/agentchat/group/context_str.py
Normal file
41
mm_agents/coact/autogen/agentchat/group/context_str.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .context_variables import ContextVariables
|
||||||
|
|
||||||
|
__all__ = ["ContextStr"]
|
||||||
|
|
||||||
|
|
||||||
|
class ContextStr(BaseModel):
|
||||||
|
"""A string that requires context variable substitution.
|
||||||
|
|
||||||
|
Use the format method to substitute context variables into the string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""The string to be substituted with context variables. It is expected that the string will contain `{var}` placeholders and that string format will be able to replace all values."""
|
||||||
|
template: str
|
||||||
|
|
||||||
|
def format(self, context_variables: ContextVariables) -> Optional[str]:
|
||||||
|
"""Substitute context variables into the string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_variables (ContextVariables): The context variables to substitute into the string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: The formatted string with context variables substituted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
context = context_variables.to_dict()
|
||||||
|
|
||||||
|
if not context:
|
||||||
|
return self.template
|
||||||
|
|
||||||
|
return self.template.format(**context)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"ContextStr, unformatted: {self.template}"
|
||||||
192
mm_agents/coact/autogen/agentchat/group/context_variables.py
Normal file
192
mm_agents/coact/autogen/agentchat/group/context_variables.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Any, Generator, Iterable, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
__all__ = ["ContextVariables"]
|
||||||
|
|
||||||
|
# Parameter name for context variables
|
||||||
|
# Use the value in functions and they will be substituted with the context variables:
|
||||||
|
# e.g. def my_function(context_variables: ContextVariables, my_other_parameters: Any) -> Any:
|
||||||
|
__CONTEXT_VARIABLES_PARAM_NAME__ = "context_variables"
|
||||||
|
|
||||||
|
|
||||||
|
class ContextVariables(BaseModel):
|
||||||
|
"""
|
||||||
|
Stores and manages context variables for agentic workflows.
|
||||||
|
|
||||||
|
Utilises a dictionary-like interface for setting, getting, and removing variables.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Internal storage for context variables
|
||||||
|
data: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
def __init__(self, data: Optional[dict[str, Any]] = None, **kwargs: Any) -> None:
|
||||||
|
"""Initialize with data dictionary as an optional positional parameter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Initial dictionary of context variables (optional)
|
||||||
|
kwargs: Additional keyword arguments for the parent class
|
||||||
|
"""
|
||||||
|
init_data = data or {}
|
||||||
|
super().__init__(data=init_data, **kwargs)
|
||||||
|
|
||||||
|
def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
|
||||||
|
"""
|
||||||
|
Get a value from the context by key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The key to retrieve
|
||||||
|
default: The default value to return if key is not found
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The value associated with the key or default if not found
|
||||||
|
"""
|
||||||
|
return self.data.get(key, default)
|
||||||
|
|
||||||
|
def set(self, key: str, value: Any) -> None:
|
||||||
|
"""
|
||||||
|
Set a value in the context by key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The key to set
|
||||||
|
value: The value to store
|
||||||
|
"""
|
||||||
|
self.data[key] = value
|
||||||
|
|
||||||
|
def remove(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
Remove a key from the context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The key to remove
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the key was removed, False if it didn't exist
|
||||||
|
"""
|
||||||
|
if key in self.data:
|
||||||
|
del self.data[key]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def keys(self) -> Iterable[str]:
|
||||||
|
"""
|
||||||
|
Get all keys in the context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An iterable of all keys
|
||||||
|
"""
|
||||||
|
return self.data.keys()
|
||||||
|
|
||||||
|
def values(self) -> Iterable[Any]:
|
||||||
|
"""
|
||||||
|
Get all values in the context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An iterable of all values
|
||||||
|
"""
|
||||||
|
return self.data.values()
|
||||||
|
|
||||||
|
def items(self) -> Iterable[tuple[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get all key-value pairs in the context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An iterable of all key-value pairs
|
||||||
|
"""
|
||||||
|
return self.data.items()
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear all keys and values from the context."""
|
||||||
|
self.data.clear()
|
||||||
|
|
||||||
|
def contains(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a key exists in the context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The key to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the key exists, False otherwise
|
||||||
|
"""
|
||||||
|
return key in self.data
|
||||||
|
|
||||||
|
def update(self, other: dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Update context with key-value pairs from another dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other: Dictionary containing key-value pairs to add
|
||||||
|
"""
|
||||||
|
self.data.update(other)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert context variables to a dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary representation of all context variables
|
||||||
|
"""
|
||||||
|
return self.data.copy()
|
||||||
|
|
||||||
|
# Dictionary-compatible interface
|
||||||
|
def __getitem__(self, key: str) -> Any:
|
||||||
|
"""Get a value using dictionary syntax: context[key]"""
|
||||||
|
try:
|
||||||
|
return self.data[key]
|
||||||
|
except KeyError:
|
||||||
|
raise KeyError(f"Context variable '{key}' not found")
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, value: Any) -> None:
|
||||||
|
"""Set a value using dictionary syntax: context[key] = value"""
|
||||||
|
self.data[key] = value
|
||||||
|
|
||||||
|
def __delitem__(self, key: str) -> None:
|
||||||
|
"""Delete a key using dictionary syntax: del context[key]"""
|
||||||
|
try:
|
||||||
|
del self.data[key]
|
||||||
|
except KeyError:
|
||||||
|
raise KeyError(f"Cannot delete non-existent context variable '{key}'")
|
||||||
|
|
||||||
|
def __contains__(self, key: str) -> bool:
|
||||||
|
"""Check if key exists using 'in' operator: key in context"""
|
||||||
|
return key in self.data
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Get the number of items: len(context)"""
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __iter__(self) -> Generator[tuple[str, Any], None, None]:
|
||||||
|
"""Iterate over keys: for key in context"""
|
||||||
|
for key in self.data:
|
||||||
|
yield (key, self.data[key])
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""String representation of context variables."""
|
||||||
|
return f"ContextVariables({self.data})"
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""Detailed representation of context variables."""
|
||||||
|
return f"ContextVariables(data={self.data!r})"
|
||||||
|
|
||||||
|
# Utility methods
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "ContextVariables":
|
||||||
|
"""
|
||||||
|
Create a new ContextVariables instance from a dictionary.
|
||||||
|
|
||||||
|
E.g.:
|
||||||
|
my_context = {"user_id": "12345", "settings": {"theme": "dark"}}
|
||||||
|
context = ContextVariables.from_dict(my_context)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary of key-value pairs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New ContextVariables instance
|
||||||
|
"""
|
||||||
|
return cls(data=data)
|
||||||
202
mm_agents/coact/autogen/agentchat/group/group_tool_executor.py
Normal file
202
mm_agents/coact/autogen/agentchat/group/group_tool_executor.py
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Annotated, Any, Callable, Optional
|
||||||
|
|
||||||
|
from ...oai import OpenAIWrapper
|
||||||
|
from ...tools import Depends, Tool
|
||||||
|
from ...tools.dependency_injection import inject_params, on
|
||||||
|
from ..agent import Agent
|
||||||
|
from ..conversable_agent import ConversableAgent
|
||||||
|
from .context_variables import __CONTEXT_VARIABLES_PARAM_NAME__, ContextVariables
|
||||||
|
from .reply_result import ReplyResult
|
||||||
|
from .targets.transition_target import TransitionTarget
|
||||||
|
|
||||||
|
__TOOL_EXECUTOR_NAME__ = "_Group_Tool_Executor"
|
||||||
|
|
||||||
|
|
||||||
|
class GroupToolExecutor(ConversableAgent):
|
||||||
|
"""Tool executor for the group chat initiated with initiate_group_chat"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__(
|
||||||
|
name=__TOOL_EXECUTOR_NAME__,
|
||||||
|
system_message="Tool Execution, do not use this agent directly.",
|
||||||
|
human_input_mode="NEVER",
|
||||||
|
code_execution_config=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store the next target from a tool call
|
||||||
|
self._group_next_target: Optional[TransitionTarget] = None
|
||||||
|
|
||||||
|
# Primary tool reply function for handling the tool reply and the ReplyResult and TransitionTarget returns
|
||||||
|
self.register_reply([Agent, None], self._generate_group_tool_reply, remove_other_reply_funcs=True)
|
||||||
|
|
||||||
|
def set_next_target(self, next_target: TransitionTarget) -> None:
|
||||||
|
"""Sets the next target to transition to, used in the determine_next_agent function."""
|
||||||
|
self._group_next_target = next_target
|
||||||
|
|
||||||
|
def get_next_target(self) -> TransitionTarget:
|
||||||
|
"""Gets the next target to transition to."""
|
||||||
|
"""Returns the next target to transition to, if it exists."""
|
||||||
|
if self._group_next_target is None:
|
||||||
|
raise ValueError(
|
||||||
|
"No next target set. Please set a next target before calling this method. Use has_next_target() to check if a next target exists."
|
||||||
|
)
|
||||||
|
return self._group_next_target
|
||||||
|
|
||||||
|
def has_next_target(self) -> bool:
|
||||||
|
"""Checks if there is a next target to transition to."""
|
||||||
|
return self._group_next_target is not None
|
||||||
|
|
||||||
|
def clear_next_target(self) -> None:
|
||||||
|
"""Clears the next target to transition to."""
|
||||||
|
self._group_next_target = None
|
||||||
|
|
||||||
|
def _modify_context_variables_param(
|
||||||
|
self, f: Callable[..., Any], context_variables: ContextVariables
|
||||||
|
) -> Callable[..., Any]:
|
||||||
|
"""Modifies the context_variables parameter to use dependency injection and link it to the group context variables.
|
||||||
|
|
||||||
|
This essentially changes:
|
||||||
|
def some_function(some_variable: int, context_variables: ContextVariables) -> str:
|
||||||
|
|
||||||
|
to:
|
||||||
|
|
||||||
|
def some_function(some_variable: int, context_variables: Annotated[ContextVariables, Depends(on(self.context_variables))]) -> str:
|
||||||
|
"""
|
||||||
|
sig = inspect.signature(f)
|
||||||
|
|
||||||
|
# Check if context_variables parameter exists and update it if so
|
||||||
|
if __CONTEXT_VARIABLES_PARAM_NAME__ in sig.parameters:
|
||||||
|
new_params = []
|
||||||
|
for name, param in sig.parameters.items():
|
||||||
|
if name == __CONTEXT_VARIABLES_PARAM_NAME__:
|
||||||
|
# Replace with new annotation using Depends
|
||||||
|
new_param = param.replace(annotation=Annotated[ContextVariables, Depends(on(context_variables))])
|
||||||
|
new_params.append(new_param)
|
||||||
|
else:
|
||||||
|
new_params.append(param)
|
||||||
|
|
||||||
|
# Update signature
|
||||||
|
new_sig = sig.replace(parameters=new_params)
|
||||||
|
f.__signature__ = new_sig # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
def _change_tool_context_variables_to_depends(
|
||||||
|
self, agent: ConversableAgent, current_tool: Tool, context_variables: ContextVariables
|
||||||
|
) -> None:
|
||||||
|
"""Checks for the context_variables parameter in the tool and updates it to use dependency injection."""
|
||||||
|
|
||||||
|
# If the tool has a context_variables parameter, remove the tool and reregister it without the parameter
|
||||||
|
if __CONTEXT_VARIABLES_PARAM_NAME__ in current_tool.tool_schema["function"]["parameters"]["properties"]:
|
||||||
|
# We'll replace the tool, so start with getting the underlying function
|
||||||
|
tool_func = current_tool._func
|
||||||
|
|
||||||
|
# Remove the Tool from the agent
|
||||||
|
name = current_tool._name
|
||||||
|
description = current_tool._description
|
||||||
|
agent.remove_tool_for_llm(current_tool)
|
||||||
|
|
||||||
|
# Recreate the tool without the context_variables parameter
|
||||||
|
tool_func = self._modify_context_variables_param(current_tool._func, context_variables)
|
||||||
|
tool_func = inject_params(tool_func)
|
||||||
|
new_tool = ConversableAgent._create_tool_if_needed(
|
||||||
|
func_or_tool=tool_func, name=name, description=description
|
||||||
|
)
|
||||||
|
|
||||||
|
# Re-register with the agent
|
||||||
|
agent.register_for_llm()(new_tool)
|
||||||
|
|
||||||
|
def register_agents_functions(self, agents: list[ConversableAgent], context_variables: ContextVariables) -> None:
|
||||||
|
"""Adds the functions of the agents to the group tool executor."""
|
||||||
|
for agent in agents:
|
||||||
|
# As we're moving towards tools and away from function maps, this may not be used
|
||||||
|
self._function_map.update(agent._function_map)
|
||||||
|
|
||||||
|
# Update any agent tools that have context_variables parameters to use Dependency Injection
|
||||||
|
for tool in agent.tools:
|
||||||
|
self._change_tool_context_variables_to_depends(agent, tool, context_variables)
|
||||||
|
|
||||||
|
# Add all tools to the Tool Executor agent
|
||||||
|
for tool in agent.tools:
|
||||||
|
self.register_for_execution(serialize=False, silent_override=True)(tool)
|
||||||
|
|
||||||
|
def _generate_group_tool_reply(
|
||||||
|
self,
|
||||||
|
agent: ConversableAgent,
|
||||||
|
messages: Optional[list[dict[str, Any]]] = None,
|
||||||
|
sender: Optional[Agent] = None,
|
||||||
|
config: Optional[OpenAIWrapper] = None,
|
||||||
|
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||||
|
"""Pre-processes and generates tool call replies.
|
||||||
|
|
||||||
|
This function:
|
||||||
|
1. Adds context_variables back to the tool call for the function, if necessary.
|
||||||
|
2. Generates the tool calls reply.
|
||||||
|
3. Updates context_variables and next_agent based on the tool call response."""
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
config = agent # type: ignore[assignment]
|
||||||
|
if messages is None:
|
||||||
|
messages = agent._oai_messages[sender]
|
||||||
|
|
||||||
|
message = messages[-1]
|
||||||
|
if "tool_calls" in message:
|
||||||
|
tool_call_count = len(message["tool_calls"])
|
||||||
|
|
||||||
|
# Loop through tool calls individually (so context can be updated after each function call)
|
||||||
|
next_target: Optional[TransitionTarget] = None
|
||||||
|
tool_responses_inner = []
|
||||||
|
contents = []
|
||||||
|
for index in range(tool_call_count):
|
||||||
|
message_copy = deepcopy(message)
|
||||||
|
|
||||||
|
# 1. add context_variables to the tool call arguments
|
||||||
|
tool_call = message_copy["tool_calls"][index]
|
||||||
|
|
||||||
|
# Ensure we are only executing the one tool at a time
|
||||||
|
message_copy["tool_calls"] = [tool_call]
|
||||||
|
|
||||||
|
# 2. generate tool calls reply
|
||||||
|
_, tool_message = agent.generate_tool_calls_reply([message_copy])
|
||||||
|
|
||||||
|
if tool_message is None:
|
||||||
|
raise ValueError("Tool call did not return a message")
|
||||||
|
|
||||||
|
# 3. update context_variables and next_agent, convert content to string
|
||||||
|
for tool_response in tool_message["tool_responses"]:
|
||||||
|
content = tool_response.get("content")
|
||||||
|
|
||||||
|
# Tool Call returns that are a target are either a ReplyResult or a TransitionTarget are the next agent
|
||||||
|
if isinstance(content, ReplyResult):
|
||||||
|
if content.context_variables and content.context_variables.to_dict() != {}:
|
||||||
|
agent.context_variables.update(content.context_variables.to_dict())
|
||||||
|
if content.target is not None:
|
||||||
|
next_target = content.target
|
||||||
|
elif isinstance(content, TransitionTarget):
|
||||||
|
next_target = content
|
||||||
|
|
||||||
|
# Serialize the content to a string
|
||||||
|
if content is not None:
|
||||||
|
tool_response["content"] = str(content)
|
||||||
|
|
||||||
|
tool_responses_inner.append(tool_response)
|
||||||
|
contents.append(str(tool_response["content"]))
|
||||||
|
|
||||||
|
self._group_next_target = next_target # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
# Put the tool responses and content strings back into the response message
|
||||||
|
# Caters for multiple tool calls
|
||||||
|
if tool_message is None:
|
||||||
|
raise ValueError("Tool call did not return a message")
|
||||||
|
|
||||||
|
tool_message["tool_responses"] = tool_responses_inner
|
||||||
|
tool_message["content"] = "\n".join(contents)
|
||||||
|
|
||||||
|
return True, tool_message
|
||||||
|
return False, None
|
||||||
636
mm_agents/coact/autogen/agentchat/group/group_utils.py
Normal file
636
mm_agents/coact/autogen/agentchat/group/group_utils.py
Normal file
@@ -0,0 +1,636 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import copy
|
||||||
|
from functools import partial
|
||||||
|
from types import MethodType
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||||
|
|
||||||
|
from ..agent import Agent
|
||||||
|
from ..groupchat import GroupChat, GroupChatManager
|
||||||
|
from .context_variables import ContextVariables
|
||||||
|
from .group_tool_executor import GroupToolExecutor
|
||||||
|
from .targets.group_manager_target import GroupManagerTarget
|
||||||
|
from .targets.transition_target import (
|
||||||
|
AgentNameTarget,
|
||||||
|
AgentTarget,
|
||||||
|
TransitionTarget,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..conversable_agent import ConversableAgent
|
||||||
|
|
||||||
|
# Utility functions for group chat preparation and management
|
||||||
|
# These are extracted from multi_agent_chat.py to avoid circular imports
|
||||||
|
|
||||||
|
|
||||||
|
def update_conditional_functions(agent: "ConversableAgent", messages: list[dict[str, Any]]) -> None:
|
||||||
|
"""Updates the agent's functions based on the OnCondition's available condition.
|
||||||
|
|
||||||
|
All functions are removed and then added back if they are available
|
||||||
|
"""
|
||||||
|
for on_condition in agent.handoffs.llm_conditions:
|
||||||
|
is_available = on_condition.available.is_available(agent, messages) if on_condition.available else True
|
||||||
|
|
||||||
|
# Remove it from their tools
|
||||||
|
for tool in agent.tools:
|
||||||
|
if tool.name == on_condition.llm_function_name:
|
||||||
|
agent.remove_tool_for_llm(tool)
|
||||||
|
break
|
||||||
|
|
||||||
|
# then add the function if it is available, so that the function signature is updated
|
||||||
|
if is_available:
|
||||||
|
agent._add_single_function(
|
||||||
|
_create_on_condition_handoff_function(on_condition.target),
|
||||||
|
on_condition.llm_function_name,
|
||||||
|
on_condition.condition.get_prompt(agent, messages),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def establish_group_agent(agent: "ConversableAgent") -> None:
|
||||||
|
"""Establish the group agent with the group-related attributes and hooks. Not for the tool executor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent ("ConversableAgent"): The agent to establish as a group agent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _group_agent_str(self: "ConversableAgent") -> str:
|
||||||
|
"""Customise the __str__ method to show the agent name for transition messages."""
|
||||||
|
return f"Group agent --> {self.name}"
|
||||||
|
|
||||||
|
# Register the hook to update agent state (except tool executor)
|
||||||
|
agent.register_hook("update_agent_state", update_conditional_functions)
|
||||||
|
|
||||||
|
# Register a reply function to run Python function-based OnContextConditions before any other reply function
|
||||||
|
agent.register_reply(trigger=([Agent, None]), reply_func=_run_oncontextconditions, position=0)
|
||||||
|
|
||||||
|
agent._get_display_name = MethodType(_group_agent_str, agent) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
# Mark this agent as established as a group agent
|
||||||
|
agent._group_is_established = True # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
|
def link_agents_to_group_manager(agents: list[Agent], group_chat_manager: Agent) -> None:
|
||||||
|
"""Link all agents to the GroupChatManager so they can access the underlying GroupChat and other agents.
|
||||||
|
|
||||||
|
This is primarily used so that agents can get to the tool executor to help set the next agent.
|
||||||
|
|
||||||
|
Does not link the Tool Executor agent.
|
||||||
|
"""
|
||||||
|
for agent in agents:
|
||||||
|
agent._group_manager = group_chat_manager # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
|
def _evaluate_after_works_conditions(
|
||||||
|
agent: "ConversableAgent",
|
||||||
|
groupchat: GroupChat,
|
||||||
|
user_agent: Optional["ConversableAgent"],
|
||||||
|
) -> Optional[Union[Agent, str]]:
|
||||||
|
"""Evaluate after_works context conditions for an agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent to evaluate after_works conditions for
|
||||||
|
groupchat: The current group chat
|
||||||
|
user_agent: Optional user proxy agent
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The resolved speaker selection result if a condition matches, None otherwise
|
||||||
|
"""
|
||||||
|
if not hasattr(agent, "handoffs") or not agent.handoffs.after_works: # type: ignore[attr-defined]
|
||||||
|
return None
|
||||||
|
|
||||||
|
for after_work_condition in agent.handoffs.after_works: # type: ignore[attr-defined]
|
||||||
|
# Check if condition is available
|
||||||
|
is_available = (
|
||||||
|
after_work_condition.available.is_available(agent, groupchat.messages)
|
||||||
|
if after_work_condition.available
|
||||||
|
else True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evaluate the condition (None condition means always true)
|
||||||
|
if is_available and (
|
||||||
|
after_work_condition.condition is None or after_work_condition.condition.evaluate(agent.context_variables)
|
||||||
|
):
|
||||||
|
# Condition matched, resolve and return
|
||||||
|
return after_work_condition.target.resolve(
|
||||||
|
groupchat,
|
||||||
|
agent,
|
||||||
|
user_agent,
|
||||||
|
).get_speaker_selection_result(groupchat)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _run_oncontextconditions(
|
||||||
|
agent: "ConversableAgent",
|
||||||
|
messages: Optional[list[dict[str, Any]]] = None,
|
||||||
|
sender: Optional[Agent] = None,
|
||||||
|
config: Optional[Any] = None,
|
||||||
|
) -> tuple[bool, Optional[Union[str, dict[str, Any]]]]:
|
||||||
|
"""Run OnContextConditions for an agent before any other reply function."""
|
||||||
|
for on_condition in agent.handoffs.context_conditions: # type: ignore[attr-defined]
|
||||||
|
is_available = (
|
||||||
|
on_condition.available.is_available(agent, messages if messages else []) if on_condition.available else True
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_available and (
|
||||||
|
on_condition.condition is None or on_condition.condition.evaluate(agent.context_variables)
|
||||||
|
):
|
||||||
|
# Condition has been met, we'll set the Tool Executor's next target
|
||||||
|
# attribute and that will be picked up on the next iteration when
|
||||||
|
# _determine_next_agent is called
|
||||||
|
for agent in agent._group_manager.groupchat.agents: # type: ignore[attr-defined]
|
||||||
|
if isinstance(agent, GroupToolExecutor):
|
||||||
|
agent.set_next_target(on_condition.target)
|
||||||
|
break
|
||||||
|
|
||||||
|
transfer_name = on_condition.target.display_name()
|
||||||
|
|
||||||
|
return True, "[Handing off to " + transfer_name + "]"
|
||||||
|
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
|
||||||
|
def _create_on_condition_handoff_function(target: TransitionTarget) -> Callable[[], TransitionTarget]:
|
||||||
|
"""Creates a function that will be used by the tool call reply function when the condition is met.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target (TransitionTarget): The target to transfer to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable: The transfer function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transfer_to_target() -> TransitionTarget:
|
||||||
|
return target
|
||||||
|
|
||||||
|
return transfer_to_target
|
||||||
|
|
||||||
|
|
||||||
|
def create_on_condition_handoff_functions(agent: "ConversableAgent") -> None:
|
||||||
|
"""Creates the functions for the OnConditions so that the current tool handling works.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent ("ConversableAgent"): The agent to create the functions for.
|
||||||
|
"""
|
||||||
|
# Populate the function names for the handoffs
|
||||||
|
agent.handoffs.set_llm_function_names()
|
||||||
|
|
||||||
|
# Create a function for each OnCondition
|
||||||
|
for on_condition in agent.handoffs.llm_conditions:
|
||||||
|
# Create a function that will be called when the condition is met
|
||||||
|
agent._add_single_function(
|
||||||
|
_create_on_condition_handoff_function(on_condition.target),
|
||||||
|
on_condition.llm_function_name,
|
||||||
|
on_condition.condition.get_prompt(agent, []),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_handoff_agents_in_group(agents: list["ConversableAgent"]) -> None:
|
||||||
|
"""Ensure the agents in handoffs are in the group chat."""
|
||||||
|
agent_names = [agent.name for agent in agents]
|
||||||
|
for agent in agents:
|
||||||
|
for llm_conditions in agent.handoffs.llm_conditions:
|
||||||
|
if (
|
||||||
|
isinstance(llm_conditions.target, (AgentTarget, AgentNameTarget))
|
||||||
|
and llm_conditions.target.agent_name not in agent_names
|
||||||
|
):
|
||||||
|
raise ValueError("Agent in OnCondition Hand-offs must be in the agents list")
|
||||||
|
for context_conditions in agent.handoffs.context_conditions:
|
||||||
|
if (
|
||||||
|
isinstance(context_conditions.target, (AgentTarget, AgentNameTarget))
|
||||||
|
and context_conditions.target.agent_name not in agent_names
|
||||||
|
):
|
||||||
|
raise ValueError("Agent in OnContextCondition Hand-offs must be in the agents list")
|
||||||
|
# Check after_works targets
|
||||||
|
for after_work_condition in agent.handoffs.after_works:
|
||||||
|
if (
|
||||||
|
isinstance(after_work_condition.target, (AgentTarget, AgentNameTarget))
|
||||||
|
and after_work_condition.target.agent_name not in agent_names
|
||||||
|
):
|
||||||
|
raise ValueError("Agent in after work target Hand-offs must be in the agents list")
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_exclude_transit_messages(agents: list["ConversableAgent"]) -> None:
|
||||||
|
"""Preparation for excluding transit messages by getting all tool names and registering a hook on agents to remove those messages."""
|
||||||
|
# get all transit functions names
|
||||||
|
to_be_removed: list[str] = []
|
||||||
|
for agent in agents:
|
||||||
|
for on_condition in agent.handoffs.llm_conditions:
|
||||||
|
if on_condition.llm_function_name:
|
||||||
|
to_be_removed.append(on_condition.llm_function_name)
|
||||||
|
else:
|
||||||
|
raise ValueError("OnCondition must have a function name")
|
||||||
|
|
||||||
|
remove_function = make_remove_function(to_be_removed)
|
||||||
|
|
||||||
|
# register hook to remove transit messages for group agents
|
||||||
|
for agent in agents:
|
||||||
|
agent.register_hook("process_all_messages_before_reply", remove_function)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_group_agents(
|
||||||
|
agents: list["ConversableAgent"],
|
||||||
|
context_variables: ContextVariables,
|
||||||
|
exclude_transit_message: bool = True,
|
||||||
|
) -> tuple[GroupToolExecutor, list["ConversableAgent"]]:
|
||||||
|
"""Validates agents, create the tool executor, wrap necessary targets in agents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agents (list["ConversableAgent"]): List of all agents in the conversation.
|
||||||
|
context_variables (ContextVariables): Context variables to assign to all agents.
|
||||||
|
exclude_transit_message (bool): Whether to exclude transit messages from the agents.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"ConversableAgent": The tool executor agent.
|
||||||
|
list["ConversableAgent"]: List of wrapped agents.
|
||||||
|
"""
|
||||||
|
# Initialise all agents as group agents
|
||||||
|
for agent in agents:
|
||||||
|
if not hasattr(agent, "_group_is_established"):
|
||||||
|
establish_group_agent(agent)
|
||||||
|
|
||||||
|
# Ensure all agents in hand-off after-works are in the passed in agents list
|
||||||
|
ensure_handoff_agents_in_group(agents)
|
||||||
|
|
||||||
|
# Create Tool Executor for the group
|
||||||
|
tool_execution = GroupToolExecutor()
|
||||||
|
|
||||||
|
# Wrap handoff targets in agents that need to be wrapped
|
||||||
|
wrapped_chat_agents: list["ConversableAgent"] = []
|
||||||
|
for agent in agents:
|
||||||
|
wrap_agent_handoff_targets(agent, wrapped_chat_agents)
|
||||||
|
|
||||||
|
# Create the functions for the OnConditions so that the current tool handling works
|
||||||
|
for agent in agents:
|
||||||
|
create_on_condition_handoff_functions(agent)
|
||||||
|
|
||||||
|
# Register all the agents' functions with the tool executor and
|
||||||
|
# use dependency injection for the context variables parameter
|
||||||
|
# Update tool execution agent with all the functions from all the agents
|
||||||
|
tool_execution.register_agents_functions(agents + wrapped_chat_agents, context_variables)
|
||||||
|
|
||||||
|
if exclude_transit_message:
|
||||||
|
prepare_exclude_transit_messages(agents)
|
||||||
|
|
||||||
|
return tool_execution, wrapped_chat_agents
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_agent_handoff_targets(agent: "ConversableAgent", wrapped_agent_list: list["ConversableAgent"]) -> None:
|
||||||
|
"""Wrap handoff targets in agents that need to be wrapped to be part of the group chat.
|
||||||
|
|
||||||
|
Example is NestedChatTarget.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent ("ConversableAgent"): The agent to wrap the handoff targets for.
|
||||||
|
wrapped_agent_list (list["ConversableAgent"]): List of wrapped chat agents that will be appended to.
|
||||||
|
"""
|
||||||
|
# Wrap OnCondition targets
|
||||||
|
for i, handoff_oncondition_requiring_wrapping in enumerate(agent.handoffs.get_llm_conditions_requiring_wrapping()):
|
||||||
|
# Create wrapper agent
|
||||||
|
wrapper_agent = handoff_oncondition_requiring_wrapping.target.create_wrapper_agent(parent_agent=agent, index=i)
|
||||||
|
wrapped_agent_list.append(wrapper_agent)
|
||||||
|
|
||||||
|
# Change this handoff target to point to the newly created agent
|
||||||
|
handoff_oncondition_requiring_wrapping.target = AgentTarget(wrapper_agent)
|
||||||
|
|
||||||
|
for i, handoff_oncontextcondition_requiring_wrapping in enumerate(
|
||||||
|
agent.handoffs.get_context_conditions_requiring_wrapping()
|
||||||
|
):
|
||||||
|
# Create wrapper agent
|
||||||
|
wrapper_agent = handoff_oncontextcondition_requiring_wrapping.target.create_wrapper_agent(
|
||||||
|
parent_agent=agent, index=i
|
||||||
|
)
|
||||||
|
wrapped_agent_list.append(wrapper_agent)
|
||||||
|
|
||||||
|
# Change this handoff target to point to the newly created agent
|
||||||
|
handoff_oncontextcondition_requiring_wrapping.target = AgentTarget(wrapper_agent)
|
||||||
|
|
||||||
|
|
||||||
|
def process_initial_messages(
|
||||||
|
messages: Union[list[dict[str, Any]], str],
|
||||||
|
user_agent: Optional["ConversableAgent"],
|
||||||
|
agents: list["ConversableAgent"],
|
||||||
|
wrapped_agents: list["ConversableAgent"],
|
||||||
|
) -> tuple[list[dict[str, Any]], Optional["ConversableAgent"], list[str], list[Agent]]:
|
||||||
|
"""Process initial messages, validating agent names against messages, and determining the last agent to speak.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Initial messages to process.
|
||||||
|
user_agent: Optional user proxy agent passed in to a_/initiate_group_chat.
|
||||||
|
agents: Agents in the group.
|
||||||
|
wrapped_agents: List of wrapped agents.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[dict[str, Any]]: Processed message(s).
|
||||||
|
Agent: Last agent to speak.
|
||||||
|
list[str]: List of agent names.
|
||||||
|
list[Agent]: List of temporary user proxy agents to add to GroupChat.
|
||||||
|
"""
|
||||||
|
from ..conversable_agent import ConversableAgent # NEED SOLUTION
|
||||||
|
|
||||||
|
if isinstance(messages, str):
|
||||||
|
messages = [{"role": "user", "content": messages}]
|
||||||
|
|
||||||
|
group_agent_names = [agent.name for agent in agents + wrapped_agents]
|
||||||
|
|
||||||
|
# If there's only one message and there's no identified group agent
|
||||||
|
# Start with a user proxy agent, creating one if they haven't passed one in
|
||||||
|
last_agent: Optional[ConversableAgent]
|
||||||
|
temp_user_proxy: Optional[ConversableAgent] = None
|
||||||
|
temp_user_list: list[Agent] = []
|
||||||
|
if len(messages) == 1 and "name" not in messages[0] and not user_agent:
|
||||||
|
temp_user_proxy = ConversableAgent(name="_User", code_execution_config=False, human_input_mode="ALWAYS")
|
||||||
|
last_agent = temp_user_proxy
|
||||||
|
temp_user_list.append(temp_user_proxy)
|
||||||
|
else:
|
||||||
|
last_message = messages[0]
|
||||||
|
if "name" in last_message:
|
||||||
|
if last_message["name"] in group_agent_names:
|
||||||
|
last_agent = next(agent for agent in agents + wrapped_agents if agent.name == last_message["name"]) # type: ignore[assignment]
|
||||||
|
elif user_agent and last_message["name"] == user_agent.name:
|
||||||
|
last_agent = user_agent
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid group agent name in last message: {last_message['name']}")
|
||||||
|
else:
|
||||||
|
last_agent = user_agent if user_agent else temp_user_proxy
|
||||||
|
|
||||||
|
return messages, last_agent, group_agent_names, temp_user_list
|
||||||
|
|
||||||
|
|
||||||
|
def setup_context_variables(
|
||||||
|
tool_execution: "ConversableAgent",
|
||||||
|
agents: list["ConversableAgent"],
|
||||||
|
manager: GroupChatManager,
|
||||||
|
user_agent: Optional["ConversableAgent"],
|
||||||
|
context_variables: ContextVariables,
|
||||||
|
) -> None:
|
||||||
|
"""Assign a common context_variables reference to all agents in the group, including the tool executor, group chat manager, and user proxy agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_execution: The tool execution agent.
|
||||||
|
agents: List of all agents in the conversation.
|
||||||
|
manager: GroupChatManager instance.
|
||||||
|
user_agent: Optional user proxy agent.
|
||||||
|
context_variables: Context variables to assign to all agents.
|
||||||
|
"""
|
||||||
|
for agent in agents + [tool_execution] + [manager] + ([user_agent] if user_agent else []):
|
||||||
|
agent.context_variables = context_variables
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_temp_user_messages(chat_result: Any) -> None:
|
||||||
|
"""Remove temporary user proxy agent name from messages before returning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_result: ChatResult instance.
|
||||||
|
"""
|
||||||
|
for message in chat_result.chat_history:
|
||||||
|
if "name" in message and message["name"] == "_User":
|
||||||
|
del message["name"]
|
||||||
|
|
||||||
|
|
||||||
|
def get_last_agent_speaker(
|
||||||
|
groupchat: GroupChat, group_agent_names: list[str], tool_executor: GroupToolExecutor
|
||||||
|
) -> Agent:
|
||||||
|
"""Get the last group agent from the group chat messages. Not including the tool executor."""
|
||||||
|
last_group_speaker = None
|
||||||
|
for message in reversed(groupchat.messages):
|
||||||
|
if "name" in message and message["name"] in group_agent_names and message["name"] != tool_executor.name:
|
||||||
|
agent = groupchat.agent_by_name(name=message["name"])
|
||||||
|
if agent:
|
||||||
|
last_group_speaker = agent
|
||||||
|
break
|
||||||
|
if last_group_speaker is None:
|
||||||
|
raise ValueError("No group agent found in the message history")
|
||||||
|
|
||||||
|
return last_group_speaker
|
||||||
|
|
||||||
|
|
||||||
|
def determine_next_agent(
|
||||||
|
last_speaker: "ConversableAgent",
|
||||||
|
groupchat: GroupChat,
|
||||||
|
initial_agent: "ConversableAgent",
|
||||||
|
use_initial_agent: bool,
|
||||||
|
tool_executor: GroupToolExecutor,
|
||||||
|
group_agent_names: list[str],
|
||||||
|
user_agent: Optional["ConversableAgent"],
|
||||||
|
group_after_work: TransitionTarget,
|
||||||
|
) -> Optional[Union[Agent, str]]:
|
||||||
|
"""Determine the next agent in the conversation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
last_speaker ("ConversableAgent"): The last agent to speak.
|
||||||
|
groupchat (GroupChat): GroupChat instance.
|
||||||
|
initial_agent ("ConversableAgent"): The initial agent in the conversation.
|
||||||
|
use_initial_agent (bool): Whether to use the initial agent straight away.
|
||||||
|
tool_executor ("ConversableAgent"): The tool execution agent.
|
||||||
|
group_agent_names (list[str]): List of agent names.
|
||||||
|
user_agent (UserProxyAgent): Optional user proxy agent.
|
||||||
|
group_after_work (TransitionTarget): Group-level Transition option when an agent doesn't select the next agent.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Union[Agent, str]]: The next agent or speaker selection method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Logic for determining the next target (anything based on Transition Target: an agent, wrapped agent, TerminateTarget, StayTarget, RevertToUserTarget, GroupManagerTarget, etc.
|
||||||
|
# 1. If it's the first response -> initial agent
|
||||||
|
# 2. If the last message is a tool call -> tool execution agent
|
||||||
|
# 3. If the Tool Executor has determined a next target (e.g. ReplyResult specified target) -> transition to tool reply target
|
||||||
|
# 4. If the user last spoke -> return to the previous agent
|
||||||
|
# NOW "AFTER WORK":
|
||||||
|
# 5. Get the After Work condition (if the agent doesn't have one, get the group-level one)
|
||||||
|
# 6. Resolve and return the After Work condition -> agent / wrapped agent / TerminateTarget / StayTarget / RevertToUserTarget / GroupManagerTarget / etc.
|
||||||
|
|
||||||
|
# 1. If it's the first response, return the initial agent
|
||||||
|
if use_initial_agent:
|
||||||
|
return initial_agent
|
||||||
|
|
||||||
|
# 2. If the last message is a tool call, return the tool execution agent
|
||||||
|
if "tool_calls" in groupchat.messages[-1]:
|
||||||
|
return tool_executor
|
||||||
|
|
||||||
|
# 3. If the Tool Executor has determined a next target, return that
|
||||||
|
if tool_executor.has_next_target():
|
||||||
|
next_agent = tool_executor.get_next_target()
|
||||||
|
tool_executor.clear_next_target()
|
||||||
|
|
||||||
|
if next_agent.can_resolve_for_speaker_selection():
|
||||||
|
return next_agent.resolve(groupchat, last_speaker, user_agent).get_speaker_selection_result(groupchat)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Tool Executor next target must be a valid TransitionTarget that can resolve for speaker selection."
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the last group agent
|
||||||
|
last_agent_speaker = get_last_agent_speaker(groupchat, group_agent_names, tool_executor)
|
||||||
|
|
||||||
|
# If we are returning from a tool execution, return to the last agent that spoke
|
||||||
|
if groupchat.messages[-1]["role"] == "tool":
|
||||||
|
return last_agent_speaker
|
||||||
|
|
||||||
|
# If the user last spoke, return to the agent prior to them (if they don't have an after work, otherwise it's treated like any other agent)
|
||||||
|
if user_agent and last_speaker == user_agent:
|
||||||
|
if not user_agent.handoffs.after_works:
|
||||||
|
return last_agent_speaker
|
||||||
|
else:
|
||||||
|
last_agent_speaker = user_agent
|
||||||
|
|
||||||
|
# AFTER WORK:
|
||||||
|
|
||||||
|
# First, try to evaluate after_works context conditions
|
||||||
|
after_works_result = _evaluate_after_works_conditions(
|
||||||
|
last_agent_speaker, # type: ignore[arg-type]
|
||||||
|
groupchat,
|
||||||
|
user_agent,
|
||||||
|
)
|
||||||
|
if after_works_result is not None:
|
||||||
|
return after_works_result
|
||||||
|
|
||||||
|
# If no after_works conditions matched, use the group-level after_work
|
||||||
|
# Resolve the next agent, termination, or speaker selection method
|
||||||
|
resolved_speaker_selection_result = group_after_work.resolve(
|
||||||
|
groupchat,
|
||||||
|
last_agent_speaker, # type: ignore[arg-type]
|
||||||
|
user_agent,
|
||||||
|
).get_speaker_selection_result(groupchat)
|
||||||
|
|
||||||
|
return resolved_speaker_selection_result
|
||||||
|
|
||||||
|
|
||||||
|
def create_group_transition(
|
||||||
|
initial_agent: "ConversableAgent",
|
||||||
|
tool_execution: GroupToolExecutor,
|
||||||
|
group_agent_names: list[str],
|
||||||
|
user_agent: Optional["ConversableAgent"],
|
||||||
|
group_after_work: TransitionTarget,
|
||||||
|
) -> Callable[["ConversableAgent", GroupChat], Optional[Union[Agent, str]]]:
|
||||||
|
"""Creates a transition function for group chat with enclosed state for the use_initial_agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_agent ("ConversableAgent"): The first agent to speak
|
||||||
|
tool_execution (GroupToolExecutor): The tool execution agent
|
||||||
|
group_agent_names (list[str]): List of all agent names
|
||||||
|
user_agent (UserProxyAgent): Optional user proxy agent
|
||||||
|
group_after_work (TransitionTarget): Group-level after work
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[["ConversableAgent", GroupChat], Optional[Union[Agent, str]]]: The transition function
|
||||||
|
"""
|
||||||
|
# Create enclosed state, this will be set once per creation so will only be True on the first execution
|
||||||
|
# of group_transition
|
||||||
|
state = {"use_initial_agent": True}
|
||||||
|
|
||||||
|
def group_transition(last_speaker: "ConversableAgent", groupchat: GroupChat) -> Optional[Union[Agent, str]]:
|
||||||
|
result = determine_next_agent(
|
||||||
|
last_speaker=last_speaker,
|
||||||
|
groupchat=groupchat,
|
||||||
|
initial_agent=initial_agent,
|
||||||
|
use_initial_agent=state["use_initial_agent"],
|
||||||
|
tool_executor=tool_execution,
|
||||||
|
group_agent_names=group_agent_names,
|
||||||
|
user_agent=user_agent,
|
||||||
|
group_after_work=group_after_work,
|
||||||
|
)
|
||||||
|
state["use_initial_agent"] = False
|
||||||
|
return result
|
||||||
|
|
||||||
|
return group_transition
|
||||||
|
|
||||||
|
|
||||||
|
def create_group_manager(
|
||||||
|
groupchat: GroupChat,
|
||||||
|
group_manager_args: Optional[dict[str, Any]],
|
||||||
|
agents: list["ConversableAgent"],
|
||||||
|
group_after_work: TransitionTarget,
|
||||||
|
) -> GroupChatManager:
|
||||||
|
"""Create a GroupChatManager for the group chat utilising any arguments passed in and ensure an LLM Config exists if needed
|
||||||
|
|
||||||
|
Args:
|
||||||
|
groupchat (GroupChat): The groupchat.
|
||||||
|
group_manager_args (dict[str, Any]): Group manager arguments to create the GroupChatManager.
|
||||||
|
agents (list["ConversableAgent"]): List of agents in the group to check handoffs and after work.
|
||||||
|
group_after_work (TransitionTarget): Group-level after work to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GroupChatManager: GroupChatManager instance.
|
||||||
|
"""
|
||||||
|
manager_args = (group_manager_args or {}).copy()
|
||||||
|
if "groupchat" in manager_args:
|
||||||
|
raise ValueError("'groupchat' cannot be specified in group_manager_args as it is set by initiate_group_chat")
|
||||||
|
manager = GroupChatManager(groupchat, **manager_args)
|
||||||
|
|
||||||
|
# Ensure that our manager has an LLM Config if we have any GroupManagerTarget targets used
|
||||||
|
if manager.llm_config is False:
|
||||||
|
has_group_manager_target = False
|
||||||
|
|
||||||
|
if isinstance(group_after_work, GroupManagerTarget):
|
||||||
|
# Check group after work
|
||||||
|
has_group_manager_target = True
|
||||||
|
else:
|
||||||
|
# Check agent hand-offs and after work
|
||||||
|
for agent in agents:
|
||||||
|
if (
|
||||||
|
len(agent.handoffs.get_context_conditions_by_target_type(GroupManagerTarget)) > 0
|
||||||
|
or len(agent.handoffs.get_llm_conditions_by_target_type(GroupManagerTarget)) > 0
|
||||||
|
or any(isinstance(aw.target, GroupManagerTarget) for aw in agent.handoffs.after_works)
|
||||||
|
):
|
||||||
|
has_group_manager_target = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if has_group_manager_target:
|
||||||
|
raise ValueError(
|
||||||
|
"The group manager doesn't have an LLM Config and it is required for any targets or after works using a GroupManagerTarget. Use the 'llm_config' in the group_manager_args parameter to specify the LLM Config for the group manager."
|
||||||
|
)
|
||||||
|
|
||||||
|
return manager
|
||||||
|
|
||||||
|
|
||||||
|
def make_remove_function(tool_msgs_to_remove: list[str]) -> Callable[[list[dict[str, Any]]], list[dict[str, Any]]]:
|
||||||
|
"""Create a function to remove messages with tool calls from the messages list.
|
||||||
|
|
||||||
|
The returned function can be registered as a hook to "process_all_messages_before_reply"" to remove messages with tool calls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def remove_messages(messages: list[dict[str, Any]], tool_msgs_to_remove: list[str]) -> list[dict[str, Any]]:
|
||||||
|
copied = copy.deepcopy(messages)
|
||||||
|
new_messages = []
|
||||||
|
removed_tool_ids = []
|
||||||
|
for message in copied:
|
||||||
|
# remove tool calls
|
||||||
|
if message.get("tool_calls") is not None:
|
||||||
|
filtered_tool_calls = []
|
||||||
|
for tool_call in message["tool_calls"]:
|
||||||
|
if tool_call.get("function") is not None and tool_call["function"]["name"] in tool_msgs_to_remove:
|
||||||
|
# remove
|
||||||
|
removed_tool_ids.append(tool_call["id"])
|
||||||
|
else:
|
||||||
|
filtered_tool_calls.append(tool_call)
|
||||||
|
if len(filtered_tool_calls) > 0:
|
||||||
|
message["tool_calls"] = filtered_tool_calls
|
||||||
|
else:
|
||||||
|
del message["tool_calls"]
|
||||||
|
if (
|
||||||
|
message.get("content") is None
|
||||||
|
or message.get("content") == ""
|
||||||
|
or message.get("content") == "None"
|
||||||
|
):
|
||||||
|
continue # if no tool call and no content, skip this message
|
||||||
|
# else: keep the message with tool_calls removed
|
||||||
|
# remove corresponding tool responses
|
||||||
|
elif message.get("tool_responses") is not None:
|
||||||
|
filtered_tool_responses = []
|
||||||
|
for tool_response in message["tool_responses"]:
|
||||||
|
if tool_response["tool_call_id"] not in removed_tool_ids:
|
||||||
|
filtered_tool_responses.append(tool_response)
|
||||||
|
|
||||||
|
if len(filtered_tool_responses) > 0:
|
||||||
|
message["tool_responses"] = filtered_tool_responses
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_messages.append(message)
|
||||||
|
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
return partial(remove_messages, tool_msgs_to_remove=tool_msgs_to_remove)
|
||||||
320
mm_agents/coact/autogen/agentchat/group/handoffs.py
Normal file
320
mm_agents/coact/autogen/agentchat/group/handoffs.py
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Union, overload
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from .on_condition import OnCondition
|
||||||
|
from .on_context_condition import OnContextCondition
|
||||||
|
from .targets.transition_target import TransitionTarget
|
||||||
|
|
||||||
|
__all__ = ["Handoffs"]
|
||||||
|
|
||||||
|
|
||||||
|
class Handoffs(BaseModel):
|
||||||
|
"""
|
||||||
|
Container for all handoff transition conditions of a ConversableAgent.
|
||||||
|
|
||||||
|
Three types of conditions can be added, each with a different order and time of use:
|
||||||
|
1. OnContextConditions (evaluated without an LLM)
|
||||||
|
2. OnConditions (evaluated with an LLM)
|
||||||
|
3. After work TransitionTarget (if no other transition is triggered)
|
||||||
|
|
||||||
|
Supports method chaining:
|
||||||
|
agent.handoffs.add_context_conditions([condition1]) \
|
||||||
|
.add_llm_condition(condition2) \
|
||||||
|
.set_after_work(after_work)
|
||||||
|
"""
|
||||||
|
|
||||||
|
context_conditions: list[OnContextCondition] = Field(default_factory=list)
|
||||||
|
llm_conditions: list[OnCondition] = Field(default_factory=list)
|
||||||
|
after_works: list[OnContextCondition] = Field(default_factory=list)
|
||||||
|
|
||||||
|
def add_context_condition(self, condition: OnContextCondition) -> "Handoffs":
|
||||||
|
"""
|
||||||
|
Add a single context condition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
condition: The OnContextCondition to add
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Self for method chaining
|
||||||
|
"""
|
||||||
|
# Validate that it is an OnContextCondition
|
||||||
|
if not isinstance(condition, OnContextCondition):
|
||||||
|
raise TypeError(f"Expected an OnContextCondition instance, got {type(condition).__name__}")
|
||||||
|
|
||||||
|
self.context_conditions.append(condition)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add_context_conditions(self, conditions: list[OnContextCondition]) -> "Handoffs":
|
||||||
|
"""
|
||||||
|
Add multiple context conditions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conditions: List of OnContextConditions to add
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Self for method chaining
|
||||||
|
"""
|
||||||
|
# Validate that it is a list of OnContextConditions
|
||||||
|
if not all(isinstance(condition, OnContextCondition) for condition in conditions):
|
||||||
|
raise TypeError("All conditions must be of type OnContextCondition")
|
||||||
|
|
||||||
|
self.context_conditions.extend(conditions)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add_llm_condition(self, condition: OnCondition) -> "Handoffs":
|
||||||
|
"""
|
||||||
|
Add a single LLM condition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
condition: The OnCondition to add
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Self for method chaining
|
||||||
|
"""
|
||||||
|
# Validate that it is an OnCondition
|
||||||
|
if not isinstance(condition, OnCondition):
|
||||||
|
raise TypeError(f"Expected an OnCondition instance, got {type(condition).__name__}")
|
||||||
|
|
||||||
|
self.llm_conditions.append(condition)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add_llm_conditions(self, conditions: list[OnCondition]) -> "Handoffs":
|
||||||
|
"""
|
||||||
|
Add multiple LLM conditions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conditions: List of OnConditions to add
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Self for method chaining
|
||||||
|
"""
|
||||||
|
# Validate that it is a list of OnConditions
|
||||||
|
if not all(isinstance(condition, OnCondition) for condition in conditions):
|
||||||
|
raise TypeError("All conditions must be of type OnCondition")
|
||||||
|
|
||||||
|
self.llm_conditions.extend(conditions)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def set_after_work(self, target: TransitionTarget) -> "Handoffs":
|
||||||
|
"""
|
||||||
|
Set the after work target (replaces all after_works with single entry).
|
||||||
|
|
||||||
|
For backward compatibility, this creates an OnContextCondition with no condition (always true).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: The after work TransitionTarget to set
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Self for method chaining
|
||||||
|
"""
|
||||||
|
if not isinstance(target, TransitionTarget):
|
||||||
|
raise TypeError(f"Expected a TransitionTarget instance, got {type(target).__name__}")
|
||||||
|
|
||||||
|
# Create OnContextCondition with no condition (always true)
|
||||||
|
after_work_condition = OnContextCondition(target=target, condition=None)
|
||||||
|
self.after_works = [after_work_condition]
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add_after_work(self, condition: OnContextCondition) -> "Handoffs":
|
||||||
|
"""
|
||||||
|
Add a single after-work condition.
|
||||||
|
|
||||||
|
If the condition has condition=None, it will replace any existing
|
||||||
|
condition=None entry and be placed at the end.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
condition: The OnContextCondition to add
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Self for method chaining
|
||||||
|
"""
|
||||||
|
if not isinstance(condition, OnContextCondition):
|
||||||
|
raise TypeError(f"Expected an OnContextCondition instance, got {type(condition).__name__}")
|
||||||
|
|
||||||
|
if condition.condition is None:
|
||||||
|
# Remove any existing condition=None entries
|
||||||
|
self.after_works = [c for c in self.after_works if c.condition is not None]
|
||||||
|
# Add the new one at the end
|
||||||
|
self.after_works.append(condition)
|
||||||
|
else:
|
||||||
|
# For regular conditions, check if we need to move condition=None to the end
|
||||||
|
none_conditions = [c for c in self.after_works if c.condition is None]
|
||||||
|
if none_conditions:
|
||||||
|
# Remove the None condition temporarily
|
||||||
|
self.after_works = [c for c in self.after_works if c.condition is not None]
|
||||||
|
# Add the new regular condition
|
||||||
|
self.after_works.append(condition)
|
||||||
|
# Re-add the None condition at the end
|
||||||
|
self.after_works.append(none_conditions[0])
|
||||||
|
else:
|
||||||
|
# No None condition exists, just append
|
||||||
|
self.after_works.append(condition)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add_after_works(self, conditions: list[OnContextCondition]) -> "Handoffs":
|
||||||
|
"""
|
||||||
|
Add multiple after-work conditions.
|
||||||
|
|
||||||
|
Special handling for condition=None entries:
|
||||||
|
- Only one condition=None entry is allowed (the fallback)
|
||||||
|
- It will always be placed at the end of the list
|
||||||
|
- If multiple condition=None entries are provided, only the last one is kept
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conditions: List of OnContextConditions to add
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Self for method chaining
|
||||||
|
"""
|
||||||
|
# Validate that it is a list of OnContextConditions
|
||||||
|
if not all(isinstance(condition, OnContextCondition) for condition in conditions):
|
||||||
|
raise TypeError("All conditions must be of type OnContextCondition")
|
||||||
|
|
||||||
|
# Separate conditions with None and without None
|
||||||
|
none_conditions = [c for c in conditions if c.condition is None]
|
||||||
|
regular_conditions = [c for c in conditions if c.condition is not None]
|
||||||
|
|
||||||
|
# Remove any existing condition=None entries
|
||||||
|
self.after_works = [c for c in self.after_works if c.condition is not None]
|
||||||
|
|
||||||
|
# Add regular conditions
|
||||||
|
self.after_works.extend(regular_conditions)
|
||||||
|
|
||||||
|
# Add at most one None condition at the end
|
||||||
|
if none_conditions:
|
||||||
|
self.after_works.append(none_conditions[-1]) # Use the last one if multiple provided
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def add(self, condition: OnContextCondition) -> "Handoffs": ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def add(self, condition: OnCondition) -> "Handoffs": ...
|
||||||
|
|
||||||
|
def add(self, condition: Union[OnContextCondition, OnCondition]) -> "Handoffs":
|
||||||
|
"""
|
||||||
|
Add a single condition (OnContextCondition or OnCondition).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
condition: The condition to add (OnContextCondition or OnCondition)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If the condition type is not supported
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Self for method chaining
|
||||||
|
"""
|
||||||
|
# This add method is a helper method designed to make it easier for
|
||||||
|
# adding handoffs without worrying about the specific type.
|
||||||
|
if isinstance(condition, OnContextCondition):
|
||||||
|
return self.add_context_condition(condition)
|
||||||
|
elif isinstance(condition, OnCondition):
|
||||||
|
return self.add_llm_condition(condition)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unsupported condition type: {type(condition).__name__}")
|
||||||
|
|
||||||
|
def add_many(self, conditions: list[Union[OnContextCondition, OnCondition]]) -> "Handoffs":
|
||||||
|
"""
|
||||||
|
Add multiple conditions of any supported types (OnContextCondition and OnCondition).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conditions: List of conditions to add
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If an unsupported condition type is provided
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Self for method chaining
|
||||||
|
"""
|
||||||
|
# This add_many method is a helper method designed to make it easier for
|
||||||
|
# adding handoffs without worrying about the specific type.
|
||||||
|
context_conditions = []
|
||||||
|
llm_conditions = []
|
||||||
|
|
||||||
|
for condition in conditions:
|
||||||
|
if isinstance(condition, OnContextCondition):
|
||||||
|
context_conditions.append(condition)
|
||||||
|
elif isinstance(condition, OnCondition):
|
||||||
|
llm_conditions.append(condition)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unsupported condition type: {type(condition).__name__}")
|
||||||
|
|
||||||
|
if context_conditions:
|
||||||
|
self.add_context_conditions(context_conditions)
|
||||||
|
if llm_conditions:
|
||||||
|
self.add_llm_conditions(llm_conditions)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def clear(self) -> "Handoffs":
|
||||||
|
"""
|
||||||
|
Clear all handoff conditions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Self for method chaining
|
||||||
|
"""
|
||||||
|
self.context_conditions.clear()
|
||||||
|
self.llm_conditions.clear()
|
||||||
|
self.after_works.clear()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def get_llm_conditions_by_target_type(self, target_type: type) -> list[OnCondition]:
|
||||||
|
"""
|
||||||
|
Get OnConditions for a specific target type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_type: The type of condition to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of conditions of the specified type, or None if none exist
|
||||||
|
"""
|
||||||
|
return [on_condition for on_condition in self.llm_conditions if on_condition.has_target_type(target_type)]
|
||||||
|
|
||||||
|
def get_context_conditions_by_target_type(self, target_type: type) -> list[OnContextCondition]:
|
||||||
|
"""
|
||||||
|
Get OnContextConditions for a specific target type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_type: The type of condition to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of conditions of the specified type, or None if none exist
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
on_context_condition
|
||||||
|
for on_context_condition in self.context_conditions
|
||||||
|
if on_context_condition.has_target_type(target_type)
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_llm_conditions_requiring_wrapping(self) -> list[OnCondition]:
|
||||||
|
"""
|
||||||
|
Get LLM conditions that have targets that require wrapping.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of LLM conditions that require wrapping
|
||||||
|
"""
|
||||||
|
return [condition for condition in self.llm_conditions if condition.target_requires_wrapping()]
|
||||||
|
|
||||||
|
def get_context_conditions_requiring_wrapping(self) -> list[OnContextCondition]:
|
||||||
|
"""
|
||||||
|
Get context conditions that have targets that require wrapping.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of context conditions that require wrapping
|
||||||
|
"""
|
||||||
|
return [condition for condition in self.context_conditions if condition.target_requires_wrapping()]
|
||||||
|
|
||||||
|
def set_llm_function_names(self) -> None:
|
||||||
|
"""
|
||||||
|
Set the LLM function names for all LLM conditions, creating unique names for each function.
|
||||||
|
"""
|
||||||
|
for i, condition in enumerate(self.llm_conditions):
|
||||||
|
# Function names are made unique and allow multiple OnCondition's to the same agent
|
||||||
|
condition.llm_function_name = f"transfer_to_{condition.target.normalized_name()}_{i + 1}"
|
||||||
93
mm_agents/coact/autogen/agentchat/group/llm_condition.py
Normal file
93
mm_agents/coact/autogen/agentchat/group/llm_condition.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .context_str import ContextStr
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
# Avoid circular import
|
||||||
|
from ..conversable_agent import ConversableAgent
|
||||||
|
|
||||||
|
__all__ = ["ContextStrLLMCondition", "LLMCondition", "StringLLMCondition"]
|
||||||
|
|
||||||
|
|
||||||
|
class LLMCondition(BaseModel):
|
||||||
|
"""Protocol for conditions evaluated by an LLM."""
|
||||||
|
|
||||||
|
def get_prompt(self, agent: "ConversableAgent", messages: list[dict[str, Any]]) -> str:
|
||||||
|
"""Get the prompt text for LLM evaluation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent evaluating the condition
|
||||||
|
messages: The conversation history
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The prompt text to be evaluated by the LLM
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Requires subclasses to implement.")
|
||||||
|
|
||||||
|
|
||||||
|
class StringLLMCondition(LLMCondition):
|
||||||
|
"""Simple string-based LLM condition.
|
||||||
|
|
||||||
|
This condition provides a static string prompt to be evaluated by an LLM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt: str
|
||||||
|
|
||||||
|
def __init__(self, prompt: str, **data: Any) -> None:
|
||||||
|
"""Initialize with a prompt string as a positional parameter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The static prompt string to evaluate
|
||||||
|
data: Additional data for the parent class
|
||||||
|
"""
|
||||||
|
super().__init__(prompt=prompt, **data)
|
||||||
|
|
||||||
|
def get_prompt(self, agent: "ConversableAgent", messages: list[dict[str, Any]]) -> str:
|
||||||
|
"""Return the static prompt string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent evaluating the condition (not used)
|
||||||
|
messages: The conversation history (not used)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The static prompt string
|
||||||
|
"""
|
||||||
|
return self.prompt
|
||||||
|
|
||||||
|
|
||||||
|
class ContextStrLLMCondition(LLMCondition):
|
||||||
|
"""Context variable-based LLM condition.
|
||||||
|
|
||||||
|
This condition uses a ContextStr object with context variable placeholders that
|
||||||
|
will be substituted before being evaluated by an LLM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
context_str: ContextStr
|
||||||
|
|
||||||
|
def __init__(self, context_str: ContextStr, **data: Any) -> None:
|
||||||
|
"""Initialize with a context string as a positional parameter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_str: The ContextStr object with variable placeholders
|
||||||
|
data: Additional data for the parent class
|
||||||
|
"""
|
||||||
|
super().__init__(context_str=context_str, **data)
|
||||||
|
|
||||||
|
def get_prompt(self, agent: "ConversableAgent", messages: list[dict[str, Any]]) -> str:
|
||||||
|
"""Return the prompt with context variables substituted.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent evaluating the condition (provides context variables)
|
||||||
|
messages: The conversation history (not used)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The prompt with context variables substituted
|
||||||
|
"""
|
||||||
|
result = self.context_str.format(agent.context_variables)
|
||||||
|
return result if result is not None else ""
|
||||||
237
mm_agents/coact/autogen/agentchat/group/multi_agent_chat.py
Normal file
237
mm_agents/coact/autogen/agentchat/group/multi_agent_chat.py
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
from typing import TYPE_CHECKING, Any, Union
|
||||||
|
|
||||||
|
from ...doc_utils import export_module
|
||||||
|
from ...events.agent_events import ErrorEvent, RunCompletionEvent
|
||||||
|
from ...io.base import IOStream
|
||||||
|
from ...io.run_response import AsyncRunResponse, AsyncRunResponseProtocol, RunResponse, RunResponseProtocol
|
||||||
|
from ...io.thread_io_stream import AsyncThreadIOStream, ThreadIOStream
|
||||||
|
from ..chat import ChatResult
|
||||||
|
from .context_variables import ContextVariables
|
||||||
|
from .group_utils import cleanup_temp_user_messages
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..agent import Agent
|
||||||
|
from .patterns.pattern import Pattern
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"a_initiate_group_chat",
|
||||||
|
"a_run_group_chat",
|
||||||
|
"initiate_group_chat",
|
||||||
|
"run_group_chat",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen")
|
||||||
|
def initiate_group_chat(
|
||||||
|
pattern: "Pattern",
|
||||||
|
messages: Union[list[dict[str, Any]], str],
|
||||||
|
max_rounds: int = 20,
|
||||||
|
) -> tuple[ChatResult, ContextVariables, "Agent"]:
|
||||||
|
"""Initialize and run a group chat using a pattern for configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: Pattern object that encapsulates the chat configuration.
|
||||||
|
messages: Initial message(s).
|
||||||
|
max_rounds: Maximum number of conversation rounds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatResult: Conversations chat history.
|
||||||
|
ContextVariables: Updated Context variables.
|
||||||
|
"ConversableAgent": Last speaker.
|
||||||
|
"""
|
||||||
|
# Let the pattern prepare the group chat and all its components
|
||||||
|
# Only passing the necessary parameters that aren't already in the pattern
|
||||||
|
(
|
||||||
|
_, # agents,
|
||||||
|
_, # wrapped_agents,
|
||||||
|
_, # user_agent,
|
||||||
|
context_variables,
|
||||||
|
_, # initial_agent,
|
||||||
|
_, # group_after_work,
|
||||||
|
_, # tool_execution,
|
||||||
|
_, # groupchat,
|
||||||
|
manager,
|
||||||
|
processed_messages,
|
||||||
|
last_agent,
|
||||||
|
_, # group_agent_names,
|
||||||
|
_, # temp_user_list,
|
||||||
|
) = pattern.prepare_group_chat(
|
||||||
|
max_rounds=max_rounds,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start or resume the conversation
|
||||||
|
if len(processed_messages) > 1:
|
||||||
|
last_agent, last_message = manager.resume(messages=processed_messages)
|
||||||
|
clear_history = False
|
||||||
|
else:
|
||||||
|
last_message = processed_messages[0]
|
||||||
|
clear_history = True
|
||||||
|
|
||||||
|
if last_agent is None:
|
||||||
|
raise ValueError("No agent selected to start the conversation")
|
||||||
|
|
||||||
|
chat_result = last_agent.initiate_chat(
|
||||||
|
manager,
|
||||||
|
message=last_message,
|
||||||
|
clear_history=clear_history,
|
||||||
|
summary_method=pattern.summary_method,
|
||||||
|
)
|
||||||
|
|
||||||
|
cleanup_temp_user_messages(chat_result)
|
||||||
|
|
||||||
|
return chat_result, context_variables, manager.last_speaker
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.agentchat")
|
||||||
|
async def a_initiate_group_chat(
|
||||||
|
pattern: "Pattern",
|
||||||
|
messages: Union[list[dict[str, Any]], str],
|
||||||
|
max_rounds: int = 20,
|
||||||
|
) -> tuple[ChatResult, ContextVariables, "Agent"]:
|
||||||
|
"""Initialize and run a group chat using a pattern for configuration, asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: Pattern object that encapsulates the chat configuration.
|
||||||
|
messages: Initial message(s).
|
||||||
|
max_rounds: Maximum number of conversation rounds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatResult: Conversations chat history.
|
||||||
|
ContextVariables: Updated Context variables.
|
||||||
|
"ConversableAgent": Last speaker.
|
||||||
|
"""
|
||||||
|
# Let the pattern prepare the group chat and all its components
|
||||||
|
# Only passing the necessary parameters that aren't already in the pattern
|
||||||
|
(
|
||||||
|
_, # agents,
|
||||||
|
_, # wrapped_agents,
|
||||||
|
_, # user_agent,
|
||||||
|
context_variables,
|
||||||
|
_, # initial_agent,
|
||||||
|
_, # group_after_work,
|
||||||
|
_, # tool_execution,
|
||||||
|
_, # groupchat,
|
||||||
|
manager,
|
||||||
|
processed_messages,
|
||||||
|
last_agent,
|
||||||
|
_, # group_agent_names,
|
||||||
|
_, # temp_user_list,
|
||||||
|
) = pattern.prepare_group_chat(
|
||||||
|
max_rounds=max_rounds,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start or resume the conversation
|
||||||
|
if len(processed_messages) > 1:
|
||||||
|
last_agent, last_message = await manager.a_resume(messages=processed_messages)
|
||||||
|
clear_history = False
|
||||||
|
else:
|
||||||
|
last_message = processed_messages[0]
|
||||||
|
clear_history = True
|
||||||
|
|
||||||
|
if last_agent is None:
|
||||||
|
raise ValueError("No agent selected to start the conversation")
|
||||||
|
|
||||||
|
chat_result = await last_agent.a_initiate_chat(
|
||||||
|
manager,
|
||||||
|
message=last_message, # type: ignore[arg-type]
|
||||||
|
clear_history=clear_history,
|
||||||
|
summary_method=pattern.summary_method,
|
||||||
|
)
|
||||||
|
|
||||||
|
cleanup_temp_user_messages(chat_result)
|
||||||
|
|
||||||
|
return chat_result, context_variables, manager.last_speaker
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.agentchat")
|
||||||
|
def run_group_chat(
|
||||||
|
pattern: "Pattern",
|
||||||
|
messages: Union[list[dict[str, Any]], str],
|
||||||
|
max_rounds: int = 20,
|
||||||
|
) -> RunResponseProtocol:
|
||||||
|
iostream = ThreadIOStream()
|
||||||
|
# todo: add agents
|
||||||
|
response = RunResponse(iostream, agents=[])
|
||||||
|
|
||||||
|
def _initiate_group_chat(
|
||||||
|
pattern: "Pattern" = pattern,
|
||||||
|
messages: Union[list[dict[str, Any]], str] = messages,
|
||||||
|
max_rounds: int = max_rounds,
|
||||||
|
iostream: ThreadIOStream = iostream,
|
||||||
|
response: RunResponse = response,
|
||||||
|
) -> None:
|
||||||
|
with IOStream.set_default(iostream):
|
||||||
|
try:
|
||||||
|
chat_result, context_vars, agent = initiate_group_chat(
|
||||||
|
pattern=pattern,
|
||||||
|
messages=messages,
|
||||||
|
max_rounds=max_rounds,
|
||||||
|
)
|
||||||
|
|
||||||
|
IOStream.get_default().send(
|
||||||
|
RunCompletionEvent( # type: ignore[call-arg]
|
||||||
|
history=chat_result.chat_history,
|
||||||
|
summary=chat_result.summary,
|
||||||
|
cost=chat_result.cost,
|
||||||
|
last_speaker=agent.name,
|
||||||
|
context_variables=context_vars,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
response.iostream.send(ErrorEvent(error=e)) # type: ignore[call-arg]
|
||||||
|
|
||||||
|
threading.Thread(
|
||||||
|
target=_initiate_group_chat,
|
||||||
|
).start()
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.agentchat")
|
||||||
|
async def a_run_group_chat(
|
||||||
|
pattern: "Pattern",
|
||||||
|
messages: Union[list[dict[str, Any]], str],
|
||||||
|
max_rounds: int = 20,
|
||||||
|
) -> AsyncRunResponseProtocol:
|
||||||
|
iostream = AsyncThreadIOStream()
|
||||||
|
# todo: add agents
|
||||||
|
response = AsyncRunResponse(iostream, agents=[])
|
||||||
|
|
||||||
|
async def _initiate_group_chat(
|
||||||
|
pattern: "Pattern" = pattern,
|
||||||
|
messages: Union[list[dict[str, Any]], str] = messages,
|
||||||
|
max_rounds: int = max_rounds,
|
||||||
|
iostream: AsyncThreadIOStream = iostream,
|
||||||
|
response: AsyncRunResponse = response,
|
||||||
|
) -> None:
|
||||||
|
with IOStream.set_default(iostream):
|
||||||
|
try:
|
||||||
|
chat_result, context_vars, agent = await a_initiate_group_chat(
|
||||||
|
pattern=pattern,
|
||||||
|
messages=messages,
|
||||||
|
max_rounds=max_rounds,
|
||||||
|
)
|
||||||
|
|
||||||
|
IOStream.get_default().send(
|
||||||
|
RunCompletionEvent( # type: ignore[call-arg]
|
||||||
|
history=chat_result.chat_history,
|
||||||
|
summary=chat_result.summary,
|
||||||
|
cost=chat_result.cost,
|
||||||
|
last_speaker=agent.name,
|
||||||
|
context_variables=context_vars,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
response.iostream.send(ErrorEvent(error=e)) # type: ignore[call-arg]
|
||||||
|
|
||||||
|
asyncio.create_task(_initiate_group_chat())
|
||||||
|
|
||||||
|
return response
|
||||||
58
mm_agents/coact/autogen/agentchat/group/on_condition.py
Normal file
58
mm_agents/coact/autogen/agentchat/group/on_condition.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ...doc_utils import export_module
|
||||||
|
from .available_condition import AvailableCondition
|
||||||
|
from .llm_condition import LLMCondition
|
||||||
|
from .targets.transition_target import TransitionTarget
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"OnCondition",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen")
|
||||||
|
class OnCondition(BaseModel): # noqa: N801
|
||||||
|
"""Defines a condition for transitioning to another agent or nested chats.
|
||||||
|
|
||||||
|
This is for LLM-based condition evaluation where these conditions are translated into tools and attached to the agent.
|
||||||
|
|
||||||
|
These are evaluated after the OnCondition conditions but before the after work condition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target (TransitionTarget): The transition (essentially an agent) to hand off to.
|
||||||
|
condition (LLMCondition): The condition for transitioning to the target agent, evaluated by the LLM.
|
||||||
|
available (AvailableCondition): Optional condition to determine if this OnCondition is included for the LLM to evaluate based on context variables using classes like StringAvailableCondition and ContextExpressionAvailableCondition.
|
||||||
|
llm_function_name (Optional[str]): The name of the LLM function to use for this condition.
|
||||||
|
"""
|
||||||
|
|
||||||
|
target: TransitionTarget
|
||||||
|
condition: LLMCondition
|
||||||
|
available: Optional[AvailableCondition] = None
|
||||||
|
llm_function_name: Optional[str] = None
|
||||||
|
|
||||||
|
def has_target_type(self, target_type: type) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the target type matches the specified type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_type (type): The target type to check against, which should be a subclass of TransitionTarget
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the target type matches, False otherwise
|
||||||
|
"""
|
||||||
|
return isinstance(self.target, target_type)
|
||||||
|
|
||||||
|
def target_requires_wrapping(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the target requires wrapping in an agent.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the target requires wrapping, False otherwise
|
||||||
|
"""
|
||||||
|
return self.target.needs_agent_wrapper()
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .available_condition import AvailableCondition
|
||||||
|
from .context_condition import ContextCondition
|
||||||
|
from .targets.transition_target import TransitionTarget
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"OnContextCondition",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class OnContextCondition(BaseModel): # noqa: N801
|
||||||
|
"""Defines a condition for transitioning to another agent or nested chats using context variables and the ContextExpression class.
|
||||||
|
|
||||||
|
This is for context variable-based condition evaluation (does not use the agent's LLM).
|
||||||
|
|
||||||
|
These are evaluated before the OnCondition and after work conditions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target (TransitionTarget): The transition (essentially an agent) to hand off to.
|
||||||
|
condition (Optional[ContextCondition]): The context variable based condition for transitioning to the target agent. If None, the condition always evaluates to True.
|
||||||
|
available (AvailableCondition): Optional condition to determine if this OnCondition is included for the LLM to evaluate based on context variables using classes like StringAvailableCondition and ContextExpressionAvailableCondition.
|
||||||
|
"""
|
||||||
|
|
||||||
|
target: TransitionTarget
|
||||||
|
condition: Optional[ContextCondition] = None
|
||||||
|
available: Optional[AvailableCondition] = None
|
||||||
|
|
||||||
|
def has_target_type(self, target_type: type) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the target type matches the specified type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_type (type): The target type to check against. Should be a subclass of TransitionTarget.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the target type matches, False otherwise
|
||||||
|
"""
|
||||||
|
return isinstance(self.target, target_type)
|
||||||
|
|
||||||
|
def target_requires_wrapping(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the target requires wrapping in an agent.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the target requires wrapping, False otherwise
|
||||||
|
"""
|
||||||
|
return self.target.needs_agent_wrapper()
|
||||||
18
mm_agents/coact/autogen/agentchat/group/patterns/__init__.py
Normal file
18
mm_agents/coact/autogen/agentchat/group/patterns/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
|
||||||
|
from .auto import AutoPattern
|
||||||
|
from .manual import ManualPattern
|
||||||
|
from .pattern import DefaultPattern
|
||||||
|
from .random import RandomPattern
|
||||||
|
from .round_robin import RoundRobinPattern
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AutoPattern",
|
||||||
|
"DefaultPattern",
|
||||||
|
"ManualPattern",
|
||||||
|
"RandomPattern",
|
||||||
|
"RoundRobinPattern",
|
||||||
|
]
|
||||||
159
mm_agents/coact/autogen/agentchat/group/patterns/auto.py
Normal file
159
mm_agents/coact/autogen/agentchat/group/patterns/auto.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from ..context_variables import ContextVariables
|
||||||
|
from ..targets.group_manager_target import GroupManagerSelectionMessage, GroupManagerTarget
|
||||||
|
from ..targets.transition_target import TransitionTarget
|
||||||
|
from .pattern import Pattern
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...conversable_agent import ConversableAgent
|
||||||
|
from ...groupchat import GroupChat, GroupChatManager
|
||||||
|
from ..group_tool_executor import GroupToolExecutor
|
||||||
|
|
||||||
|
|
||||||
|
class AutoPattern(Pattern):
|
||||||
|
"""AutoPattern implements a flexible pattern where agents are selected based on their expertise.
|
||||||
|
|
||||||
|
In this pattern, a group manager automatically selects the next agent to speak based on the context
|
||||||
|
of the conversation and agent descriptions. The after_work is always set to "group_manager" as
|
||||||
|
this is the defining characteristic of this pattern.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
initial_agent: "ConversableAgent",
|
||||||
|
agents: list["ConversableAgent"],
|
||||||
|
user_agent: Optional["ConversableAgent"] = None,
|
||||||
|
group_manager_args: Optional[dict[str, Any]] = None,
|
||||||
|
context_variables: Optional[ContextVariables] = None,
|
||||||
|
selection_message: Optional[GroupManagerSelectionMessage] = None,
|
||||||
|
exclude_transit_message: bool = True,
|
||||||
|
summary_method: Optional[Union[str, Callable[..., Any]]] = "last_msg",
|
||||||
|
):
|
||||||
|
"""Initialize the AutoPattern.
|
||||||
|
|
||||||
|
The after_work is always set to group_manager selection, which is the defining
|
||||||
|
characteristic of this pattern. You can customize the selection message used
|
||||||
|
by the group manager when selecting the next agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_agent: The first agent to speak in the group chat.
|
||||||
|
agents: List of all agents participating in the chat.
|
||||||
|
user_agent: Optional user proxy agent.
|
||||||
|
group_manager_args: Optional arguments for the GroupChatManager.
|
||||||
|
context_variables: Initial context variables for the chat.
|
||||||
|
selection_message: Custom message to use when the group manager is selecting agents.
|
||||||
|
exclude_transit_message: Whether to exclude transit messages from the conversation.
|
||||||
|
summary_method: Method for summarizing the conversation.
|
||||||
|
"""
|
||||||
|
# Create the group_manager after_work with the provided selection message
|
||||||
|
group_manager_after_work = GroupManagerTarget(selection_message=selection_message)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
initial_agent=initial_agent,
|
||||||
|
agents=agents,
|
||||||
|
user_agent=user_agent,
|
||||||
|
group_manager_args=group_manager_args,
|
||||||
|
context_variables=context_variables,
|
||||||
|
group_after_work=group_manager_after_work,
|
||||||
|
exclude_transit_message=exclude_transit_message,
|
||||||
|
summary_method=summary_method,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store the selection message for potential use
|
||||||
|
self.selection_message = selection_message
|
||||||
|
|
||||||
|
def prepare_group_chat(
|
||||||
|
self,
|
||||||
|
max_rounds: int,
|
||||||
|
messages: Union[list[dict[str, Any]], str],
|
||||||
|
) -> Tuple[
|
||||||
|
list["ConversableAgent"],
|
||||||
|
list["ConversableAgent"],
|
||||||
|
Optional["ConversableAgent"],
|
||||||
|
ContextVariables,
|
||||||
|
"ConversableAgent",
|
||||||
|
TransitionTarget,
|
||||||
|
"GroupToolExecutor",
|
||||||
|
"GroupChat",
|
||||||
|
"GroupChatManager",
|
||||||
|
list[dict[str, Any]],
|
||||||
|
Any,
|
||||||
|
list[str],
|
||||||
|
list[Any],
|
||||||
|
]:
|
||||||
|
"""Prepare the group chat for organic agent selection.
|
||||||
|
|
||||||
|
Ensures that:
|
||||||
|
1. The group manager has a valid LLM config
|
||||||
|
2. All agents have appropriate descriptions for the group manager to use
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_rounds: Maximum number of conversation rounds.
|
||||||
|
messages: Initial message(s) to start the conversation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing all necessary components for the group chat.
|
||||||
|
"""
|
||||||
|
# Validate that group_manager_args has an LLM config which is required for this pattern
|
||||||
|
if not self.group_manager_args.get("llm_config", False):
|
||||||
|
# Check if any agent has an LLM config we can use
|
||||||
|
has_llm_config = any(getattr(agent, "llm_config", False) for agent in self.agents)
|
||||||
|
|
||||||
|
if not has_llm_config:
|
||||||
|
raise ValueError(
|
||||||
|
"AutoPattern requires the group_manager_args to include an llm_config, "
|
||||||
|
"or at least one agent to have an llm_config"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that all agents have descriptions for effective group manager selection
|
||||||
|
for agent in self.agents:
|
||||||
|
if not hasattr(agent, "description") or not agent.description:
|
||||||
|
agent.description = f"Agent {agent.name}"
|
||||||
|
|
||||||
|
# Use the parent class's implementation to prepare the agents and group chat
|
||||||
|
components = super().prepare_group_chat(
|
||||||
|
max_rounds=max_rounds,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the group_after_work and the rest of the components
|
||||||
|
(
|
||||||
|
agents,
|
||||||
|
wrapped_agents,
|
||||||
|
user_agent,
|
||||||
|
context_variables,
|
||||||
|
initial_agent,
|
||||||
|
_,
|
||||||
|
tool_executor,
|
||||||
|
groupchat,
|
||||||
|
manager,
|
||||||
|
processed_messages,
|
||||||
|
last_agent,
|
||||||
|
group_agent_names,
|
||||||
|
temp_user_list,
|
||||||
|
) = components
|
||||||
|
|
||||||
|
# Ensure we're using the group_manager after_work
|
||||||
|
group_after_work = self.group_after_work
|
||||||
|
|
||||||
|
# Return all components with our group_after_work
|
||||||
|
return (
|
||||||
|
agents,
|
||||||
|
wrapped_agents,
|
||||||
|
user_agent,
|
||||||
|
context_variables,
|
||||||
|
initial_agent,
|
||||||
|
group_after_work,
|
||||||
|
tool_executor,
|
||||||
|
groupchat,
|
||||||
|
manager,
|
||||||
|
processed_messages,
|
||||||
|
last_agent,
|
||||||
|
group_agent_names,
|
||||||
|
temp_user_list,
|
||||||
|
)
|
||||||
176
mm_agents/coact/autogen/agentchat/group/patterns/manual.py
Normal file
176
mm_agents/coact/autogen/agentchat/group/patterns/manual.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from ..context_variables import ContextVariables
|
||||||
|
from ..group_tool_executor import GroupToolExecutor
|
||||||
|
from ..targets.transition_target import AskUserTarget, TransitionTarget
|
||||||
|
from .pattern import Pattern
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...conversable_agent import ConversableAgent
|
||||||
|
from ...groupchat import GroupChat, GroupChatManager
|
||||||
|
|
||||||
|
|
||||||
|
class ManualPattern(Pattern):
|
||||||
|
"""ManualPattern will ask the user to nominate the next agent to speak at each turn."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
initial_agent: "ConversableAgent",
|
||||||
|
agents: list["ConversableAgent"],
|
||||||
|
user_agent: Optional["ConversableAgent"] = None,
|
||||||
|
group_manager_args: Optional[dict[str, Any]] = None,
|
||||||
|
context_variables: Optional[ContextVariables] = None,
|
||||||
|
exclude_transit_message: bool = True,
|
||||||
|
summary_method: Optional[Union[str, Callable[..., Any]]] = "last_msg",
|
||||||
|
):
|
||||||
|
"""Initialize the ManualPattern.
|
||||||
|
|
||||||
|
The after_work is always set to ask_user, which will prompt the user for the next agent
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_agent: The first agent to speak in the group chat.
|
||||||
|
agents: List of all agents participating in the chat.
|
||||||
|
user_agent: Optional user proxy agent.
|
||||||
|
group_manager_args: Optional arguments for the GroupChatManager.
|
||||||
|
context_variables: Initial context variables for the chat.
|
||||||
|
exclude_transit_message: Whether to exclude transit messages from the conversation.
|
||||||
|
summary_method: Method for summarizing the conversation.
|
||||||
|
"""
|
||||||
|
# The group after work will be to ask the user
|
||||||
|
group_after_work = AskUserTarget()
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
initial_agent=initial_agent,
|
||||||
|
agents=agents,
|
||||||
|
user_agent=user_agent,
|
||||||
|
group_manager_args=group_manager_args,
|
||||||
|
context_variables=context_variables,
|
||||||
|
group_after_work=group_after_work,
|
||||||
|
exclude_transit_message=exclude_transit_message,
|
||||||
|
summary_method=summary_method,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_group_chat(
|
||||||
|
self,
|
||||||
|
max_rounds: int,
|
||||||
|
messages: Union[list[dict[str, Any]], str],
|
||||||
|
) -> Tuple[
|
||||||
|
list["ConversableAgent"],
|
||||||
|
list["ConversableAgent"],
|
||||||
|
Optional["ConversableAgent"],
|
||||||
|
ContextVariables,
|
||||||
|
"ConversableAgent",
|
||||||
|
TransitionTarget,
|
||||||
|
"GroupToolExecutor",
|
||||||
|
"GroupChat",
|
||||||
|
"GroupChatManager",
|
||||||
|
list[dict[str, Any]],
|
||||||
|
Any,
|
||||||
|
list[str],
|
||||||
|
list[Any],
|
||||||
|
]:
|
||||||
|
"""Prepare the group chat for organic agent selection.
|
||||||
|
|
||||||
|
Ensures that:
|
||||||
|
1. The group manager has a valid LLM config
|
||||||
|
2. All agents have appropriate descriptions for the group manager to use
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_rounds: Maximum number of conversation rounds.
|
||||||
|
messages: Initial message(s) to start the conversation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing all necessary components for the group chat.
|
||||||
|
"""
|
||||||
|
# Use the parent class's implementation to prepare the agents and group chat
|
||||||
|
components = super().prepare_group_chat(
|
||||||
|
max_rounds=max_rounds,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the group_after_work and the rest of the components
|
||||||
|
(
|
||||||
|
agents,
|
||||||
|
wrapped_agents,
|
||||||
|
user_agent,
|
||||||
|
context_variables,
|
||||||
|
initial_agent,
|
||||||
|
_,
|
||||||
|
tool_executor,
|
||||||
|
groupchat,
|
||||||
|
manager,
|
||||||
|
processed_messages,
|
||||||
|
last_agent,
|
||||||
|
group_agent_names,
|
||||||
|
temp_user_list,
|
||||||
|
) = components
|
||||||
|
|
||||||
|
# Ensure we're using the group_manager after_work
|
||||||
|
group_after_work = self.group_after_work
|
||||||
|
|
||||||
|
# Set up the allowed speaker transitions to exclude user_agent and GroupToolExecutor
|
||||||
|
self._setup_allowed_transitions(groupchat, user_agent, tool_executor)
|
||||||
|
|
||||||
|
# Return all components with our group_after_work
|
||||||
|
return (
|
||||||
|
agents,
|
||||||
|
wrapped_agents,
|
||||||
|
user_agent,
|
||||||
|
context_variables,
|
||||||
|
initial_agent,
|
||||||
|
group_after_work,
|
||||||
|
tool_executor,
|
||||||
|
groupchat,
|
||||||
|
manager,
|
||||||
|
processed_messages,
|
||||||
|
last_agent,
|
||||||
|
group_agent_names,
|
||||||
|
temp_user_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _setup_allowed_transitions(
|
||||||
|
self, groupchat: "GroupChat", user_agent: Optional["ConversableAgent"], tool_executor: "GroupToolExecutor"
|
||||||
|
) -> None:
|
||||||
|
"""Set up the allowed speaker transitions for the group chat so that when a user selects the next agent the tool executor and user agent don't appear as options.
|
||||||
|
|
||||||
|
Creates transitions where:
|
||||||
|
1. Any agent can speak after any other agent, including themselves
|
||||||
|
2. The user_agent and GroupToolExecutor are excluded from transitions
|
||||||
|
|
||||||
|
Args:
|
||||||
|
groupchat: The GroupChat instance to configure
|
||||||
|
user_agent: The user agent to exclude from transitions
|
||||||
|
tool_executor: The GroupToolExecutor to exclude from transitions
|
||||||
|
"""
|
||||||
|
# NOTE: THIS IS NOT WORKING - THE TRANSITIONS ARE NOT BEING KEPT?!
|
||||||
|
"""
|
||||||
|
# Get all agents in the group chat
|
||||||
|
all_agents = groupchat.agents
|
||||||
|
|
||||||
|
# Filter out user_agent and group tool executor
|
||||||
|
eligible_agents = []
|
||||||
|
for agent in all_agents:
|
||||||
|
# Skip user_agent
|
||||||
|
if agent == user_agent:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Skip GroupToolExecutor
|
||||||
|
if isinstance(agent, GroupToolExecutor):
|
||||||
|
continue
|
||||||
|
|
||||||
|
eligible_agents.append(agent)
|
||||||
|
|
||||||
|
# Create a fully connected graph among eligible agents
|
||||||
|
# Each agent can be followed by any other eligible agent
|
||||||
|
allowed_transitions = {}
|
||||||
|
for agent in eligible_agents:
|
||||||
|
# For each agent, every other eligible agent can follow
|
||||||
|
allowed_transitions[agent] = eligible_agents
|
||||||
|
|
||||||
|
# Set the transitions in the group chat
|
||||||
|
groupchat.allowed_speaker_transitions_dict = allowed_transitions
|
||||||
|
"""
|
||||||
294
mm_agents/coact/autogen/agentchat/group/patterns/pattern.py
Normal file
294
mm_agents/coact/autogen/agentchat/group/patterns/pattern.py
Normal file
@@ -0,0 +1,294 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
# Patterns of agent orchestrations
|
||||||
|
# Uses the group chat or the agents' handoffs to create a pattern
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from ..context_variables import ContextVariables
|
||||||
|
from ..group_utils import (
|
||||||
|
create_group_manager,
|
||||||
|
create_group_transition,
|
||||||
|
link_agents_to_group_manager,
|
||||||
|
prepare_group_agents,
|
||||||
|
process_initial_messages,
|
||||||
|
setup_context_variables,
|
||||||
|
)
|
||||||
|
from ..targets.transition_target import TerminateTarget, TransitionTarget
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...agent import Agent
|
||||||
|
from ...conversable_agent import ConversableAgent
|
||||||
|
from ...groupchat import GroupChat, GroupChatManager
|
||||||
|
from ..group_tool_executor import GroupToolExecutor
|
||||||
|
|
||||||
|
|
||||||
|
class Pattern(ABC):
|
||||||
|
"""Base abstract class for all orchestration patterns.
|
||||||
|
|
||||||
|
Patterns provide a reusable way to define how agents interact within a group chat.
|
||||||
|
Each pattern encapsulates the logic for setting up agents, configuring handoffs,
|
||||||
|
and determining the flow of conversation.
|
||||||
|
|
||||||
|
This is an abstract base class and should not be instantiated directly.
|
||||||
|
Use one of the concrete pattern implementations like AutoPattern,
|
||||||
|
RoundRobinPattern, RandomPattern, or ManualPattern.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
initial_agent: "ConversableAgent",
|
||||||
|
agents: list["ConversableAgent"],
|
||||||
|
user_agent: Optional["ConversableAgent"] = None,
|
||||||
|
group_manager_args: Optional[dict[str, Any]] = None,
|
||||||
|
context_variables: Optional[ContextVariables] = None,
|
||||||
|
group_after_work: Optional[TransitionTarget] = None,
|
||||||
|
exclude_transit_message: bool = True,
|
||||||
|
summary_method: Optional[Union[str, Callable[..., Any]]] = "last_msg",
|
||||||
|
):
|
||||||
|
"""Initialize the pattern with the required components.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_agent: The first agent to speak in the group chat.
|
||||||
|
agents: List of all agents participating in the chat.
|
||||||
|
user_agent: Optional user proxy agent.
|
||||||
|
group_manager_args: Optional arguments for the GroupChatManager.
|
||||||
|
context_variables: Initial context variables for the chat.
|
||||||
|
group_after_work: Default after work transition behavior when no specific next agent is determined.
|
||||||
|
exclude_transit_message: Whether to exclude transit messages from the conversation.
|
||||||
|
summary_method: Method for summarizing the conversation.
|
||||||
|
"""
|
||||||
|
self.initial_agent = initial_agent
|
||||||
|
self.agents = agents
|
||||||
|
self.user_agent = user_agent
|
||||||
|
self.group_manager_args = group_manager_args or {}
|
||||||
|
self.context_variables = context_variables or ContextVariables()
|
||||||
|
self.group_after_work = group_after_work if group_after_work is not None else TerminateTarget()
|
||||||
|
self.exclude_transit_message = exclude_transit_message
|
||||||
|
self.summary_method = summary_method
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def prepare_group_chat(
|
||||||
|
self,
|
||||||
|
max_rounds: int,
|
||||||
|
messages: Union[list[dict[str, Any]], str],
|
||||||
|
) -> Tuple[
|
||||||
|
list["ConversableAgent"],
|
||||||
|
list["ConversableAgent"],
|
||||||
|
Optional["ConversableAgent"],
|
||||||
|
ContextVariables,
|
||||||
|
"ConversableAgent",
|
||||||
|
TransitionTarget,
|
||||||
|
"GroupToolExecutor",
|
||||||
|
"GroupChat",
|
||||||
|
"GroupChatManager",
|
||||||
|
list[dict[str, Any]],
|
||||||
|
"ConversableAgent",
|
||||||
|
list[str],
|
||||||
|
list["Agent"],
|
||||||
|
]:
|
||||||
|
"""Prepare the group chat for orchestration.
|
||||||
|
|
||||||
|
This is the main method called by initiate_group_chat to set up the pattern.
|
||||||
|
Subclasses must implement or extend this method to define pattern-specific behavior.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_rounds: Maximum number of conversation rounds.
|
||||||
|
messages: Initial message(s) to start the conversation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing:
|
||||||
|
- List of agents involved in the group chat
|
||||||
|
- List of wrapped agents
|
||||||
|
- User agent, if applicable
|
||||||
|
- Context variables for the group chat
|
||||||
|
- Initial agent for the group chat
|
||||||
|
- Group-level after work transition for the group chat
|
||||||
|
- Tool executor for the group chat
|
||||||
|
- GroupChat instance
|
||||||
|
- GroupChatManager instance
|
||||||
|
- Processed messages
|
||||||
|
- Last agent to speak
|
||||||
|
- List of group agent names
|
||||||
|
- List of temporary user agents
|
||||||
|
"""
|
||||||
|
from ...groupchat import GroupChat
|
||||||
|
|
||||||
|
# Prepare the agents using the existing helper function
|
||||||
|
tool_executor, wrapped_agents = prepare_group_agents(
|
||||||
|
self.agents, self.context_variables, self.exclude_transit_message
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process the initial messages BEFORE creating the GroupChat
|
||||||
|
# This will create a temporary user agent if needed
|
||||||
|
processed_messages, last_agent, group_agent_names, temp_user_list = process_initial_messages(
|
||||||
|
messages, self.user_agent, self.agents, wrapped_agents
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create transition function (has enclosed state for initial agent)
|
||||||
|
group_transition = create_group_transition(
|
||||||
|
initial_agent=self.initial_agent,
|
||||||
|
tool_execution=tool_executor,
|
||||||
|
group_agent_names=group_agent_names,
|
||||||
|
user_agent=self.user_agent,
|
||||||
|
group_after_work=self.group_after_work,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the group chat - now we use temp_user_list if no user_agent
|
||||||
|
groupchat = GroupChat(
|
||||||
|
agents=[tool_executor]
|
||||||
|
+ self.agents
|
||||||
|
+ wrapped_agents
|
||||||
|
+ ([self.user_agent] if self.user_agent else temp_user_list),
|
||||||
|
messages=[],
|
||||||
|
max_round=max_rounds,
|
||||||
|
speaker_selection_method=group_transition,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the group manager
|
||||||
|
manager = create_group_manager(groupchat, self.group_manager_args, self.agents, self.group_after_work)
|
||||||
|
|
||||||
|
# Point all agent's context variables to this function's context_variables
|
||||||
|
setup_context_variables(
|
||||||
|
tool_execution=tool_executor,
|
||||||
|
agents=self.agents,
|
||||||
|
manager=manager,
|
||||||
|
user_agent=self.user_agent,
|
||||||
|
context_variables=self.context_variables,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Link all agents with the GroupChatManager to allow access to the group chat
|
||||||
|
link_agents_to_group_manager(groupchat.agents, manager)
|
||||||
|
|
||||||
|
return (
|
||||||
|
self.agents,
|
||||||
|
wrapped_agents,
|
||||||
|
self.user_agent,
|
||||||
|
self.context_variables,
|
||||||
|
self.initial_agent,
|
||||||
|
self.group_after_work,
|
||||||
|
tool_executor,
|
||||||
|
groupchat,
|
||||||
|
manager,
|
||||||
|
processed_messages,
|
||||||
|
last_agent,
|
||||||
|
group_agent_names,
|
||||||
|
temp_user_list,
|
||||||
|
) # type: ignore[return-value]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_default(
|
||||||
|
cls,
|
||||||
|
initial_agent: "ConversableAgent",
|
||||||
|
agents: list["ConversableAgent"],
|
||||||
|
user_agent: Optional["ConversableAgent"] = None,
|
||||||
|
group_manager_args: Optional[dict[str, Any]] = None,
|
||||||
|
context_variables: Optional[ContextVariables] = None,
|
||||||
|
exclude_transit_message: bool = True,
|
||||||
|
summary_method: Optional[Union[str, Callable[..., Any]]] = "last_msg",
|
||||||
|
) -> "DefaultPattern":
|
||||||
|
"""Create a default pattern with minimal configuration.
|
||||||
|
|
||||||
|
This replaces the need for a separate BasePattern class by providing
|
||||||
|
a factory method that creates a simple DefaultPattern instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_agent: The first agent to speak in the group chat.
|
||||||
|
agents: List of all agents participating in the chat.
|
||||||
|
user_agent: Optional user proxy agent.
|
||||||
|
group_manager_args: Optional arguments for the GroupChatManager.
|
||||||
|
context_variables: Initial context variables for the chat.
|
||||||
|
exclude_transit_message: Whether to exclude transit messages from the conversation.
|
||||||
|
summary_method: Method for summarizing the conversation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A DefaultPattern instance with basic configuration.
|
||||||
|
"""
|
||||||
|
return DefaultPattern(
|
||||||
|
initial_agent=initial_agent,
|
||||||
|
agents=agents,
|
||||||
|
user_agent=user_agent,
|
||||||
|
group_manager_args=group_manager_args,
|
||||||
|
context_variables=context_variables,
|
||||||
|
exclude_transit_message=exclude_transit_message,
|
||||||
|
summary_method=summary_method,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultPattern(Pattern):
|
||||||
|
"""DefaultPattern implements a minimal pattern for simple agent interactions.
|
||||||
|
|
||||||
|
This replaces the previous BasePattern and provides a concrete implementation
|
||||||
|
of the Pattern abstract base class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def prepare_group_chat(
|
||||||
|
self,
|
||||||
|
max_rounds: int,
|
||||||
|
messages: Union[list[dict[str, Any]], str],
|
||||||
|
) -> Tuple[
|
||||||
|
list["ConversableAgent"],
|
||||||
|
list["ConversableAgent"],
|
||||||
|
Optional["ConversableAgent"],
|
||||||
|
ContextVariables,
|
||||||
|
"ConversableAgent",
|
||||||
|
TransitionTarget,
|
||||||
|
"GroupToolExecutor",
|
||||||
|
"GroupChat",
|
||||||
|
"GroupChatManager",
|
||||||
|
list[dict[str, Any]],
|
||||||
|
Any,
|
||||||
|
list[str],
|
||||||
|
list[Any],
|
||||||
|
]:
|
||||||
|
"""Prepare the group chat with default configuration.
|
||||||
|
|
||||||
|
This implementation calls the parent class method but ensures that
|
||||||
|
the group_after_work in the returned tuple is the pattern's own.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_rounds: Maximum number of conversation rounds.
|
||||||
|
messages: Initial message(s) to start the conversation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing all necessary components for the group chat.
|
||||||
|
"""
|
||||||
|
# Use the parent class's implementation to prepare the agents and group chat
|
||||||
|
(
|
||||||
|
agents,
|
||||||
|
wrapped_agents,
|
||||||
|
user_agent,
|
||||||
|
context_variables,
|
||||||
|
initial_agent,
|
||||||
|
_, # Ignore the group_after_work from parent
|
||||||
|
tool_executor,
|
||||||
|
groupchat,
|
||||||
|
manager,
|
||||||
|
processed_messages,
|
||||||
|
last_agent,
|
||||||
|
group_agent_names,
|
||||||
|
temp_user_list,
|
||||||
|
) = super().prepare_group_chat(
|
||||||
|
max_rounds=max_rounds,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return all components with our group_after_work
|
||||||
|
return (
|
||||||
|
agents,
|
||||||
|
wrapped_agents,
|
||||||
|
user_agent,
|
||||||
|
context_variables,
|
||||||
|
initial_agent,
|
||||||
|
self.group_after_work, # Use our own group_after_work
|
||||||
|
tool_executor,
|
||||||
|
groupchat,
|
||||||
|
manager,
|
||||||
|
processed_messages,
|
||||||
|
last_agent,
|
||||||
|
group_agent_names,
|
||||||
|
temp_user_list,
|
||||||
|
)
|
||||||
106
mm_agents/coact/autogen/agentchat/group/patterns/random.py
Normal file
106
mm_agents/coact/autogen/agentchat/group/patterns/random.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from ..context_variables import ContextVariables
|
||||||
|
from ..targets.transition_target import RandomAgentTarget, TransitionTarget
|
||||||
|
from .pattern import Pattern
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...conversable_agent import ConversableAgent
|
||||||
|
from ...groupchat import GroupChat, GroupChatManager
|
||||||
|
from ..group_tool_executor import GroupToolExecutor
|
||||||
|
|
||||||
|
|
||||||
|
class RandomPattern(Pattern):
|
||||||
|
"""RandomPattern implements a random agent selection process."""
|
||||||
|
|
||||||
|
def _generate_handoffs(
|
||||||
|
self,
|
||||||
|
initial_agent: "ConversableAgent",
|
||||||
|
agents: list["ConversableAgent"],
|
||||||
|
user_agent: Optional["ConversableAgent"],
|
||||||
|
) -> None:
|
||||||
|
"""Generate handoffs between agents in a random fashion."""
|
||||||
|
agent_list = agents + ([user_agent] if user_agent is not None else [])
|
||||||
|
|
||||||
|
for agent in agent_list:
|
||||||
|
# Get the list of agents except itself
|
||||||
|
other_agents = [a for a in agent_list if a != agent]
|
||||||
|
|
||||||
|
# Create a random after work
|
||||||
|
agent.handoffs.set_after_work(target=RandomAgentTarget(agents=other_agents))
|
||||||
|
|
||||||
|
def prepare_group_chat(
|
||||||
|
self,
|
||||||
|
max_rounds: int,
|
||||||
|
messages: Union[list[dict[str, Any]], str],
|
||||||
|
) -> Tuple[
|
||||||
|
list["ConversableAgent"],
|
||||||
|
list["ConversableAgent"],
|
||||||
|
Optional["ConversableAgent"],
|
||||||
|
ContextVariables,
|
||||||
|
"ConversableAgent",
|
||||||
|
TransitionTarget,
|
||||||
|
"GroupToolExecutor",
|
||||||
|
"GroupChat",
|
||||||
|
"GroupChatManager",
|
||||||
|
list[dict[str, Any]],
|
||||||
|
Any,
|
||||||
|
list[str],
|
||||||
|
list[Any],
|
||||||
|
]:
|
||||||
|
"""Prepare the group chat for organic agent selection.
|
||||||
|
|
||||||
|
Ensures that:
|
||||||
|
1. The group manager has a valid LLM config
|
||||||
|
2. All agents have appropriate descriptions for the group manager to use
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_rounds: Maximum number of conversation rounds.
|
||||||
|
messages: Initial message(s) to start the conversation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing all necessary components for the group chat.
|
||||||
|
"""
|
||||||
|
# Use the parent class's implementation to prepare the agents and group chat
|
||||||
|
(
|
||||||
|
agents,
|
||||||
|
wrapped_agents,
|
||||||
|
user_agent,
|
||||||
|
context_variables,
|
||||||
|
initial_agent,
|
||||||
|
group_after_work,
|
||||||
|
tool_executor,
|
||||||
|
groupchat,
|
||||||
|
manager,
|
||||||
|
processed_messages,
|
||||||
|
last_agent,
|
||||||
|
group_agent_names,
|
||||||
|
temp_user_list,
|
||||||
|
) = super().prepare_group_chat(
|
||||||
|
max_rounds=max_rounds,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the random handoffs between agents
|
||||||
|
self._generate_handoffs(initial_agent=initial_agent, agents=agents, user_agent=user_agent)
|
||||||
|
|
||||||
|
# Return all components with our group_after_work
|
||||||
|
return (
|
||||||
|
agents,
|
||||||
|
wrapped_agents,
|
||||||
|
user_agent,
|
||||||
|
context_variables,
|
||||||
|
initial_agent,
|
||||||
|
group_after_work,
|
||||||
|
tool_executor,
|
||||||
|
groupchat,
|
||||||
|
manager,
|
||||||
|
processed_messages,
|
||||||
|
last_agent,
|
||||||
|
group_agent_names,
|
||||||
|
temp_user_list,
|
||||||
|
)
|
||||||
117
mm_agents/coact/autogen/agentchat/group/patterns/round_robin.py
Normal file
117
mm_agents/coact/autogen/agentchat/group/patterns/round_robin.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from ..context_variables import ContextVariables
|
||||||
|
from ..targets.transition_target import AgentTarget, TransitionTarget
|
||||||
|
from .pattern import Pattern
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...conversable_agent import ConversableAgent
|
||||||
|
from ...groupchat import GroupChat, GroupChatManager
|
||||||
|
from ..group_tool_executor import GroupToolExecutor
|
||||||
|
|
||||||
|
|
||||||
|
class RoundRobinPattern(Pattern):
|
||||||
|
"""RoundRobinPattern implements a round robin with handoffs between agents."""
|
||||||
|
|
||||||
|
def _generate_handoffs(
|
||||||
|
self,
|
||||||
|
initial_agent: "ConversableAgent",
|
||||||
|
agents: list["ConversableAgent"],
|
||||||
|
user_agent: Optional["ConversableAgent"],
|
||||||
|
) -> None:
|
||||||
|
"""Generate handoffs between agents in a round-robin fashion."""
|
||||||
|
# Create a list of the agents and the user_agent but put the initial_agent first
|
||||||
|
agent_list = [initial_agent]
|
||||||
|
|
||||||
|
# Add the rest of the agents, excluding the initial_agent and user_agent
|
||||||
|
for agent in agents:
|
||||||
|
if agent != initial_agent and (user_agent is None or agent != user_agent):
|
||||||
|
agent_list.append(agent)
|
||||||
|
|
||||||
|
# Add the user_agent last if it exists
|
||||||
|
if user_agent is not None:
|
||||||
|
agent_list.append(user_agent)
|
||||||
|
|
||||||
|
# Create handoffs in a round-robin fashion
|
||||||
|
for i, agent in enumerate(agent_list):
|
||||||
|
# Last agent hands off to the first agent
|
||||||
|
# Otherwise agent hands off to the next one
|
||||||
|
handoff_target = agent_list[0] if i == len(agent_list) - 1 else agent_list[i + 1]
|
||||||
|
|
||||||
|
agent.handoffs.set_after_work(target=AgentTarget(agent=handoff_target))
|
||||||
|
|
||||||
|
def prepare_group_chat(
|
||||||
|
self,
|
||||||
|
max_rounds: int,
|
||||||
|
messages: Union[list[dict[str, Any]], str],
|
||||||
|
) -> Tuple[
|
||||||
|
list["ConversableAgent"],
|
||||||
|
list["ConversableAgent"],
|
||||||
|
Optional["ConversableAgent"],
|
||||||
|
ContextVariables,
|
||||||
|
"ConversableAgent",
|
||||||
|
TransitionTarget,
|
||||||
|
"GroupToolExecutor",
|
||||||
|
"GroupChat",
|
||||||
|
"GroupChatManager",
|
||||||
|
list[dict[str, Any]],
|
||||||
|
Any,
|
||||||
|
list[str],
|
||||||
|
list[Any],
|
||||||
|
]:
|
||||||
|
"""Prepare the group chat for organic agent selection.
|
||||||
|
|
||||||
|
Ensures that:
|
||||||
|
1. The group manager has a valid LLM config
|
||||||
|
2. All agents have appropriate descriptions for the group manager to use
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_rounds: Maximum number of conversation rounds.
|
||||||
|
messages: Initial message(s) to start the conversation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing all necessary components for the group chat.
|
||||||
|
"""
|
||||||
|
# Use the parent class's implementation to prepare the agents and group chat
|
||||||
|
(
|
||||||
|
agents,
|
||||||
|
wrapped_agents,
|
||||||
|
user_agent,
|
||||||
|
context_variables,
|
||||||
|
initial_agent,
|
||||||
|
group_after_work,
|
||||||
|
tool_executor,
|
||||||
|
groupchat,
|
||||||
|
manager,
|
||||||
|
processed_messages,
|
||||||
|
last_agent,
|
||||||
|
group_agent_names,
|
||||||
|
temp_user_list,
|
||||||
|
) = super().prepare_group_chat(
|
||||||
|
max_rounds=max_rounds,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the handoffs between agents
|
||||||
|
self._generate_handoffs(initial_agent=initial_agent, agents=agents, user_agent=user_agent)
|
||||||
|
|
||||||
|
# Return all components with our group_after_work
|
||||||
|
return (
|
||||||
|
agents,
|
||||||
|
wrapped_agents,
|
||||||
|
user_agent,
|
||||||
|
context_variables,
|
||||||
|
initial_agent,
|
||||||
|
group_after_work,
|
||||||
|
tool_executor,
|
||||||
|
groupchat,
|
||||||
|
manager,
|
||||||
|
processed_messages,
|
||||||
|
last_agent,
|
||||||
|
group_agent_names,
|
||||||
|
temp_user_list,
|
||||||
|
)
|
||||||
26
mm_agents/coact/autogen/agentchat/group/reply_result.py
Normal file
26
mm_agents/coact/autogen/agentchat/group/reply_result.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["ReplyResult"]
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .context_variables import ContextVariables
|
||||||
|
from .targets.transition_target import TransitionTarget
|
||||||
|
|
||||||
|
|
||||||
|
class ReplyResult(BaseModel):
|
||||||
|
"""Result of a tool call that is used to provide the return message and the target to transition to."""
|
||||||
|
|
||||||
|
message: str
|
||||||
|
target: Optional[TransitionTarget] = None
|
||||||
|
context_variables: Optional[ContextVariables] = None
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""The string representation for ReplyResult will be just the message."""
|
||||||
|
return self.message
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ..agent import Agent
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
# Avoid circular import
|
||||||
|
from ..groupchat import GroupChat
|
||||||
|
|
||||||
|
|
||||||
|
class SpeakerSelectionResult(BaseModel):
|
||||||
|
"""Represents a speaker selection result that will be returned to GroupChat._prepare_and_select_agents to determine the next speaker.
|
||||||
|
|
||||||
|
This class can return an Agent, a None to end the conversation, or a string for a speaker selection method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
terminate: Optional[bool] = None
|
||||||
|
agent_name: Optional[str] = None
|
||||||
|
speaker_selection_method: Optional[str] = None
|
||||||
|
|
||||||
|
def get_speaker_selection_result(self, groupchat: "GroupChat") -> Optional[Union[Agent, str]]:
|
||||||
|
"""Get the speaker selection result. If None, the conversation will end."""
|
||||||
|
if self.agent_name is not None:
|
||||||
|
# Find the agent by name in the groupchat
|
||||||
|
for agent in groupchat.agents:
|
||||||
|
if agent.name == self.agent_name:
|
||||||
|
return agent
|
||||||
|
raise ValueError(f"Agent '{self.agent_name}' not found in groupchat.")
|
||||||
|
elif self.speaker_selection_method is not None:
|
||||||
|
return self.speaker_selection_method
|
||||||
|
elif self.terminate is not None and self.terminate:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Unable to establish speaker selection result. No terminate, agent, or speaker selection method provided."
|
||||||
|
)
|
||||||
@@ -0,0 +1,4 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
@@ -0,0 +1,132 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ....doc_utils import export_module
|
||||||
|
from ...agent import Agent
|
||||||
|
from ..speaker_selection_result import SpeakerSelectionResult
|
||||||
|
from .transition_target import AgentTarget, TransitionTarget
|
||||||
|
from .transition_utils import __AGENT_WRAPPER_PREFIX__
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...conversable_agent import ConversableAgent
|
||||||
|
from ...groupchat import GroupChat
|
||||||
|
from ..patterns.pattern import Pattern
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["GroupChatConfig", "GroupChatTarget"]
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.agentchat.group")
|
||||||
|
class GroupChatConfig(BaseModel):
|
||||||
|
"""Configuration for a group chat transition target.
|
||||||
|
|
||||||
|
Note: If context_variables are not passed in, the outer context variables will be passed in"""
|
||||||
|
|
||||||
|
pattern: "Pattern"
|
||||||
|
messages: Union[list[dict[str, Any]], str]
|
||||||
|
max_rounds: int = 20
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.agentchat.group")
|
||||||
|
class GroupChatTarget(TransitionTarget):
|
||||||
|
"""Target that represents a group chat."""
|
||||||
|
|
||||||
|
group_chat_config: GroupChatConfig
|
||||||
|
|
||||||
|
def can_resolve_for_speaker_selection(self) -> bool:
|
||||||
|
"""Check if the target can resolve for speaker selection. For GroupChatTarget the chat must be encapsulated into an agent."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def resolve(
|
||||||
|
self,
|
||||||
|
groupchat: "GroupChat",
|
||||||
|
current_agent: "ConversableAgent",
|
||||||
|
user_agent: Optional["ConversableAgent"],
|
||||||
|
) -> SpeakerSelectionResult:
|
||||||
|
"""Resolve to the nested chat configuration."""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"GroupChatTarget does not support the resolve method. An agent should be used to encapsulate this nested chat and then the target changed to an AgentTarget."
|
||||||
|
)
|
||||||
|
|
||||||
|
def display_name(self) -> str:
|
||||||
|
"""Get the display name for the target."""
|
||||||
|
return "a group chat"
|
||||||
|
|
||||||
|
def normalized_name(self) -> str:
|
||||||
|
"""Get a normalized name for the target that has no spaces, used for function calling."""
|
||||||
|
return "group_chat"
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""String representation for AgentTarget, can be shown as a function call message."""
|
||||||
|
return "Transfer to group chat"
|
||||||
|
|
||||||
|
def needs_agent_wrapper(self) -> bool:
|
||||||
|
"""Check if the target needs to be wrapped in an agent. GroupChatTarget must be wrapped in an agent."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
|
||||||
|
"""Create a wrapper agent for the group chat."""
|
||||||
|
from autogen.agentchat import initiate_group_chat
|
||||||
|
|
||||||
|
from ...conversable_agent import ConversableAgent # to avoid circular import
|
||||||
|
|
||||||
|
# Create the wrapper agent with a name that identifies it as a wrapped group chat
|
||||||
|
group_chat_agent = ConversableAgent(
|
||||||
|
name=f"{__AGENT_WRAPPER_PREFIX__}group_{parent_agent.name}_{index + 1}",
|
||||||
|
# Copy LLM config from parent agent to ensure it can generate replies if needed
|
||||||
|
llm_config=parent_agent.llm_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store the config directly on the agent
|
||||||
|
group_chat_agent._group_chat_config = self.group_chat_config # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
# Define the reply function that will run the group chat
|
||||||
|
def group_chat_reply(
|
||||||
|
agent: "ConversableAgent",
|
||||||
|
messages: Optional[list[dict[str, Any]]] = None,
|
||||||
|
sender: Optional["Agent"] = None,
|
||||||
|
config: Optional[Any] = None,
|
||||||
|
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||||
|
"""Run the inner group chat and return its results as a reply."""
|
||||||
|
# Get the configuration stored directly on the agent
|
||||||
|
group_config = agent._group_chat_config # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
# Pull through the second last message from the outer chat (the last message will be the handoff message)
|
||||||
|
# This may need work to make sure we get the right message(s) from the outer chat
|
||||||
|
message = (
|
||||||
|
messages[-2]["content"]
|
||||||
|
if messages and len(messages) >= 2 and "content" in messages[-2]
|
||||||
|
else "No message to pass through."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Run the group chat with direct agent references from the config
|
||||||
|
result, _, _ = initiate_group_chat(
|
||||||
|
pattern=group_config.pattern,
|
||||||
|
messages=message,
|
||||||
|
max_rounds=group_config.max_rounds,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return the summary from the chat result summary
|
||||||
|
return True, {"content": result.summary}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Handle any errors during execution
|
||||||
|
return True, {"content": f"Error running group chat: {str(e)}"}
|
||||||
|
|
||||||
|
# Register the reply function with the wrapper agent
|
||||||
|
group_chat_agent.register_reply(
|
||||||
|
trigger=[ConversableAgent, None],
|
||||||
|
reply_func=group_chat_reply,
|
||||||
|
remove_other_reply_funcs=True, # Use only this reply function
|
||||||
|
)
|
||||||
|
|
||||||
|
# After the group chat completes, transition back to the parent agent
|
||||||
|
group_chat_agent.handoffs.set_after_work(AgentTarget(parent_agent))
|
||||||
|
|
||||||
|
return group_chat_agent
|
||||||
@@ -0,0 +1,151 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional, Type, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
|
from ....doc_utils import export_module
|
||||||
|
from ..context_str import ContextStr
|
||||||
|
from ..group_tool_executor import GroupToolExecutor
|
||||||
|
from ..speaker_selection_result import SpeakerSelectionResult
|
||||||
|
from .transition_target import TransitionTarget
|
||||||
|
from .transition_utils import __AGENT_WRAPPER_PREFIX__
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
# Avoid circular import
|
||||||
|
from ...conversable_agent import ConversableAgent
|
||||||
|
from ...groupchat import GroupChat
|
||||||
|
|
||||||
|
__all__ = ["GroupManagerTarget"]
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_groupchat_auto_speaker(
|
||||||
|
groupchat: "GroupChat",
|
||||||
|
last_group_agent: "ConversableAgent",
|
||||||
|
group_chat_manager_selection_msg: Optional[Any],
|
||||||
|
) -> None:
|
||||||
|
"""Prepare the group chat for auto speaker selection, includes updating or restore the groupchat speaker selection message.
|
||||||
|
|
||||||
|
Tool Executor and wrapped agents will be removed from the available agents list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
groupchat (GroupChat): GroupChat instance.
|
||||||
|
last_group_agent ("ConversableAgent"): The last group agent for which the LLM config is used
|
||||||
|
group_chat_manager_selection_msg (GroupManagerSelectionMessage): Optional message to use for the agent selection (in internal group chat).
|
||||||
|
"""
|
||||||
|
from ...groupchat import SELECT_SPEAKER_PROMPT_TEMPLATE
|
||||||
|
|
||||||
|
def substitute_agentlist(template: str) -> str:
|
||||||
|
# Run through group chat's string substitution first for {agentlist}
|
||||||
|
# We need to do this so that the next substitution doesn't fail with agentlist
|
||||||
|
# and we can remove the tool executor and wrapped chats from the available agents list
|
||||||
|
agent_list = [
|
||||||
|
agent
|
||||||
|
for agent in groupchat.agents
|
||||||
|
if not isinstance(agent, GroupToolExecutor) and not agent.name.startswith(__AGENT_WRAPPER_PREFIX__)
|
||||||
|
]
|
||||||
|
|
||||||
|
groupchat.select_speaker_prompt_template = template
|
||||||
|
return groupchat.select_speaker_prompt(agent_list)
|
||||||
|
|
||||||
|
# Use the default speaker selection prompt if one is not specified, otherwise use the specified one
|
||||||
|
groupchat.select_speaker_prompt_template = substitute_agentlist(
|
||||||
|
SELECT_SPEAKER_PROMPT_TEMPLATE
|
||||||
|
if group_chat_manager_selection_msg is None
|
||||||
|
else group_chat_manager_selection_msg.get_message(last_group_agent)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# GroupManagerSelectionMessage protocol and implementations
|
||||||
|
@export_module("autogen.agentchat.group")
|
||||||
|
class GroupManagerSelectionMessage(BaseModel):
|
||||||
|
"""Base class for all GroupManager selection message types."""
|
||||||
|
|
||||||
|
def get_message(self, agent: "ConversableAgent") -> str:
|
||||||
|
"""Get the formatted message."""
|
||||||
|
raise NotImplementedError("Requires subclasses to implement.")
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.agentchat.group")
|
||||||
|
class GroupManagerSelectionMessageString(GroupManagerSelectionMessage):
|
||||||
|
"""Selection message that uses a plain string template."""
|
||||||
|
|
||||||
|
message: str
|
||||||
|
|
||||||
|
def get_message(self, agent: "ConversableAgent") -> str:
|
||||||
|
"""Get the message string."""
|
||||||
|
return self.message
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.agentchat.group")
|
||||||
|
class GroupManagerSelectionMessageContextStr(GroupManagerSelectionMessage):
|
||||||
|
"""Selection message that uses a ContextStr template."""
|
||||||
|
|
||||||
|
context_str_template: str
|
||||||
|
|
||||||
|
# We will replace {agentlist} with another term and return it later for use with the internal group chat auto speaker selection
|
||||||
|
# Otherwise our format will fail
|
||||||
|
@field_validator("context_str_template", mode="before")
|
||||||
|
def _replace_agentlist_placeholder(cls: Type["GroupManagerSelectionMessageContextStr"], v: Any) -> Union[str, Any]: # noqa: N805
|
||||||
|
"""Replace {agentlist} placeholder before validation/assignment."""
|
||||||
|
if isinstance(v, str):
|
||||||
|
if "{agentlist}" in v:
|
||||||
|
return v.replace("{agentlist}", "<<agent_list>>") # Perform the replacement
|
||||||
|
else:
|
||||||
|
return v # If no replacement is needed, return the original value
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def get_message(self, agent: "ConversableAgent") -> str:
|
||||||
|
"""Get the formatted message with context variables substituted."""
|
||||||
|
context_str = ContextStr(template=self.context_str_template)
|
||||||
|
format_result = context_str.format(agent.context_variables)
|
||||||
|
if format_result is None:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return format_result.replace(
|
||||||
|
"<<agent_list>>", "{agentlist}"
|
||||||
|
) # Restore agentlist so it can be substituted by the internal group chat auto speaker selection
|
||||||
|
|
||||||
|
|
||||||
|
class GroupManagerTarget(TransitionTarget):
|
||||||
|
"""Target that represents an agent by name."""
|
||||||
|
|
||||||
|
selection_message: Optional[GroupManagerSelectionMessage] = None
|
||||||
|
|
||||||
|
def can_resolve_for_speaker_selection(self) -> bool:
|
||||||
|
"""Check if the target can resolve for speaker selection."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def resolve(
|
||||||
|
self,
|
||||||
|
groupchat: "GroupChat",
|
||||||
|
current_agent: "ConversableAgent",
|
||||||
|
user_agent: Optional["ConversableAgent"],
|
||||||
|
) -> SpeakerSelectionResult:
|
||||||
|
"""Resolve to the speaker selection for the group."""
|
||||||
|
if self.selection_message is not None:
|
||||||
|
prepare_groupchat_auto_speaker(groupchat, current_agent, self.selection_message)
|
||||||
|
|
||||||
|
return SpeakerSelectionResult(speaker_selection_method="auto")
|
||||||
|
|
||||||
|
def display_name(self) -> str:
|
||||||
|
"""Get the display name for the target."""
|
||||||
|
return "the group manager"
|
||||||
|
|
||||||
|
def normalized_name(self) -> str:
|
||||||
|
"""Get a normalized name for the target that has no spaces, used for function calling"""
|
||||||
|
return self.display_name()
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""String representation for AgentTarget, can be shown as a function call message."""
|
||||||
|
return "Transfer to the group manager"
|
||||||
|
|
||||||
|
def needs_agent_wrapper(self) -> bool:
|
||||||
|
"""Check if the target needs to be wrapped in an agent."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
|
||||||
|
"""Create a wrapper agent for the target if needed."""
|
||||||
|
raise NotImplementedError("GroupManagerTarget does not require wrapping in an agent.")
|
||||||
@@ -0,0 +1,413 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import random
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ..speaker_selection_result import SpeakerSelectionResult
|
||||||
|
from .transition_utils import __AGENT_WRAPPER_PREFIX__
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
# Avoid circular import
|
||||||
|
from ...conversable_agent import ConversableAgent
|
||||||
|
from ...groupchat import GroupChat
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AgentNameTarget",
|
||||||
|
"AgentTarget",
|
||||||
|
"AskUserTarget",
|
||||||
|
"NestedChatTarget",
|
||||||
|
"RandomAgentTarget",
|
||||||
|
"RevertToUserTarget",
|
||||||
|
"StayTarget",
|
||||||
|
"TerminateTarget",
|
||||||
|
"TransitionTarget",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Common options for transitions
|
||||||
|
# terminate: Terminate the conversation
|
||||||
|
# revert_to_user: Revert to the user agent
|
||||||
|
# stay: Stay with the current agent
|
||||||
|
# group_manager: Use the group manager (auto speaker selection)
|
||||||
|
# ask_user: Use the user manager (ask the user, aka manual)
|
||||||
|
# TransitionOption = Literal["terminate", "revert_to_user", "stay", "group_manager", "ask_user"]
|
||||||
|
|
||||||
|
|
||||||
|
class TransitionTarget(BaseModel):
|
||||||
|
"""Base class for all transition targets across OnCondition, OnContextCondition, and after work."""
|
||||||
|
|
||||||
|
def can_resolve_for_speaker_selection(self) -> bool:
|
||||||
|
"""Check if the target can resolve to an option for speaker selection (Agent, 'None' to end, Str for speaker selection method). In the case of a nested chat, this will return False as it should be encapsulated in an agent."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def resolve(
|
||||||
|
self,
|
||||||
|
groupchat: "GroupChat",
|
||||||
|
current_agent: "ConversableAgent",
|
||||||
|
user_agent: Optional["ConversableAgent"],
|
||||||
|
) -> SpeakerSelectionResult:
|
||||||
|
"""Resolve to a speaker selection result (Agent, None for termination, or str for speaker selection method)."""
|
||||||
|
raise NotImplementedError("Requires subclasses to implement.")
|
||||||
|
|
||||||
|
def display_name(self) -> str:
|
||||||
|
"""Get the display name for the target."""
|
||||||
|
raise NotImplementedError("Requires subclasses to implement.")
|
||||||
|
|
||||||
|
def normalized_name(self) -> str:
|
||||||
|
"""Get a normalized name for the target that has no spaces, used for function calling"""
|
||||||
|
raise NotImplementedError("Requires subclasses to implement.")
|
||||||
|
|
||||||
|
def needs_agent_wrapper(self) -> bool:
|
||||||
|
"""Check if the target needs to be wrapped in an agent."""
|
||||||
|
raise NotImplementedError("Requires subclasses to implement.")
|
||||||
|
|
||||||
|
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
|
||||||
|
"""Create a wrapper agent for the target if needed."""
|
||||||
|
raise NotImplementedError("Requires subclasses to implement.")
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTarget(TransitionTarget):
|
||||||
|
"""Target that represents a direct agent reference."""
|
||||||
|
|
||||||
|
agent_name: str
|
||||||
|
|
||||||
|
def __init__(self, agent: "ConversableAgent", **data: Any) -> None: # type: ignore[no-untyped-def]
|
||||||
|
# Store the name from the agent for serialization
|
||||||
|
super().__init__(agent_name=agent.name, **data)
|
||||||
|
|
||||||
|
def can_resolve_for_speaker_selection(self) -> bool:
|
||||||
|
"""Check if the target can resolve for speaker selection."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def resolve(
|
||||||
|
self,
|
||||||
|
groupchat: "GroupChat",
|
||||||
|
current_agent: "ConversableAgent",
|
||||||
|
user_agent: Optional["ConversableAgent"],
|
||||||
|
) -> SpeakerSelectionResult:
|
||||||
|
"""Resolve to the actual agent object from the groupchat."""
|
||||||
|
return SpeakerSelectionResult(agent_name=self.agent_name)
|
||||||
|
|
||||||
|
def display_name(self) -> str:
|
||||||
|
"""Get the display name for the target."""
|
||||||
|
return f"{self.agent_name}"
|
||||||
|
|
||||||
|
def normalized_name(self) -> str:
|
||||||
|
"""Get a normalized name for the target that has no spaces, used for function calling"""
|
||||||
|
return self.display_name()
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""String representation for AgentTarget, can be shown as a function call message."""
|
||||||
|
return f"Transfer to {self.agent_name}"
|
||||||
|
|
||||||
|
def needs_agent_wrapper(self) -> bool:
|
||||||
|
"""Check if the target needs to be wrapped in an agent."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
|
||||||
|
"""Create a wrapper agent for the target if needed."""
|
||||||
|
raise NotImplementedError("AgentTarget does not require wrapping in an agent.")
|
||||||
|
|
||||||
|
|
||||||
|
class AgentNameTarget(TransitionTarget):
|
||||||
|
"""Target that represents an agent by name."""
|
||||||
|
|
||||||
|
agent_name: str
|
||||||
|
|
||||||
|
def __init__(self, agent_name: str, **data: Any) -> None:
|
||||||
|
"""Initialize with agent name as a positional parameter."""
|
||||||
|
super().__init__(agent_name=agent_name, **data)
|
||||||
|
|
||||||
|
def can_resolve_for_speaker_selection(self) -> bool:
|
||||||
|
"""Check if the target can resolve for speaker selection."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def resolve(
|
||||||
|
self,
|
||||||
|
groupchat: "GroupChat",
|
||||||
|
current_agent: "ConversableAgent",
|
||||||
|
user_agent: Optional["ConversableAgent"],
|
||||||
|
) -> SpeakerSelectionResult:
|
||||||
|
"""Resolve to the agent name string."""
|
||||||
|
return SpeakerSelectionResult(agent_name=self.agent_name)
|
||||||
|
|
||||||
|
def display_name(self) -> str:
|
||||||
|
"""Get the display name for the target."""
|
||||||
|
return f"{self.agent_name}"
|
||||||
|
|
||||||
|
def normalized_name(self) -> str:
|
||||||
|
"""Get a normalized name for the target that has no spaces, used for function calling"""
|
||||||
|
return self.display_name()
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""String representation for AgentTarget, can be shown as a function call message."""
|
||||||
|
return f"Transfer to {self.agent_name}"
|
||||||
|
|
||||||
|
def needs_agent_wrapper(self) -> bool:
|
||||||
|
"""Check if the target needs to be wrapped in an agent."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
|
||||||
|
"""Create a wrapper agent for the target if needed."""
|
||||||
|
raise NotImplementedError("AgentNameTarget does not require wrapping in an agent.")
|
||||||
|
|
||||||
|
|
||||||
|
class NestedChatTarget(TransitionTarget):
|
||||||
|
"""Target that represents a nested chat configuration."""
|
||||||
|
|
||||||
|
nested_chat_config: dict[str, Any]
|
||||||
|
|
||||||
|
def can_resolve_for_speaker_selection(self) -> bool:
|
||||||
|
"""Check if the target can resolve for speaker selection. For NestedChatTarget the nested chat must be encapsulated into an agent."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def resolve(
|
||||||
|
self,
|
||||||
|
groupchat: "GroupChat",
|
||||||
|
current_agent: "ConversableAgent",
|
||||||
|
user_agent: Optional["ConversableAgent"],
|
||||||
|
) -> SpeakerSelectionResult:
|
||||||
|
"""Resolve to the nested chat configuration."""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"NestedChatTarget does not support the resolve method. An agent should be used to encapsulate this nested chat and then the target changed to an AgentTarget."
|
||||||
|
)
|
||||||
|
|
||||||
|
def display_name(self) -> str:
|
||||||
|
"""Get the display name for the target."""
|
||||||
|
return "a nested chat"
|
||||||
|
|
||||||
|
def normalized_name(self) -> str:
|
||||||
|
"""Get a normalized name for the target that has no spaces, used for function calling"""
|
||||||
|
return "nested_chat"
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""String representation for AgentTarget, can be shown as a function call message."""
|
||||||
|
return "Transfer to nested chat"
|
||||||
|
|
||||||
|
def needs_agent_wrapper(self) -> bool:
|
||||||
|
"""Check if the target needs to be wrapped in an agent. NestedChatTarget must be wrapped in an agent."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
|
||||||
|
"""Create a wrapper agent for the nested chat."""
|
||||||
|
from ...conversable_agent import ConversableAgent # to avoid circular import - NEED SOLUTION
|
||||||
|
|
||||||
|
nested_chat_agent = ConversableAgent(name=f"{__AGENT_WRAPPER_PREFIX__}nested_{parent_agent.name}_{index + 1}")
|
||||||
|
|
||||||
|
nested_chat_agent.register_nested_chats(
|
||||||
|
self.nested_chat_config["chat_queue"],
|
||||||
|
reply_func_from_nested_chats=self.nested_chat_config.get("reply_func_from_nested_chats")
|
||||||
|
or "summary_from_nested_chats",
|
||||||
|
config=self.nested_chat_config.get("config"),
|
||||||
|
trigger=lambda sender: True,
|
||||||
|
position=0,
|
||||||
|
use_async=self.nested_chat_config.get("use_async", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# After the nested chat is complete, transfer back to the parent agent
|
||||||
|
nested_chat_agent.handoffs.set_after_work(AgentTarget(parent_agent))
|
||||||
|
|
||||||
|
return nested_chat_agent
|
||||||
|
|
||||||
|
|
||||||
|
class TerminateTarget(TransitionTarget):
|
||||||
|
"""Target that represents a termination of the conversation."""
|
||||||
|
|
||||||
|
def can_resolve_for_speaker_selection(self) -> bool:
|
||||||
|
"""Check if the target can resolve for speaker selection."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def resolve(
|
||||||
|
self,
|
||||||
|
groupchat: "GroupChat",
|
||||||
|
current_agent: "ConversableAgent",
|
||||||
|
user_agent: Optional["ConversableAgent"],
|
||||||
|
) -> SpeakerSelectionResult:
|
||||||
|
"""Resolve to termination."""
|
||||||
|
return SpeakerSelectionResult(terminate=True)
|
||||||
|
|
||||||
|
def display_name(self) -> str:
|
||||||
|
"""Get the display name for the target."""
|
||||||
|
return "Terminate"
|
||||||
|
|
||||||
|
def normalized_name(self) -> str:
|
||||||
|
"""Get a normalized name for the target that has no spaces, used for function calling"""
|
||||||
|
return "terminate"
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""String representation for AgentTarget, can be shown as a function call message."""
|
||||||
|
return "Terminate"
|
||||||
|
|
||||||
|
def needs_agent_wrapper(self) -> bool:
|
||||||
|
"""Check if the target needs to be wrapped in an agent."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
|
||||||
|
"""Create a wrapper agent for the target if needed."""
|
||||||
|
raise NotImplementedError("TerminateTarget does not require wrapping in an agent.")
|
||||||
|
|
||||||
|
|
||||||
|
class StayTarget(TransitionTarget):
|
||||||
|
"""Target that represents staying with the current agent."""
|
||||||
|
|
||||||
|
def can_resolve_for_speaker_selection(self) -> bool:
|
||||||
|
"""Check if the target can resolve for speaker selection."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def resolve(
|
||||||
|
self,
|
||||||
|
groupchat: "GroupChat",
|
||||||
|
current_agent: "ConversableAgent",
|
||||||
|
user_agent: Optional["ConversableAgent"],
|
||||||
|
) -> SpeakerSelectionResult:
|
||||||
|
"""Resolve to staying with the current agent."""
|
||||||
|
return SpeakerSelectionResult(agent_name=current_agent.name)
|
||||||
|
|
||||||
|
def display_name(self) -> str:
|
||||||
|
"""Get the display name for the target."""
|
||||||
|
return "Stay"
|
||||||
|
|
||||||
|
def normalized_name(self) -> str:
|
||||||
|
"""Get a normalized name for the target that has no spaces, used for function calling"""
|
||||||
|
return "stay"
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""String representation for AgentTarget, can be shown as a function call message."""
|
||||||
|
return "Stay with agent"
|
||||||
|
|
||||||
|
def needs_agent_wrapper(self) -> bool:
|
||||||
|
"""Check if the target needs to be wrapped in an agent."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
|
||||||
|
"""Create a wrapper agent for the target if needed."""
|
||||||
|
raise NotImplementedError("StayTarget does not require wrapping in an agent.")
|
||||||
|
|
||||||
|
|
||||||
|
class RevertToUserTarget(TransitionTarget):
|
||||||
|
"""Target that represents reverting to the user agent."""
|
||||||
|
|
||||||
|
def can_resolve_for_speaker_selection(self) -> bool:
|
||||||
|
"""Check if the target can resolve for speaker selection."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def resolve(
|
||||||
|
self,
|
||||||
|
groupchat: "GroupChat",
|
||||||
|
current_agent: "ConversableAgent",
|
||||||
|
user_agent: Optional["ConversableAgent"],
|
||||||
|
) -> SpeakerSelectionResult:
|
||||||
|
"""Resolve to reverting to the user agent."""
|
||||||
|
if user_agent is None:
|
||||||
|
raise ValueError("User agent must be provided to the chat for the revert_to_user option.")
|
||||||
|
return SpeakerSelectionResult(agent_name=user_agent.name)
|
||||||
|
|
||||||
|
def display_name(self) -> str:
|
||||||
|
"""Get the display name for the target."""
|
||||||
|
return "Revert to User"
|
||||||
|
|
||||||
|
def normalized_name(self) -> str:
|
||||||
|
"""Get a normalized name for the target that has no spaces, used for function calling"""
|
||||||
|
return "revert_to_user"
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""String representation for AgentTarget, can be shown as a function call message."""
|
||||||
|
return "Revert to User"
|
||||||
|
|
||||||
|
def needs_agent_wrapper(self) -> bool:
|
||||||
|
"""Check if the target needs to be wrapped in an agent."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
|
||||||
|
"""Create a wrapper agent for the target if needed."""
|
||||||
|
raise NotImplementedError("RevertToUserTarget does not require wrapping in an agent.")
|
||||||
|
|
||||||
|
|
||||||
|
class AskUserTarget(TransitionTarget):
|
||||||
|
"""Target that represents asking the user for input."""
|
||||||
|
|
||||||
|
def can_resolve_for_speaker_selection(self) -> bool:
|
||||||
|
"""Check if the target can resolve for speaker selection."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def resolve(
|
||||||
|
self,
|
||||||
|
groupchat: "GroupChat",
|
||||||
|
current_agent: "ConversableAgent",
|
||||||
|
user_agent: Optional["ConversableAgent"],
|
||||||
|
) -> SpeakerSelectionResult:
|
||||||
|
"""Resolve to asking the user for input."""
|
||||||
|
return SpeakerSelectionResult(speaker_selection_method="manual")
|
||||||
|
|
||||||
|
def display_name(self) -> str:
|
||||||
|
"""Get the display name for the target."""
|
||||||
|
return "Ask User"
|
||||||
|
|
||||||
|
def normalized_name(self) -> str:
|
||||||
|
"""Get a normalized name for the target that has no spaces, used for function calling"""
|
||||||
|
return "ask_user"
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""String representation for AgentTarget, can be shown as a function call message."""
|
||||||
|
return "Ask User"
|
||||||
|
|
||||||
|
def needs_agent_wrapper(self) -> bool:
|
||||||
|
"""Check if the target needs to be wrapped in an agent."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
|
||||||
|
"""Create a wrapper agent for the target if needed."""
|
||||||
|
raise NotImplementedError("AskUserTarget does not require wrapping in an agent.")
|
||||||
|
|
||||||
|
|
||||||
|
class RandomAgentTarget(TransitionTarget):
|
||||||
|
"""Target that represents a random selection from a list of agents."""
|
||||||
|
|
||||||
|
agent_names: list[str]
|
||||||
|
nominated_name: str = "<Not Randomly Selected Yet>"
|
||||||
|
|
||||||
|
def __init__(self, agents: list["ConversableAgent"], **data: Any) -> None: # type: ignore[no-untyped-def]
|
||||||
|
# Store the name from the agent for serialization
|
||||||
|
super().__init__(agent_names=[agent.name for agent in agents], **data)
|
||||||
|
|
||||||
|
def can_resolve_for_speaker_selection(self) -> bool:
|
||||||
|
"""Check if the target can resolve for speaker selection."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def resolve(
|
||||||
|
self,
|
||||||
|
groupchat: "GroupChat",
|
||||||
|
current_agent: "ConversableAgent",
|
||||||
|
user_agent: Optional["ConversableAgent"],
|
||||||
|
) -> SpeakerSelectionResult:
|
||||||
|
"""Resolve to the actual agent object from the groupchat, choosing a random agent (except the current one)"""
|
||||||
|
# Randomly select the next agent
|
||||||
|
self.nominated_name = random.choice([name for name in self.agent_names if name != current_agent.name])
|
||||||
|
|
||||||
|
return SpeakerSelectionResult(agent_name=self.nominated_name)
|
||||||
|
|
||||||
|
def display_name(self) -> str:
|
||||||
|
"""Get the display name for the target."""
|
||||||
|
return self.nominated_name
|
||||||
|
|
||||||
|
def normalized_name(self) -> str:
|
||||||
|
"""Get a normalized name for the target that has no spaces, used for function calling"""
|
||||||
|
return self.display_name()
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""String representation for RandomAgentTarget, can be shown as a function call message."""
|
||||||
|
return f"Transfer to {self.nominated_name}"
|
||||||
|
|
||||||
|
def needs_agent_wrapper(self) -> bool:
|
||||||
|
"""Check if the target needs to be wrapped in an agent."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent":
|
||||||
|
"""Create a wrapper agent for the target if needed."""
|
||||||
|
raise NotImplementedError("RandomAgentTarget does not require wrapping in an agent.")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Consider adding a SequentialChatTarget class
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
# Prefix for all wrapped agent names
|
||||||
|
__AGENT_WRAPPER_PREFIX__ = "wrapped_"
|
||||||
1694
mm_agents/coact/autogen/agentchat/groupchat.py
Normal file
1694
mm_agents/coact/autogen/agentchat/groupchat.py
Normal file
File diff suppressed because it is too large
Load Diff
3
mm_agents/coact/autogen/agentchat/realtime/__init__.py
Normal file
3
mm_agents/coact/autogen/agentchat/realtime/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from .audio_adapters import TwilioAudioAdapter, WebSocketAudioAdapter
|
||||||
|
from .audio_observer import AudioObserver
|
||||||
|
from .function_observer import FunctionObserver
|
||||||
|
from .realtime_agent import RealtimeAgent
|
||||||
|
from .realtime_observer import RealtimeObserver
|
||||||
|
from .realtime_swarm import register_swarm
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AudioObserver",
|
||||||
|
"FunctionObserver",
|
||||||
|
"RealtimeAgent",
|
||||||
|
"RealtimeObserver",
|
||||||
|
"TwilioAudioAdapter",
|
||||||
|
"WebSocketAudioAdapter",
|
||||||
|
"register_swarm",
|
||||||
|
]
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from .twilio_audio_adapter import TwilioAudioAdapter
|
||||||
|
from .websocket_audio_adapter import WebSocketAudioAdapter
|
||||||
|
|
||||||
|
__all__ = ["TwilioAudioAdapter", "WebSocketAudioAdapter"]
|
||||||
@@ -0,0 +1,148 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
from logging import Logger
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
from .....doc_utils import export_module
|
||||||
|
from ..realtime_events import AudioDelta, RealtimeEvent, SpeechStarted
|
||||||
|
from ..realtime_observer import RealtimeObserver
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..websockets import WebSocketProtocol as WebSocket
|
||||||
|
|
||||||
|
|
||||||
|
LOG_EVENT_TYPES = [
|
||||||
|
"error",
|
||||||
|
"response.content.done",
|
||||||
|
"rate_limits.updated",
|
||||||
|
"response.done",
|
||||||
|
"input_audio_buffer.committed",
|
||||||
|
"input_audio_buffer.speech_stopped",
|
||||||
|
"input_audio_buffer.speech_started",
|
||||||
|
"session.created",
|
||||||
|
]
|
||||||
|
SHOW_TIMING_MATH = False
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.agentchat.realtime.experimental")
|
||||||
|
class TwilioAudioAdapter(RealtimeObserver):
|
||||||
|
"""Adapter for streaming audio from Twilio to OpenAI Realtime API and vice versa."""
|
||||||
|
|
||||||
|
def __init__(self, websocket: "WebSocket", *, logger: Optional[Logger] = None):
|
||||||
|
"""Adapter for streaming audio from Twilio to OpenAI Realtime API and vice versa.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
websocket: the websocket connection to the Twilio service
|
||||||
|
logger: the logger to use for logging events
|
||||||
|
"""
|
||||||
|
super().__init__(logger=logger)
|
||||||
|
self.websocket = websocket
|
||||||
|
|
||||||
|
# Connection specific state
|
||||||
|
self.stream_sid = None
|
||||||
|
self.latest_media_timestamp = 0
|
||||||
|
self.last_assistant_item: Optional[str] = None
|
||||||
|
self.mark_queue: list[str] = []
|
||||||
|
self.response_start_timestamp_twilio: Optional[int] = None
|
||||||
|
|
||||||
|
async def on_event(self, event: RealtimeEvent) -> None:
|
||||||
|
"""Receive events from the OpenAI Realtime API, send audio back to Twilio."""
|
||||||
|
logger = self.logger
|
||||||
|
|
||||||
|
if isinstance(event, AudioDelta):
|
||||||
|
audio_payload = base64.b64encode(base64.b64decode(event.delta)).decode("utf-8")
|
||||||
|
audio_delta = {"event": "media", "streamSid": self.stream_sid, "media": {"payload": audio_payload}}
|
||||||
|
await self.websocket.send_json(audio_delta)
|
||||||
|
|
||||||
|
if self.response_start_timestamp_twilio is None:
|
||||||
|
self.response_start_timestamp_twilio = self.latest_media_timestamp
|
||||||
|
if SHOW_TIMING_MATH:
|
||||||
|
logger.info(f"Setting start timestamp for new response: {self.response_start_timestamp_twilio}ms")
|
||||||
|
|
||||||
|
# Update last_assistant_item safely
|
||||||
|
if event.item_id:
|
||||||
|
self.last_assistant_item = event.item_id
|
||||||
|
|
||||||
|
await self.send_mark()
|
||||||
|
|
||||||
|
# Trigger an interruption. Your use case might work better using `input_audio_buffer.speech_stopped`, or combining the two.
|
||||||
|
if isinstance(event, SpeechStarted):
|
||||||
|
logger.info("Speech start detected.")
|
||||||
|
if self.last_assistant_item:
|
||||||
|
logger.info(f"Interrupting response with id: {self.last_assistant_item}")
|
||||||
|
await self.handle_speech_started_event()
|
||||||
|
|
||||||
|
async def handle_speech_started_event(self) -> None:
|
||||||
|
"""Handle interruption when the caller's speech starts."""
|
||||||
|
logger = self.logger
|
||||||
|
|
||||||
|
logger.info("Handling speech started event.")
|
||||||
|
if self.mark_queue and self.response_start_timestamp_twilio is not None:
|
||||||
|
elapsed_time = self.latest_media_timestamp - self.response_start_timestamp_twilio
|
||||||
|
if SHOW_TIMING_MATH:
|
||||||
|
logger.info(
|
||||||
|
f"Calculating elapsed time for truncation: {self.latest_media_timestamp} - {self.response_start_timestamp_twilio} = {elapsed_time}ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.last_assistant_item:
|
||||||
|
if SHOW_TIMING_MATH:
|
||||||
|
logger.info(f"Truncating item with ID: {self.last_assistant_item}, Truncated at: {elapsed_time}ms")
|
||||||
|
|
||||||
|
await self.realtime_client.truncate_audio(
|
||||||
|
audio_end_ms=elapsed_time,
|
||||||
|
content_index=0,
|
||||||
|
item_id=self.last_assistant_item,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.websocket.send_json({"event": "clear", "streamSid": self.stream_sid})
|
||||||
|
|
||||||
|
self.mark_queue.clear()
|
||||||
|
self.last_assistant_item = None
|
||||||
|
self.response_start_timestamp_twilio = None
|
||||||
|
|
||||||
|
async def send_mark(self) -> None:
|
||||||
|
"""Send a mark of audio interruption to the Twilio websocket."""
|
||||||
|
if self.stream_sid:
|
||||||
|
mark_event = {"event": "mark", "streamSid": self.stream_sid, "mark": {"name": "responsePart"}}
|
||||||
|
await self.websocket.send_json(mark_event)
|
||||||
|
self.mark_queue.append("responsePart")
|
||||||
|
|
||||||
|
async def run_loop(self) -> None:
|
||||||
|
"""Run the adapter loop."""
|
||||||
|
logger = self.logger
|
||||||
|
|
||||||
|
async for message in self.websocket.iter_text():
|
||||||
|
try:
|
||||||
|
data = json.loads(message)
|
||||||
|
if data["event"] == "media":
|
||||||
|
self.latest_media_timestamp = int(data["media"]["timestamp"])
|
||||||
|
await self.realtime_client.send_audio(audio=data["media"]["payload"])
|
||||||
|
elif data["event"] == "start":
|
||||||
|
self.stream_sid = data["start"]["streamSid"]
|
||||||
|
logger.info(f"Incoming stream has started {self.stream_sid}")
|
||||||
|
self.response_start_timestamp_twilio = None
|
||||||
|
self.latest_media_timestamp = 0
|
||||||
|
self.last_assistant_item = None
|
||||||
|
elif data["event"] == "mark":
|
||||||
|
if self.mark_queue:
|
||||||
|
self.mark_queue.pop(0)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error processing Twilio message: {e}", stack_info=True)
|
||||||
|
|
||||||
|
async def initialize_session(self) -> None:
|
||||||
|
"""Control initial session with OpenAI."""
|
||||||
|
session_update = {
|
||||||
|
"input_audio_format": "g711_ulaw",
|
||||||
|
"output_audio_format": "g711_ulaw",
|
||||||
|
}
|
||||||
|
await self.realtime_client.session_update(session_update)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
|
def twilio_audio_adapter(websocket: "WebSocket") -> RealtimeObserver:
|
||||||
|
return TwilioAudioAdapter(websocket)
|
||||||
@@ -0,0 +1,139 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
from logging import Logger
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
from .....doc_utils import export_module
|
||||||
|
from ..realtime_events import AudioDelta, RealtimeEvent, SpeechStarted
|
||||||
|
from ..realtime_observer import RealtimeObserver
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..websockets import WebSocketProtocol as WebSocket
|
||||||
|
|
||||||
|
LOG_EVENT_TYPES = [
|
||||||
|
"error",
|
||||||
|
"response.content.done",
|
||||||
|
"rate_limits.updated",
|
||||||
|
"response.done",
|
||||||
|
"input_audio_buffer.committed",
|
||||||
|
"input_audio_buffer.speech_stopped",
|
||||||
|
"input_audio_buffer.speech_started",
|
||||||
|
"session.created",
|
||||||
|
]
|
||||||
|
SHOW_TIMING_MATH = False
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.agentchat.realtime.experimental")
|
||||||
|
class WebSocketAudioAdapter(RealtimeObserver):
|
||||||
|
def __init__(self, websocket: "WebSocket", *, logger: Optional[Logger] = None) -> None:
|
||||||
|
"""Observer for handling function calls from the OpenAI Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
websocket (WebSocket): The websocket connection.
|
||||||
|
logger (Logger): The logger for the observer.
|
||||||
|
"""
|
||||||
|
super().__init__(logger=logger)
|
||||||
|
self.websocket = websocket
|
||||||
|
|
||||||
|
# Connection specific state
|
||||||
|
self.stream_sid = None
|
||||||
|
self.latest_media_timestamp = 0
|
||||||
|
self.last_assistant_item: Optional[str] = None
|
||||||
|
self.mark_queue: list[str] = []
|
||||||
|
self.response_start_timestamp_socket: Optional[int] = None
|
||||||
|
|
||||||
|
async def on_event(self, event: RealtimeEvent) -> None:
|
||||||
|
"""Receive events from the OpenAI Realtime API, send audio back to websocket."""
|
||||||
|
logger = self.logger
|
||||||
|
|
||||||
|
if isinstance(event, AudioDelta):
|
||||||
|
audio_payload = base64.b64encode(base64.b64decode(event.delta)).decode("utf-8")
|
||||||
|
audio_delta = {"event": "media", "streamSid": self.stream_sid, "media": {"payload": audio_payload}}
|
||||||
|
await self.websocket.send_json(audio_delta)
|
||||||
|
|
||||||
|
if self.response_start_timestamp_socket is None:
|
||||||
|
self.response_start_timestamp_socket = self.latest_media_timestamp
|
||||||
|
if SHOW_TIMING_MATH:
|
||||||
|
logger.info(f"Setting start timestamp for new response: {self.response_start_timestamp_socket}ms")
|
||||||
|
|
||||||
|
# Update last_assistant_item safely
|
||||||
|
if event.item_id:
|
||||||
|
self.last_assistant_item = event.item_id
|
||||||
|
|
||||||
|
await self.send_mark()
|
||||||
|
|
||||||
|
# Trigger an interruption. Your use case might work better using `input_audio_buffer.speech_stopped`, or combining the two.
|
||||||
|
if isinstance(event, SpeechStarted):
|
||||||
|
logger.info("Speech start detected.")
|
||||||
|
if self.last_assistant_item:
|
||||||
|
logger.info(f"Interrupting response with id: {self.last_assistant_item}")
|
||||||
|
await self.handle_speech_started_event()
|
||||||
|
|
||||||
|
async def handle_speech_started_event(self) -> None:
|
||||||
|
"""Handle interruption when the caller's speech starts."""
|
||||||
|
logger = self.logger
|
||||||
|
logger.info("Handling speech started event.")
|
||||||
|
if self.mark_queue and self.response_start_timestamp_socket is not None:
|
||||||
|
elapsed_time = self.latest_media_timestamp - self.response_start_timestamp_socket
|
||||||
|
if SHOW_TIMING_MATH:
|
||||||
|
logger.info(
|
||||||
|
f"Calculating elapsed time for truncation: {self.latest_media_timestamp} - {self.response_start_timestamp_socket} = {elapsed_time}ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.last_assistant_item:
|
||||||
|
if SHOW_TIMING_MATH:
|
||||||
|
logger.info(f"Truncating item with ID: {self.last_assistant_item}, Truncated at: {elapsed_time}ms")
|
||||||
|
|
||||||
|
await self.realtime_client.truncate_audio(
|
||||||
|
audio_end_ms=elapsed_time,
|
||||||
|
content_index=0,
|
||||||
|
item_id=self.last_assistant_item,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.websocket.send_json({"event": "clear", "streamSid": self.stream_sid})
|
||||||
|
|
||||||
|
self.mark_queue.clear()
|
||||||
|
self.last_assistant_item = None
|
||||||
|
self.response_start_timestamp_socket = None
|
||||||
|
|
||||||
|
async def send_mark(self) -> None:
|
||||||
|
if self.stream_sid:
|
||||||
|
mark_event = {"event": "mark", "streamSid": self.stream_sid, "mark": {"name": "responsePart"}}
|
||||||
|
await self.websocket.send_json(mark_event)
|
||||||
|
self.mark_queue.append("responsePart")
|
||||||
|
|
||||||
|
async def initialize_session(self) -> None:
|
||||||
|
"""Control initial session with OpenAI."""
|
||||||
|
session_update = {"input_audio_format": "pcm16", "output_audio_format": "pcm16"}
|
||||||
|
await self.realtime_client.session_update(session_update)
|
||||||
|
|
||||||
|
async def run_loop(self) -> None:
|
||||||
|
"""Reads data from websocket and sends it to the RealtimeClient."""
|
||||||
|
logger = self.logger
|
||||||
|
async for message in self.websocket.iter_text():
|
||||||
|
try:
|
||||||
|
data = json.loads(message)
|
||||||
|
if data["event"] == "media":
|
||||||
|
self.latest_media_timestamp = int(data["media"]["timestamp"])
|
||||||
|
await self.realtime_client.send_audio(audio=data["media"]["payload"])
|
||||||
|
elif data["event"] == "start":
|
||||||
|
self.stream_sid = data["start"]["streamSid"]
|
||||||
|
logger.info(f"Incoming stream has started {self.stream_sid}")
|
||||||
|
self.response_start_timestamp_socket = None
|
||||||
|
self.latest_media_timestamp = 0
|
||||||
|
self.last_assistant_item = None
|
||||||
|
elif data["event"] == "mark":
|
||||||
|
if self.mark_queue:
|
||||||
|
self.mark_queue.pop(0)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to process message: {e}", stack_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
|
def websocket_audio_adapter(websocket: "WebSocket") -> RealtimeObserver:
|
||||||
|
return WebSocketAudioAdapter(websocket)
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
from ....doc_utils import export_module
|
||||||
|
from .realtime_events import InputAudioBufferDelta, RealtimeEvent
|
||||||
|
from .realtime_observer import RealtimeObserver
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from logging import Logger
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.agentchat.realtime.experimental")
|
||||||
|
class AudioObserver(RealtimeObserver):
|
||||||
|
"""Observer for user voice input"""
|
||||||
|
|
||||||
|
def __init__(self, *, logger: Optional["Logger"] = None) -> None:
|
||||||
|
"""Observer for user voice input"""
|
||||||
|
super().__init__(logger=logger)
|
||||||
|
|
||||||
|
async def on_event(self, event: RealtimeEvent) -> None:
|
||||||
|
"""Observe voice input events from the Realtime.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event (dict[str, Any]): The event from the OpenAI Realtime API.
|
||||||
|
"""
|
||||||
|
if isinstance(event, InputAudioBufferDelta):
|
||||||
|
self.logger.info("Received audio buffer delta")
|
||||||
|
|
||||||
|
async def initialize_session(self) -> None:
|
||||||
|
"""No need to initialize session from this observer"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def run_loop(self) -> None:
|
||||||
|
"""Run the observer loop."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
function_observer: RealtimeObserver = AudioObserver()
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from .gemini.client import GeminiRealtimeClient
|
||||||
|
from .oai.base_client import OpenAIRealtimeClient
|
||||||
|
from .realtime_client import RealtimeClientProtocol, Role, get_client
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"GeminiRealtimeClient",
|
||||||
|
"OpenAIRealtimeClient",
|
||||||
|
"RealtimeClientProtocol",
|
||||||
|
"Role",
|
||||||
|
"get_client",
|
||||||
|
]
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from .client import GeminiRealtimeClient
|
||||||
|
|
||||||
|
__all__ = ["GeminiRealtimeClient"]
|
||||||
@@ -0,0 +1,274 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from logging import Logger, getLogger
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||||
|
|
||||||
|
from ......doc_utils import export_module
|
||||||
|
from ......import_utils import optional_import_block, require_optional_import
|
||||||
|
from ......llm_config import LLMConfig
|
||||||
|
from ...realtime_events import AudioDelta, FunctionCall, RealtimeEvent, SessionCreated
|
||||||
|
from ..realtime_client import RealtimeClientBase, Role, register_realtime_client
|
||||||
|
|
||||||
|
with optional_import_block():
|
||||||
|
from websockets.asyncio.client import connect
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from websockets.asyncio.client import ClientConnection
|
||||||
|
|
||||||
|
from ..realtime_client import RealtimeClientProtocol
|
||||||
|
|
||||||
|
__all__ = ["GeminiRealtimeClient"]
|
||||||
|
|
||||||
|
global_logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
HOST = "generativelanguage.googleapis.com"
|
||||||
|
API_VERSION = "v1alpha"
|
||||||
|
|
||||||
|
|
||||||
|
@register_realtime_client()
|
||||||
|
@require_optional_import("websockets", "gemini", except_for=["get_factory", "__init__"])
|
||||||
|
@export_module("autogen.agentchat.realtime.experimental.clients")
|
||||||
|
class GeminiRealtimeClient(RealtimeClientBase):
|
||||||
|
"""(Experimental) Client for Gemini Realtime API."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
llm_config: Union[LLMConfig, dict[str, Any]],
|
||||||
|
logger: Optional[Logger] = None,
|
||||||
|
) -> None:
|
||||||
|
"""(Experimental) Client for Gemini Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm_config: The config for the client.
|
||||||
|
logger: The logger for the client.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._llm_config = llm_config
|
||||||
|
self._logger = logger
|
||||||
|
|
||||||
|
self._connection: Optional["ClientConnection"] = None
|
||||||
|
config = llm_config["config_list"][0]
|
||||||
|
|
||||||
|
self._model: str = config["model"]
|
||||||
|
self._voice = config.get("voice", "charon")
|
||||||
|
self._temperature: float = config.get("temperature", 0.8) # type: ignore[union-attr]
|
||||||
|
|
||||||
|
self._response_modality = "AUDIO"
|
||||||
|
|
||||||
|
self._api_key = config.get("api_key", None)
|
||||||
|
# todo: add test with base_url just to make sure it works
|
||||||
|
self._base_url: str = config.get(
|
||||||
|
"base_url",
|
||||||
|
f"wss://{HOST}/ws/google.ai.generativelanguage.{API_VERSION}.GenerativeService.BidiGenerateContent?key={self._api_key}",
|
||||||
|
)
|
||||||
|
self._final_config: dict[str, Any] = {}
|
||||||
|
self._pending_session_updates: dict[str, Any] = {}
|
||||||
|
self._is_reading_events = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logger(self) -> Logger:
|
||||||
|
"""Get the logger for the Gemini Realtime API."""
|
||||||
|
return self._logger or global_logger
|
||||||
|
|
||||||
|
@property
|
||||||
|
def connection(self) -> "ClientConnection":
|
||||||
|
"""Get the Gemini WebSocket connection."""
|
||||||
|
if self._connection is None:
|
||||||
|
raise RuntimeError("Gemini WebSocket is not initialized")
|
||||||
|
return self._connection
|
||||||
|
|
||||||
|
async def send_function_result(self, call_id: str, result: str) -> None:
|
||||||
|
"""Send the result of a function call to the Gemini Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
call_id (str): The ID of the function call.
|
||||||
|
result (str): The result of the function call.
|
||||||
|
"""
|
||||||
|
msg = {
|
||||||
|
"tool_response": {"function_responses": [{"id": call_id, "response": {"result": {"string_value": result}}}]}
|
||||||
|
}
|
||||||
|
if self._is_reading_events:
|
||||||
|
await self.connection.send(json.dumps(msg))
|
||||||
|
|
||||||
|
async def send_text(self, *, role: Role, text: str, turn_complete: bool = True) -> None:
|
||||||
|
"""Send a text message to the Gemini Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
role: The role of the message.
|
||||||
|
text: The text of the message.
|
||||||
|
turn_complete: A flag indicating if the turn is complete.
|
||||||
|
"""
|
||||||
|
msg = {
|
||||||
|
"client_content": {
|
||||||
|
"turn_complete": turn_complete,
|
||||||
|
"turns": [{"role": role, "parts": [{"text": text}]}],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if self._is_reading_events:
|
||||||
|
await self.connection.send(json.dumps(msg))
|
||||||
|
|
||||||
|
async def send_audio(self, audio: str) -> None:
|
||||||
|
"""Send audio to the Gemini Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio (str): The audio to send.
|
||||||
|
"""
|
||||||
|
msg = {
|
||||||
|
"realtime_input": {
|
||||||
|
"media_chunks": [
|
||||||
|
{
|
||||||
|
"data": audio,
|
||||||
|
"mime_type": "audio/pcm",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
await self.queue_input_audio_buffer_delta(audio)
|
||||||
|
if self._is_reading_events:
|
||||||
|
await self.connection.send(json.dumps(msg))
|
||||||
|
|
||||||
|
async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
|
||||||
|
self.logger.info("This is not natively supported by Gemini Realtime API.")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _initialize_session(self) -> None:
|
||||||
|
"""Initialize the session with the Gemini Realtime API."""
|
||||||
|
session_config = {
|
||||||
|
"setup": {
|
||||||
|
"system_instruction": {
|
||||||
|
"role": "system",
|
||||||
|
"parts": [{"text": self._pending_session_updates.get("instructions", "")}],
|
||||||
|
},
|
||||||
|
"model": f"models/{self._model}",
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"function_declarations": [
|
||||||
|
{
|
||||||
|
"name": tool_schema["name"],
|
||||||
|
"description": tool_schema["description"],
|
||||||
|
"parameters": tool_schema["parameters"],
|
||||||
|
}
|
||||||
|
for tool_schema in self._pending_session_updates.get("tools", [])
|
||||||
|
]
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"generation_config": {
|
||||||
|
"response_modalities": [self._response_modality],
|
||||||
|
"speech_config": {"voiceConfig": {"prebuiltVoiceConfig": {"voiceName": self._voice}}},
|
||||||
|
"temperature": self._temperature,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.logger.info(f"Sending session update: {session_config}")
|
||||||
|
await self.connection.send(json.dumps(session_config))
|
||||||
|
|
||||||
|
async def session_update(self, session_options: dict[str, Any]) -> None:
|
||||||
|
"""Record session updates to be applied when the connection is established.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_options (dict[str, Any]): The session options to update.
|
||||||
|
"""
|
||||||
|
if self._is_reading_events:
|
||||||
|
self.logger.warning("Is reading events. Session update will be ignored.")
|
||||||
|
else:
|
||||||
|
self._pending_session_updates.update(session_options)
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def connect(self) -> AsyncGenerator[None, None]:
|
||||||
|
"""Connect to the Gemini Realtime API."""
|
||||||
|
try:
|
||||||
|
async with connect(
|
||||||
|
self._base_url, additional_headers={"Content-Type": "application/json"}
|
||||||
|
) as self._connection:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self._connection = None
|
||||||
|
|
||||||
|
async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||||
|
"""Read Events from the Gemini Realtime Client"""
|
||||||
|
if self._connection is None:
|
||||||
|
raise RuntimeError("Client is not connected, call connect() first.")
|
||||||
|
await self._initialize_session()
|
||||||
|
|
||||||
|
self._is_reading_events = True
|
||||||
|
|
||||||
|
async for event in self._read_events():
|
||||||
|
yield event
|
||||||
|
|
||||||
|
async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||||
|
"""Read messages from the Gemini Realtime connection."""
|
||||||
|
async for raw_message in self.connection:
|
||||||
|
message = raw_message.decode("ascii") if isinstance(raw_message, bytes) else raw_message
|
||||||
|
events = self._parse_message(json.loads(message))
|
||||||
|
for event in events:
|
||||||
|
yield event
|
||||||
|
|
||||||
|
def _parse_message(self, response: dict[str, Any]) -> list[RealtimeEvent]:
|
||||||
|
"""Parse a message from the Gemini Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response (dict[str, Any]): The response to parse.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[RealtimeEvent]: The parsed events.
|
||||||
|
"""
|
||||||
|
if "serverContent" in response and "modelTurn" in response["serverContent"]:
|
||||||
|
try:
|
||||||
|
b64data = response["serverContent"]["modelTurn"]["parts"][0]["inlineData"].pop("data")
|
||||||
|
return [
|
||||||
|
AudioDelta(
|
||||||
|
delta=b64data,
|
||||||
|
item_id=None,
|
||||||
|
raw_message=response,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
except KeyError:
|
||||||
|
return []
|
||||||
|
elif "toolCall" in response:
|
||||||
|
return [
|
||||||
|
FunctionCall(
|
||||||
|
raw_message=response,
|
||||||
|
call_id=call["id"],
|
||||||
|
name=call["name"],
|
||||||
|
arguments=call["args"],
|
||||||
|
)
|
||||||
|
for call in response["toolCall"]["functionCalls"]
|
||||||
|
]
|
||||||
|
elif "setupComplete" in response:
|
||||||
|
return [
|
||||||
|
SessionCreated(raw_message=response),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return [RealtimeEvent(raw_message=response)]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_factory(
|
||||||
|
cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any
|
||||||
|
) -> Optional[Callable[[], "RealtimeClientProtocol"]]:
|
||||||
|
"""Create a Realtime API client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm_config: The LLM config for the client.
|
||||||
|
logger: The logger for the client.
|
||||||
|
**kwargs: Additional arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern
|
||||||
|
"""
|
||||||
|
if llm_config["config_list"][0].get("api_type") == "google" and list(kwargs.keys()) == []:
|
||||||
|
return lambda: GeminiRealtimeClient(llm_config=llm_config, logger=logger, **kwargs)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# needed for mypy to check if GeminiRealtimeClient implements RealtimeClientProtocol
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
_client: RealtimeClientProtocol = GeminiRealtimeClient(llm_config={})
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from .base_client import OpenAIRealtimeClient
|
||||||
|
from .rtc_client import OpenAIRealtimeWebRTCClient
|
||||||
|
|
||||||
|
__all__ = ["OpenAIRealtimeClient", "OpenAIRealtimeWebRTCClient"]
|
||||||
@@ -0,0 +1,220 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from logging import Logger, getLogger
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||||
|
|
||||||
|
from ......doc_utils import export_module
|
||||||
|
from ......import_utils import optional_import_block, require_optional_import
|
||||||
|
from ......llm_config import LLMConfig
|
||||||
|
from ...realtime_events import RealtimeEvent
|
||||||
|
from ..realtime_client import RealtimeClientBase, Role, register_realtime_client
|
||||||
|
from .utils import parse_oai_message
|
||||||
|
|
||||||
|
with optional_import_block():
|
||||||
|
from openai import DEFAULT_MAX_RETRIES, NOT_GIVEN, AsyncOpenAI
|
||||||
|
from openai.resources.beta.realtime.realtime import AsyncRealtimeConnection
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..realtime_client import RealtimeClientProtocol
|
||||||
|
|
||||||
|
__all__ = ["OpenAIRealtimeClient"]
|
||||||
|
|
||||||
|
global_logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@register_realtime_client()
|
||||||
|
@require_optional_import("openai>=1.66.2", "openai-realtime", except_for=["get_factory", "__init__"])
|
||||||
|
@export_module("autogen.agentchat.realtime.experimental.clients")
|
||||||
|
class OpenAIRealtimeClient(RealtimeClientBase):
|
||||||
|
"""(Experimental) Client for OpenAI Realtime API."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
llm_config: Union[LLMConfig, dict[str, Any]],
|
||||||
|
logger: Optional[Logger] = None,
|
||||||
|
) -> None:
|
||||||
|
"""(Experimental) Client for OpenAI Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm_config: The config for the client.
|
||||||
|
logger: the logger to use for logging events
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._llm_config = llm_config
|
||||||
|
self._logger = logger
|
||||||
|
|
||||||
|
self._connection: Optional["AsyncRealtimeConnection"] = None
|
||||||
|
|
||||||
|
self.config = llm_config["config_list"][0]
|
||||||
|
# model is passed to self._client.beta.realtime.connect function later
|
||||||
|
self._model: str = self.config["model"]
|
||||||
|
self._voice: str = self.config.get("voice", "alloy")
|
||||||
|
self._temperature: float = llm_config.get("temperature", 0.8) # type: ignore[union-attr]
|
||||||
|
|
||||||
|
self._client: Optional["AsyncOpenAI"] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logger(self) -> Logger:
|
||||||
|
"""Get the logger for the OpenAI Realtime API."""
|
||||||
|
return self._logger or global_logger
|
||||||
|
|
||||||
|
@property
|
||||||
|
def connection(self) -> "AsyncRealtimeConnection":
|
||||||
|
"""Get the OpenAI WebSocket connection."""
|
||||||
|
if self._connection is None:
|
||||||
|
raise RuntimeError("OpenAI WebSocket is not initialized")
|
||||||
|
return self._connection
|
||||||
|
|
||||||
|
async def send_function_result(self, call_id: str, result: str) -> None:
|
||||||
|
"""Send the result of a function call to the OpenAI Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
call_id (str): The ID of the function call.
|
||||||
|
result (str): The result of the function call.
|
||||||
|
"""
|
||||||
|
await self.connection.conversation.item.create(
|
||||||
|
item={
|
||||||
|
"type": "function_call_output",
|
||||||
|
"call_id": call_id,
|
||||||
|
"output": result,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.connection.response.create()
|
||||||
|
|
||||||
|
async def send_text(self, *, role: Role, text: str) -> None:
|
||||||
|
"""Send a text message to the OpenAI Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
role (str): The role of the message.
|
||||||
|
text (str): The text of the message.
|
||||||
|
"""
|
||||||
|
await self.connection.response.cancel()
|
||||||
|
await self.connection.conversation.item.create(
|
||||||
|
item={"type": "message", "role": role, "content": [{"type": "input_text", "text": text}]}
|
||||||
|
)
|
||||||
|
await self.connection.response.create()
|
||||||
|
|
||||||
|
async def send_audio(self, audio: str) -> None:
|
||||||
|
"""Send audio to the OpenAI Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio (str): The audio to send.
|
||||||
|
"""
|
||||||
|
await self.queue_input_audio_buffer_delta(audio)
|
||||||
|
await self.connection.input_audio_buffer.append(audio=audio)
|
||||||
|
|
||||||
|
async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
|
||||||
|
"""Truncate audio in the OpenAI Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_end_ms (int): The end of the audio to truncate.
|
||||||
|
content_index (int): The index of the content to truncate.
|
||||||
|
item_id (str): The ID of the item to truncate.
|
||||||
|
"""
|
||||||
|
await self.connection.conversation.item.truncate(
|
||||||
|
audio_end_ms=audio_end_ms, content_index=content_index, item_id=item_id
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _initialize_session(self) -> None:
|
||||||
|
"""Control initial session with OpenAI."""
|
||||||
|
session_update = {
|
||||||
|
"turn_detection": {"type": "server_vad"},
|
||||||
|
"voice": self._voice,
|
||||||
|
"modalities": ["audio", "text"],
|
||||||
|
"temperature": self._temperature,
|
||||||
|
}
|
||||||
|
await self.session_update(session_options=session_update)
|
||||||
|
|
||||||
|
async def session_update(self, session_options: dict[str, Any]) -> None:
|
||||||
|
"""Send a session update to the OpenAI Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_options (dict[str, Any]): The session options to update.
|
||||||
|
"""
|
||||||
|
logger = self.logger
|
||||||
|
logger.info(f"Sending session update: {session_options}")
|
||||||
|
await self.connection.session.update(session=session_options) # type: ignore[arg-type]
|
||||||
|
logger.info("Sending session update finished")
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def connect(self) -> AsyncGenerator[None, None]:
|
||||||
|
"""Connect to the OpenAI Realtime API."""
|
||||||
|
try:
|
||||||
|
if not self._client:
|
||||||
|
self._client = AsyncOpenAI(
|
||||||
|
api_key=self.config.get("api_key", None),
|
||||||
|
organization=self.config.get("organization", None),
|
||||||
|
project=self.config.get("project", None),
|
||||||
|
base_url=self.config.get("base_url", None),
|
||||||
|
websocket_base_url=self.config.get("websocket_base_url", None),
|
||||||
|
timeout=self.config.get("timeout", NOT_GIVEN),
|
||||||
|
max_retries=self.config.get("max_retries", DEFAULT_MAX_RETRIES),
|
||||||
|
default_headers=self.config.get("default_headers", None),
|
||||||
|
default_query=self.config.get("default_query", None),
|
||||||
|
)
|
||||||
|
async with self._client.beta.realtime.connect(
|
||||||
|
model=self._model,
|
||||||
|
) as self._connection:
|
||||||
|
await self._initialize_session()
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self._connection = None
|
||||||
|
|
||||||
|
async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||||
|
"""Read messages from the OpenAI Realtime API."""
|
||||||
|
if self._connection is None:
|
||||||
|
raise RuntimeError("Client is not connected, call connect() first.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for event in self._read_events():
|
||||||
|
yield event
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self._connection = None
|
||||||
|
|
||||||
|
async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||||
|
"""Read messages from the OpenAI Realtime API."""
|
||||||
|
async for message in self._connection:
|
||||||
|
for event in self._parse_message(message.model_dump()):
|
||||||
|
yield event
|
||||||
|
|
||||||
|
def _parse_message(self, message: dict[str, Any]) -> list[RealtimeEvent]:
|
||||||
|
"""Parse a message from the OpenAI Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (dict[str, Any]): The message to parse.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RealtimeEvent: The parsed event.
|
||||||
|
"""
|
||||||
|
return [parse_oai_message(message)]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_factory(
|
||||||
|
cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any
|
||||||
|
) -> Optional[Callable[[], "RealtimeClientProtocol"]]:
|
||||||
|
"""Create a Realtime API client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm_config: The config for the client.
|
||||||
|
logger: The logger to use for logging events.
|
||||||
|
kwargs: Additional arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern
|
||||||
|
"""
|
||||||
|
if llm_config["config_list"][0].get("api_type", "openai") == "openai" and list(kwargs.keys()) == []:
|
||||||
|
return lambda: OpenAIRealtimeClient(llm_config=llm_config, logger=logger, **kwargs)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# needed for mypy to check if OpenAIRealtimeWebRTCClient implements RealtimeClientProtocol
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
_client: RealtimeClientProtocol = OpenAIRealtimeClient(llm_config={})
|
||||||
@@ -0,0 +1,243 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from logging import Logger, getLogger
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||||
|
|
||||||
|
from autogen.import_utils import optional_import_block, require_optional_import
|
||||||
|
|
||||||
|
from ......doc_utils import export_module
|
||||||
|
from ......llm_config import LLMConfig
|
||||||
|
from ...realtime_events import RealtimeEvent
|
||||||
|
from ..realtime_client import RealtimeClientBase, Role, register_realtime_client
|
||||||
|
from .utils import parse_oai_message
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...websockets import WebSocketProtocol as WebSocket
|
||||||
|
from ..realtime_client import RealtimeClientProtocol
|
||||||
|
|
||||||
|
with optional_import_block():
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
__all__ = ["OpenAIRealtimeWebRTCClient"]
|
||||||
|
|
||||||
|
global_logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@register_realtime_client()
|
||||||
|
@require_optional_import("httpx", "openai-realtime", except_for="get_factory")
|
||||||
|
@export_module("autogen.agentchat.realtime.experimental.clients.oai")
|
||||||
|
class OpenAIRealtimeWebRTCClient(RealtimeClientBase):
|
||||||
|
"""(Experimental) Client for OpenAI Realtime API that uses WebRTC protocol."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
llm_config: Union[LLMConfig, dict[str, Any]],
|
||||||
|
websocket: "WebSocket",
|
||||||
|
logger: Optional[Logger] = None,
|
||||||
|
) -> None:
|
||||||
|
"""(Experimental) Client for OpenAI Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm_config: The config for the client.
|
||||||
|
websocket: the websocket to use for the connection
|
||||||
|
logger: the logger to use for logging events
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._llm_config = llm_config
|
||||||
|
self._logger = logger
|
||||||
|
self._websocket = websocket
|
||||||
|
|
||||||
|
config = llm_config["config_list"][0]
|
||||||
|
self._model: str = config["model"]
|
||||||
|
self._voice: str = config.get("voice", "alloy")
|
||||||
|
self._temperature: float = llm_config.get("temperature", 0.8) # type: ignore[union-attr]
|
||||||
|
self._config = config
|
||||||
|
self._base_url = config.get("base_url", "https://api.openai.com/v1/realtime/sessions")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logger(self) -> Logger:
|
||||||
|
"""Get the logger for the OpenAI Realtime API."""
|
||||||
|
return self._logger or global_logger
|
||||||
|
|
||||||
|
async def send_function_result(self, call_id: str, result: str) -> None:
|
||||||
|
"""Send the result of a function call to the OpenAI Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
call_id (str): The ID of the function call.
|
||||||
|
result (str): The result of the function call.
|
||||||
|
"""
|
||||||
|
await self._websocket.send_json({
|
||||||
|
"type": "conversation.item.create",
|
||||||
|
"item": {
|
||||||
|
"type": "function_call_output",
|
||||||
|
"call_id": call_id,
|
||||||
|
"output": result,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
await self._websocket.send_json({"type": "response.create"})
|
||||||
|
|
||||||
|
async def send_text(self, *, role: Role, text: str) -> None:
|
||||||
|
"""Send a text message to the OpenAI Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
role (str): The role of the message.
|
||||||
|
text (str): The text of the message.
|
||||||
|
"""
|
||||||
|
# await self.connection.response.cancel() #why is this here?
|
||||||
|
await self._websocket.send_json({
|
||||||
|
"type": "response.cancel",
|
||||||
|
})
|
||||||
|
await self._websocket.send_json({
|
||||||
|
"type": "conversation.item.create",
|
||||||
|
"item": {"type": "message", "role": role, "content": [{"type": "input_text", "text": text}]},
|
||||||
|
})
|
||||||
|
# await self.connection.response.create()
|
||||||
|
await self._websocket.send_json({"type": "response.create"})
|
||||||
|
|
||||||
|
async def send_audio(self, audio: str) -> None:
|
||||||
|
"""Send audio to the OpenAI Realtime API.
|
||||||
|
in case of WebRTC, audio is already sent by js client, so we just queue it in order to be logged.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio (str): The audio to send.
|
||||||
|
"""
|
||||||
|
await self.queue_input_audio_buffer_delta(audio)
|
||||||
|
|
||||||
|
async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
|
||||||
|
"""Truncate audio in the OpenAI Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_end_ms (int): The end of the audio to truncate.
|
||||||
|
content_index (int): The index of the content to truncate.
|
||||||
|
item_id (str): The ID of the item to truncate.
|
||||||
|
"""
|
||||||
|
await self._websocket.send_json({
|
||||||
|
"type": "conversation.item.truncate",
|
||||||
|
"content_index": content_index,
|
||||||
|
"item_id": item_id,
|
||||||
|
"audio_end_ms": audio_end_ms,
|
||||||
|
})
|
||||||
|
|
||||||
|
async def session_update(self, session_options: dict[str, Any]) -> None:
|
||||||
|
"""Send a session update to the OpenAI Realtime API.
|
||||||
|
|
||||||
|
In the case of WebRTC we can not send it directly, but we can send it
|
||||||
|
to the javascript over the websocket, and rely on it to send session
|
||||||
|
update to OpenAI
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_options (dict[str, Any]): The session options to update.
|
||||||
|
"""
|
||||||
|
logger = self.logger
|
||||||
|
logger.info(f"Sending session update: {session_options}")
|
||||||
|
# await self.connection.session.update(session=session_options) # type: ignore[arg-type]
|
||||||
|
await self._websocket.send_json({"type": "session.update", "session": session_options})
|
||||||
|
logger.info("Sending session update finished")
|
||||||
|
|
||||||
|
def session_init_data(self) -> list[dict[str, Any]]:
|
||||||
|
"""Control initial session with OpenAI."""
|
||||||
|
session_update = {
|
||||||
|
"turn_detection": {"type": "server_vad"},
|
||||||
|
"voice": self._voice,
|
||||||
|
"modalities": ["audio", "text"],
|
||||||
|
"temperature": self._temperature,
|
||||||
|
}
|
||||||
|
return [{"type": "session.update", "session": session_update}]
|
||||||
|
|
||||||
|
async def _initialize_session(self) -> None: ...
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def connect(self) -> AsyncGenerator[None, None]:
|
||||||
|
"""Connect to the OpenAI Realtime API.
|
||||||
|
|
||||||
|
In the case of WebRTC, we pass connection information over the
|
||||||
|
websocket, so that javascript on the other end of websocket open
|
||||||
|
actual connection to OpenAI
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
base_url = self._base_url
|
||||||
|
api_key = self._config.get("api_key", None)
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}", # Use os.getenv to get from environment
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
data = {
|
||||||
|
# "model": "gpt-4o-realtime-preview-2024-12-17",
|
||||||
|
"model": self._model,
|
||||||
|
"voice": self._voice,
|
||||||
|
}
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(base_url, headers=headers, json=data)
|
||||||
|
response.raise_for_status()
|
||||||
|
json_data = response.json()
|
||||||
|
json_data["model"] = self._model
|
||||||
|
if self._websocket is not None:
|
||||||
|
session_init = self.session_init_data()
|
||||||
|
await self._websocket.send_json({"type": "ag2.init", "config": json_data, "init": session_init})
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||||
|
"""Read events from the OpenAI Realtime API."""
|
||||||
|
async for event in self._read_events():
|
||||||
|
yield event
|
||||||
|
|
||||||
|
async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||||
|
"""Read messages from the OpenAI Realtime API connection.
|
||||||
|
Again, in case of WebRTC, we do not read OpenAI messages directly since we
|
||||||
|
do not hold connection to OpenAI. Instead we read messages from the websocket, and javascript
|
||||||
|
client on the other side of the websocket that is connected to OpenAI is relaying events to us.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
message_json = await self._websocket.receive_text()
|
||||||
|
message = json.loads(message_json)
|
||||||
|
for event in self._parse_message(message):
|
||||||
|
yield event
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.exception(f"Error reading from connection {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
def _parse_message(self, message: dict[str, Any]) -> list[RealtimeEvent]:
|
||||||
|
"""Parse a message from the OpenAI Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (dict[str, Any]): The message to parse.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RealtimeEvent: The parsed event.
|
||||||
|
"""
|
||||||
|
return [parse_oai_message(message)]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_factory(
|
||||||
|
cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any
|
||||||
|
) -> Optional[Callable[[], "RealtimeClientProtocol"]]:
|
||||||
|
"""Create a Realtime API client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm_config: The config for the client.
|
||||||
|
logger: The logger to use for logging events.
|
||||||
|
**kwargs: Additional arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern
|
||||||
|
"""
|
||||||
|
if llm_config["config_list"][0].get("api_type", "openai") == "openai" and list(kwargs.keys()) == ["websocket"]:
|
||||||
|
return lambda: OpenAIRealtimeWebRTCClient(llm_config=llm_config, logger=logger, **kwargs)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# needed for mypy to check if OpenAIRealtimeWebRTCClient implements RealtimeClientProtocol
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
|
def _rtc_client(websocket: "WebSocket") -> RealtimeClientProtocol:
|
||||||
|
return OpenAIRealtimeWebRTCClient(llm_config={}, websocket=websocket)
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from ...realtime_events import (
|
||||||
|
AudioDelta,
|
||||||
|
FunctionCall,
|
||||||
|
InputAudioBufferDelta,
|
||||||
|
RealtimeEvent,
|
||||||
|
SessionCreated,
|
||||||
|
SessionUpdated,
|
||||||
|
SpeechStarted,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["parse_oai_message"]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_oai_message(message: dict[str, Any]) -> RealtimeEvent:
|
||||||
|
"""Parse a message from the OpenAI Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (dict[str, Any]): The message to parse.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RealtimeEvent: The parsed event.
|
||||||
|
"""
|
||||||
|
if message.get("type") == "session.created":
|
||||||
|
return SessionCreated(raw_message=message)
|
||||||
|
elif message.get("type") == "session.updated":
|
||||||
|
return SessionUpdated(raw_message=message)
|
||||||
|
elif message.get("type") == "response.audio.delta":
|
||||||
|
return AudioDelta(raw_message=message, delta=message["delta"], item_id=message["item_id"])
|
||||||
|
elif message.get("type") == "input_audio_buffer.speech_started":
|
||||||
|
return SpeechStarted(raw_message=message)
|
||||||
|
elif message.get("type") == "input_audio_buffer.delta":
|
||||||
|
return InputAudioBufferDelta(delta=message["delta"], item_id=None, raw_message=message)
|
||||||
|
elif message.get("type") == "response.function_call_arguments.done":
|
||||||
|
return FunctionCall(
|
||||||
|
raw_message=message,
|
||||||
|
call_id=message["call_id"],
|
||||||
|
name=message["name"],
|
||||||
|
arguments=json.loads(message["arguments"]),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return RealtimeEvent(raw_message=message)
|
||||||
@@ -0,0 +1,190 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from logging import Logger
|
||||||
|
from typing import Any, AsyncContextManager, Callable, Literal, Optional, Protocol, TypeVar, Union, runtime_checkable
|
||||||
|
|
||||||
|
from asyncer import create_task_group
|
||||||
|
|
||||||
|
from .....doc_utils import export_module
|
||||||
|
from .....llm_config import LLMConfig
|
||||||
|
from ..realtime_events import InputAudioBufferDelta, RealtimeEvent
|
||||||
|
|
||||||
|
__all__ = ["RealtimeClientProtocol", "Role", "get_client", "register_realtime_client"]
|
||||||
|
|
||||||
|
# define role literal type for typing
|
||||||
|
Role = Literal["user", "assistant", "system"]
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
@export_module("autogen.agentchat.realtime.experimental.clients")
|
||||||
|
class RealtimeClientProtocol(Protocol):
|
||||||
|
async def send_function_result(self, call_id: str, result: str) -> None:
|
||||||
|
"""Send the result of a function call to a Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
call_id (str): The ID of the function call.
|
||||||
|
result (str): The result of the function call.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def send_text(self, *, role: Role, text: str) -> None:
|
||||||
|
"""Send a text message to a Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
role (str): The role of the message.
|
||||||
|
text (str): The text of the message.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def send_audio(self, audio: str) -> None:
|
||||||
|
"""Send audio to a Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio (str): The audio to send.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
|
||||||
|
"""Truncate audio in a Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_end_ms (int): The end of the audio to truncate.
|
||||||
|
content_index (int): The index of the content to truncate.
|
||||||
|
item_id (str): The ID of the item to truncate.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def session_update(self, session_options: dict[str, Any]) -> None:
|
||||||
|
"""Send a session update to a Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_options (dict[str, Any]): The session options to update.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def connect(self) -> AsyncContextManager[None]: ...
|
||||||
|
|
||||||
|
def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||||
|
"""Read events from a Realtime Client."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||||
|
"""Read events from a Realtime connection."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def _parse_message(self, message: dict[str, Any]) -> list[RealtimeEvent]:
|
||||||
|
"""Parse a message from a Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (dict[str, Any]): The message to parse.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[RealtimeEvent]: The parsed events.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_factory(
|
||||||
|
cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any
|
||||||
|
) -> Optional[Callable[[], "RealtimeClientProtocol"]]:
|
||||||
|
"""Create a Realtime API client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm_config: The config for the client.
|
||||||
|
logger: The logger to use for logging events.
|
||||||
|
**kwargs: Additional arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class RealtimeClientBase:
|
||||||
|
def __init__(self):
|
||||||
|
self._eventQueue = asyncio.Queue()
|
||||||
|
|
||||||
|
async def add_event(self, event: Optional[RealtimeEvent]):
|
||||||
|
await self._eventQueue.put(event)
|
||||||
|
|
||||||
|
async def get_event(self) -> Optional[RealtimeEvent]:
|
||||||
|
return await self._eventQueue.get()
|
||||||
|
|
||||||
|
async def _read_from_connection_task(self):
|
||||||
|
async for event in self._read_from_connection():
|
||||||
|
await self.add_event(event)
|
||||||
|
await self.add_event(None)
|
||||||
|
|
||||||
|
async def _read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
|
||||||
|
"""Read events from a Realtime Client."""
|
||||||
|
async with create_task_group() as tg:
|
||||||
|
tg.start_soon(self._read_from_connection_task)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
event = await self._eventQueue.get()
|
||||||
|
if event is not None:
|
||||||
|
yield event
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
|
||||||
|
async def queue_input_audio_buffer_delta(self, audio: str) -> None:
|
||||||
|
"""queue InputAudioBufferDelta.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio (str): The audio.
|
||||||
|
"""
|
||||||
|
await self.add_event(InputAudioBufferDelta(delta=audio, item_id=None, raw_message=dict()))
|
||||||
|
|
||||||
|
|
||||||
|
_realtime_client_classes: dict[str, type[RealtimeClientProtocol]] = {}
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=RealtimeClientProtocol)
|
||||||
|
|
||||||
|
|
||||||
|
def register_realtime_client() -> Callable[[type[T]], type[T]]:
|
||||||
|
"""Register a Realtime API client.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[[type[T]], type[T]]: The decorator to register the Realtime API client
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(client_cls: type[T]) -> type[T]:
|
||||||
|
"""Register a Realtime API client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client_cls: The client to register.
|
||||||
|
"""
|
||||||
|
global _realtime_client_classes
|
||||||
|
fqn = f"{client_cls.__module__}.{client_cls.__name__}"
|
||||||
|
_realtime_client_classes[fqn] = client_cls
|
||||||
|
|
||||||
|
return client_cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.agentchat.realtime.experimental.clients")
|
||||||
|
def get_client(llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any) -> "RealtimeClientProtocol":
|
||||||
|
"""Get a registered Realtime API client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm_config: The config for the client.
|
||||||
|
logger: The logger to use for logging events.
|
||||||
|
**kwargs: Additional arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RealtimeClientProtocol: The Realtime API client.
|
||||||
|
"""
|
||||||
|
global _realtime_client_classes
|
||||||
|
for _, client_cls in _realtime_client_classes.items():
|
||||||
|
factory = client_cls.get_factory(llm_config=llm_config, logger=logger, **kwargs)
|
||||||
|
if factory:
|
||||||
|
return factory()
|
||||||
|
|
||||||
|
raise ValueError("Realtime API client not found.")
|
||||||
@@ -0,0 +1,85 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
|
from asyncer import asyncify
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ....doc_utils import export_module
|
||||||
|
from .realtime_events import FunctionCall, RealtimeEvent
|
||||||
|
from .realtime_observer import RealtimeObserver
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from logging import Logger
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.agentchat.realtime.experimental")
|
||||||
|
class FunctionObserver(RealtimeObserver):
|
||||||
|
"""Observer for handling function calls from the OpenAI Realtime API."""
|
||||||
|
|
||||||
|
def __init__(self, *, logger: Optional["Logger"] = None) -> None:
|
||||||
|
"""Observer for handling function calls from the OpenAI Realtime API."""
|
||||||
|
super().__init__(logger=logger)
|
||||||
|
|
||||||
|
async def on_event(self, event: RealtimeEvent) -> None:
|
||||||
|
"""Handle function call events from the OpenAI Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event (dict[str, Any]): The event from the OpenAI Realtime API.
|
||||||
|
"""
|
||||||
|
if isinstance(event, FunctionCall):
|
||||||
|
self.logger.info("Received function call event")
|
||||||
|
await self.call_function(
|
||||||
|
call_id=event.call_id,
|
||||||
|
name=event.name,
|
||||||
|
kwargs=event.arguments,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def call_function(self, call_id: str, name: str, kwargs: dict[str, Any]) -> None:
|
||||||
|
"""Call a function registered with the agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
call_id (str): The ID of the function call.
|
||||||
|
name (str): The name of the function to call.
|
||||||
|
kwargs (Any[str, Any]): The arguments to pass to the function.
|
||||||
|
"""
|
||||||
|
if name in self.agent.registered_realtime_tools:
|
||||||
|
func = self.agent.registered_realtime_tools[name].func
|
||||||
|
func = func if asyncio.iscoroutinefunction(func) else asyncify(func)
|
||||||
|
try:
|
||||||
|
result = await func(**kwargs)
|
||||||
|
except Exception:
|
||||||
|
result = "Function call failed"
|
||||||
|
self.logger.info(f"Function call failed: {name=}, {kwargs=}", stack_info=True)
|
||||||
|
|
||||||
|
if isinstance(result, BaseModel):
|
||||||
|
result = result.model_dump_json()
|
||||||
|
elif not isinstance(result, str):
|
||||||
|
try:
|
||||||
|
result = json.dumps(result)
|
||||||
|
except Exception:
|
||||||
|
result = str(result)
|
||||||
|
|
||||||
|
await self.realtime_client.send_function_result(call_id, result)
|
||||||
|
else:
|
||||||
|
self.logger.warning(f"Function {name} called, but is not registered with the realtime agent.")
|
||||||
|
|
||||||
|
async def initialize_session(self) -> None:
|
||||||
|
"""Add registered tools to OpenAI with a session update."""
|
||||||
|
session_update = {
|
||||||
|
"tools": [tool.realtime_tool_schema for tool in self.agent.registered_realtime_tools.values()],
|
||||||
|
"tool_choice": "auto",
|
||||||
|
}
|
||||||
|
await self.realtime_client.session_update(session_update)
|
||||||
|
|
||||||
|
async def run_loop(self) -> None:
|
||||||
|
"""Run the observer loop."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
function_observer: RealtimeObserver = FunctionObserver()
|
||||||
@@ -0,0 +1,158 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from logging import Logger, getLogger
|
||||||
|
from typing import Any, Callable, Optional, TypeVar, Union
|
||||||
|
|
||||||
|
from anyio import lowlevel
|
||||||
|
from asyncer import create_task_group
|
||||||
|
|
||||||
|
from ....doc_utils import export_module
|
||||||
|
from ....llm_config import LLMConfig
|
||||||
|
from ....tools import Tool
|
||||||
|
from .clients.realtime_client import RealtimeClientProtocol, get_client
|
||||||
|
from .function_observer import FunctionObserver
|
||||||
|
from .realtime_observer import RealtimeObserver
|
||||||
|
|
||||||
|
F = TypeVar("F", bound=Callable[..., Any])
|
||||||
|
|
||||||
|
global_logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RealtimeAgentCallbacks:
|
||||||
|
"""Callbacks for the Realtime Agent."""
|
||||||
|
|
||||||
|
# async empty placeholder function
|
||||||
|
on_observers_ready: Callable[[], Any] = lambda: lowlevel.checkpoint()
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.agentchat.realtime.experimental")
|
||||||
|
class RealtimeAgent:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
name: str,
|
||||||
|
audio_adapter: Optional[RealtimeObserver] = None,
|
||||||
|
system_message: str = "You are a helpful AI Assistant.",
|
||||||
|
llm_config: Optional[Union[LLMConfig, dict[str, Any]]] = None,
|
||||||
|
logger: Optional[Logger] = None,
|
||||||
|
observers: Optional[list[RealtimeObserver]] = None,
|
||||||
|
**client_kwargs: Any,
|
||||||
|
):
|
||||||
|
"""(Experimental) Agent for interacting with the Realtime Clients.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the agent.
|
||||||
|
audio_adapter (Optional[RealtimeObserver] = None): The audio adapter for the agent.
|
||||||
|
system_message (str): The system message for the agent.
|
||||||
|
llm_config (LLMConfig, dict[str, Any], bool): The config for the agent.
|
||||||
|
logger (Optional[Logger]): The logger for the agent.
|
||||||
|
observers (Optional[list[RealtimeObserver]]): The additional observers for the agent.
|
||||||
|
**client_kwargs (Any): The keyword arguments for the client.
|
||||||
|
"""
|
||||||
|
self._logger = logger
|
||||||
|
self._name = name
|
||||||
|
self._system_message = system_message
|
||||||
|
|
||||||
|
llm_config = LLMConfig.get_current_llm_config(llm_config)
|
||||||
|
|
||||||
|
self._realtime_client: RealtimeClientProtocol = get_client(
|
||||||
|
llm_config=llm_config, logger=self.logger, **client_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self._registered_realtime_tools: dict[str, Tool] = {}
|
||||||
|
self._observers: list[RealtimeObserver] = observers if observers else []
|
||||||
|
self._observers.append(FunctionObserver(logger=logger))
|
||||||
|
if audio_adapter:
|
||||||
|
self._observers.append(audio_adapter)
|
||||||
|
|
||||||
|
self.callbacks = RealtimeAgentCallbacks()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def system_message(self) -> str:
|
||||||
|
"""Get the system message for the agent."""
|
||||||
|
return self._system_message
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logger(self) -> Logger:
|
||||||
|
"""Get the logger for the agent."""
|
||||||
|
return self._logger or global_logger
|
||||||
|
|
||||||
|
@property
|
||||||
|
def realtime_client(self) -> RealtimeClientProtocol:
|
||||||
|
"""Get the OpenAI Realtime Client."""
|
||||||
|
return self._realtime_client
|
||||||
|
|
||||||
|
@property
|
||||||
|
def registered_realtime_tools(self) -> dict[str, Tool]:
|
||||||
|
"""Get the registered realtime tools."""
|
||||||
|
return self._registered_realtime_tools
|
||||||
|
|
||||||
|
def register_observer(self, observer: RealtimeObserver) -> None:
|
||||||
|
"""Register an observer with the Realtime Agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observer (RealtimeObserver): The observer to register.
|
||||||
|
"""
|
||||||
|
self._observers.append(observer)
|
||||||
|
|
||||||
|
async def start_observers(self) -> None:
|
||||||
|
for observer in self._observers:
|
||||||
|
self._tg.soonify(observer.run)(self)
|
||||||
|
|
||||||
|
# wait for the observers to be ready
|
||||||
|
for observer in self._observers:
|
||||||
|
await observer.wait_for_ready()
|
||||||
|
|
||||||
|
await self.callbacks.on_observers_ready()
|
||||||
|
|
||||||
|
async def run(self) -> None:
|
||||||
|
"""Run the agent."""
|
||||||
|
# everything is run in the same task group to enable easy cancellation using self._tg.cancel_scope.cancel()
|
||||||
|
async with create_task_group() as self._tg: # noqa: SIM117
|
||||||
|
# connect with the client first (establishes a connection and initializes a session)
|
||||||
|
async with self._realtime_client.connect():
|
||||||
|
# start the observers and wait for them to be ready
|
||||||
|
await self.realtime_client.session_update(session_options={"instructions": self.system_message})
|
||||||
|
await self.start_observers()
|
||||||
|
|
||||||
|
# iterate over the events
|
||||||
|
async for event in self.realtime_client.read_events():
|
||||||
|
for observer in self._observers:
|
||||||
|
await observer.on_event(event)
|
||||||
|
|
||||||
|
def register_realtime_function(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
) -> Callable[[Union[F, Tool]], Tool]:
|
||||||
|
"""Decorator for registering a function to be used by an agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the function.
|
||||||
|
description (str): The description of the function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[[Union[F, Tool]], Tool]: The decorator for registering a function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _decorator(func_or_tool: Union[F, Tool]) -> Tool:
|
||||||
|
"""Decorator for registering a function to be used by an agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func_or_tool (Union[F, Tool]): The function or tool to register.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tool: The registered tool.
|
||||||
|
"""
|
||||||
|
tool = Tool(func_or_tool=func_or_tool, name=name, description=description)
|
||||||
|
|
||||||
|
self._registered_realtime_tools[tool.name] = tool
|
||||||
|
|
||||||
|
return tool
|
||||||
|
|
||||||
|
return _decorator
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class RealtimeEvent(BaseModel):
|
||||||
|
raw_message: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class SessionCreated(RealtimeEvent):
|
||||||
|
type: Literal["session.created"] = "session.created"
|
||||||
|
|
||||||
|
|
||||||
|
class SessionUpdated(RealtimeEvent):
|
||||||
|
type: Literal["session.updated"] = "session.updated"
|
||||||
|
|
||||||
|
|
||||||
|
class AudioDelta(RealtimeEvent):
|
||||||
|
type: Literal["response.audio.delta"] = "response.audio.delta"
|
||||||
|
delta: str
|
||||||
|
item_id: Any
|
||||||
|
|
||||||
|
|
||||||
|
class InputAudioBufferDelta(RealtimeEvent):
|
||||||
|
type: Literal["input_audio_buffer.delta"] = "input_audio_buffer.delta"
|
||||||
|
delta: str
|
||||||
|
item_id: Any
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechStarted(RealtimeEvent):
|
||||||
|
type: Literal["input_audio_buffer.speech_started"] = "input_audio_buffer.speech_started"
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionCall(RealtimeEvent):
|
||||||
|
type: Literal["response.function_call_arguments.done"] = "response.function_call_arguments.done"
|
||||||
|
name: str
|
||||||
|
arguments: dict[str, Any]
|
||||||
|
call_id: str
|
||||||
@@ -0,0 +1,100 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from logging import Logger, getLogger
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
from anyio import Event
|
||||||
|
|
||||||
|
from ....doc_utils import export_module
|
||||||
|
from .clients.realtime_client import RealtimeClientProtocol
|
||||||
|
from .realtime_events import RealtimeEvent
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .realtime_agent import RealtimeAgent
|
||||||
|
|
||||||
|
__all__ = ["RealtimeObserver"]
|
||||||
|
|
||||||
|
global_logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.agentchat.realtime.experimental")
|
||||||
|
class RealtimeObserver(ABC):
|
||||||
|
"""Observer for the OpenAI Realtime API."""
|
||||||
|
|
||||||
|
def __init__(self, *, logger: Optional[Logger] = None) -> None:
|
||||||
|
"""Observer for the OpenAI Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logger (Logger): The logger for the observer.
|
||||||
|
"""
|
||||||
|
self._ready_event = Event()
|
||||||
|
self._agent: Optional[RealtimeAgent] = None
|
||||||
|
self._logger = logger
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logger(self) -> Logger:
|
||||||
|
return self._logger or global_logger
|
||||||
|
|
||||||
|
@property
|
||||||
|
def agent(self) -> "RealtimeAgent":
|
||||||
|
if self._agent is None:
|
||||||
|
raise RuntimeError("Agent has not been set.")
|
||||||
|
return self._agent
|
||||||
|
|
||||||
|
@property
|
||||||
|
def realtime_client(self) -> RealtimeClientProtocol:
|
||||||
|
if self._agent is None:
|
||||||
|
raise RuntimeError("Agent has not been set.")
|
||||||
|
if self._agent.realtime_client is None:
|
||||||
|
raise RuntimeError("Realtime client has not been set.")
|
||||||
|
|
||||||
|
return self._agent.realtime_client
|
||||||
|
|
||||||
|
async def run(self, agent: "RealtimeAgent") -> None:
|
||||||
|
"""Run the observer with the agent.
|
||||||
|
|
||||||
|
When implementing, be sure to call `self._ready_event.set()` when the observer is ready to process events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent (RealtimeAgent): The realtime agent attached to the observer.
|
||||||
|
"""
|
||||||
|
self._agent = agent
|
||||||
|
await self.initialize_session()
|
||||||
|
self._ready_event.set()
|
||||||
|
|
||||||
|
await self.run_loop()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def run_loop(self) -> None:
|
||||||
|
"""Run the loop if needed.
|
||||||
|
|
||||||
|
This method is called after the observer is ready to process events.
|
||||||
|
Events will be processed by the on_event method, this is just a hook for additional processing.
|
||||||
|
Use initialize_session to set up the session.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def initialize_session(self) -> None:
|
||||||
|
"""Initialize the session for the observer."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def wait_for_ready(self) -> None:
|
||||||
|
"""Get the event that is set when the observer is ready."""
|
||||||
|
await self._ready_event.wait()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def on_event(self, event: RealtimeEvent) -> None:
|
||||||
|
"""Handle an event from the OpenAI Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event (RealtimeServerEvent): The event from the OpenAI Realtime API.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def on_close(self) -> None:
|
||||||
|
"""Handle close of RealtimeClient."""
|
||||||
|
...
|
||||||
@@ -0,0 +1,483 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import warnings
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
from asyncer import asyncify, create_task_group, syncify
|
||||||
|
|
||||||
|
from ....agentchat.contrib.swarm_agent import AfterWorkOption, initiate_swarm_chat
|
||||||
|
from ....cache import AbstractCache
|
||||||
|
from ....code_utils import content_str
|
||||||
|
from ....doc_utils import export_module
|
||||||
|
from ... import Agent, ChatResult, ConversableAgent, LLMAgent
|
||||||
|
from ...utils import consolidate_chat_info, gather_usage_summary
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .clients import Role
|
||||||
|
from .realtime_agent import RealtimeAgent
|
||||||
|
|
||||||
|
__all__ = ["register_swarm"]
|
||||||
|
|
||||||
|
SWARM_SYSTEM_MESSAGE = (
|
||||||
|
"You are a helpful voice assistant. Your task is to listen to user and to coordinate the tasks based on his/her inputs."
|
||||||
|
"You can and will communicate using audio output only."
|
||||||
|
)
|
||||||
|
|
||||||
|
QUESTION_ROLE: "Role" = "user"
|
||||||
|
QUESTION_MESSAGE = (
|
||||||
|
"I have a question/information for myself. DO NOT ANSWER YOURSELF, GET THE ANSWER FROM ME. "
|
||||||
|
"repeat the question to me **WITH AUDIO OUTPUT** and AFTER YOU GET THE ANSWER FROM ME call 'answer_task_question' with the answer in first person\n\n"
|
||||||
|
"IMPORTANT: repeat just the question, without any additional information or context\n\n"
|
||||||
|
"The question is: '{}'\n\n"
|
||||||
|
)
|
||||||
|
QUESTION_TIMEOUT_SECONDS = 20
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
F = TypeVar("F", bound=Callable[..., Any])
|
||||||
|
|
||||||
|
|
||||||
|
def message_to_dict(message: Union[dict[str, Any], str]) -> dict[str, Any]:
|
||||||
|
if isinstance(message, str):
|
||||||
|
return {"content": message}
|
||||||
|
elif isinstance(message, dict):
|
||||||
|
return message
|
||||||
|
else:
|
||||||
|
return dict(message)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_oai_message(message: Union[dict[str, Any], str], role: str, adressee: Agent) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Parse a message into an OpenAI-compatible message format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The message to parse.
|
||||||
|
role: The role associated with the message.
|
||||||
|
adressee: The agent that will receive the message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The parsed message in OpenAI-compatible format.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the message lacks required fields like 'content', 'function_call', or 'tool_calls'.
|
||||||
|
"""
|
||||||
|
message = message_to_dict(message)
|
||||||
|
|
||||||
|
# Extract relevant fields while ensuring none are None
|
||||||
|
oai_message = {
|
||||||
|
key: message[key]
|
||||||
|
for key in ("content", "function_call", "tool_calls", "tool_responses", "tool_call_id", "name", "context")
|
||||||
|
if key in message and message[key] is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
# Validate or set the content field
|
||||||
|
if "content" not in oai_message:
|
||||||
|
if "function_call" in oai_message or "tool_calls" in oai_message:
|
||||||
|
oai_message["content"] = None
|
||||||
|
else:
|
||||||
|
raise ValueError("Message must have either 'content', 'function_call', or 'tool_calls' field.")
|
||||||
|
|
||||||
|
# Determine and assign the role
|
||||||
|
if message.get("role") in ["function", "tool"]:
|
||||||
|
oai_message["role"] = message["role"]
|
||||||
|
# Ensure all tool responses have string content
|
||||||
|
for tool_response in oai_message.get("tool_responses", []):
|
||||||
|
tool_response["content"] = str(tool_response["content"])
|
||||||
|
elif "override_role" in message:
|
||||||
|
oai_message["role"] = message["override_role"]
|
||||||
|
else:
|
||||||
|
oai_message["role"] = role
|
||||||
|
|
||||||
|
# Enforce specific role requirements for assistant messages
|
||||||
|
if oai_message.get("function_call") or oai_message.get("tool_calls"):
|
||||||
|
oai_message["role"] = "assistant"
|
||||||
|
|
||||||
|
# Add a name field if missing
|
||||||
|
if "name" not in oai_message:
|
||||||
|
oai_message["name"] = adressee.name
|
||||||
|
|
||||||
|
return oai_message
|
||||||
|
|
||||||
|
|
||||||
|
class SwarmableAgent(Agent):
|
||||||
|
"""A class for an agent that can participate in a swarm chat."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
system_message: str = "You are a helpful AI Assistant.",
|
||||||
|
is_termination_msg: Optional[Callable[..., bool]] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
silent: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
self._oai_messages: dict[Agent, Any] = defaultdict(list)
|
||||||
|
|
||||||
|
self._system_message = system_message
|
||||||
|
self._description = description if description is not None else system_message
|
||||||
|
self._is_termination_msg = (
|
||||||
|
is_termination_msg
|
||||||
|
if is_termination_msg is not None
|
||||||
|
else (lambda x: content_str(x.get("content")) == "TERMINATE")
|
||||||
|
)
|
||||||
|
self.silent = silent
|
||||||
|
|
||||||
|
self._name = name
|
||||||
|
|
||||||
|
# Initialize standalone client cache object.
|
||||||
|
self.client_cache = None
|
||||||
|
self.previous_cache = None
|
||||||
|
|
||||||
|
self.reply_at_receive: dict[Agent, bool] = defaultdict(bool)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def system_message(self) -> str:
|
||||||
|
return self._system_message
|
||||||
|
|
||||||
|
def update_system_message(self, system_message: str) -> None:
|
||||||
|
"""Update this agent's system message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_message (str): system message for inference.
|
||||||
|
"""
|
||||||
|
self._system_message = system_message
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return self._description
|
||||||
|
|
||||||
|
def send(
|
||||||
|
self,
|
||||||
|
message: Union[dict[str, Any], str],
|
||||||
|
recipient: Agent,
|
||||||
|
request_reply: Optional[bool] = None,
|
||||||
|
silent: Optional[bool] = False,
|
||||||
|
) -> None:
|
||||||
|
self._oai_messages[recipient].append(parse_oai_message(message, "assistant", recipient))
|
||||||
|
recipient.receive(message, self, request_reply)
|
||||||
|
|
||||||
|
def receive(
|
||||||
|
self,
|
||||||
|
message: Union[dict[str, Any], str],
|
||||||
|
sender: Agent,
|
||||||
|
request_reply: Optional[bool] = None,
|
||||||
|
silent: Optional[bool] = False,
|
||||||
|
) -> None:
|
||||||
|
self._oai_messages[sender].append(parse_oai_message(message, "user", self))
|
||||||
|
if request_reply is False or (request_reply is None and self.reply_at_receive[sender] is False):
|
||||||
|
return
|
||||||
|
reply = self.generate_reply(messages=self.chat_messages[sender], sender=sender)
|
||||||
|
if reply is not None:
|
||||||
|
self.send(reply, sender, silent=silent)
|
||||||
|
|
||||||
|
def generate_reply(
|
||||||
|
self,
|
||||||
|
messages: Optional[list[dict[str, Any]]] = None,
|
||||||
|
sender: Optional["Agent"] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[str, dict[str, Any], None]:
|
||||||
|
if messages is None:
|
||||||
|
if sender is None:
|
||||||
|
raise ValueError("Either messages or sender must be provided.")
|
||||||
|
messages = self._oai_messages[sender]
|
||||||
|
|
||||||
|
_, reply = self.check_termination_and_human_reply(messages=messages, sender=sender, config=None)
|
||||||
|
|
||||||
|
return reply
|
||||||
|
|
||||||
|
def check_termination_and_human_reply(
|
||||||
|
self,
|
||||||
|
messages: Optional[list[dict[str, Any]]] = None,
|
||||||
|
sender: Optional[Agent] = None,
|
||||||
|
config: Optional[Any] = None,
|
||||||
|
) -> tuple[bool, Union[str, None]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def initiate_chat(
|
||||||
|
self,
|
||||||
|
recipient: ConversableAgent,
|
||||||
|
message: Union[dict[str, Any], str],
|
||||||
|
clear_history: bool = True,
|
||||||
|
silent: Optional[bool] = False,
|
||||||
|
cache: Optional[AbstractCache] = None,
|
||||||
|
summary_args: Optional[dict[str, Any]] = {},
|
||||||
|
**kwargs: dict[str, Any],
|
||||||
|
) -> ChatResult:
|
||||||
|
_chat_info = locals().copy()
|
||||||
|
_chat_info["sender"] = self
|
||||||
|
consolidate_chat_info(_chat_info, uniform_sender=self)
|
||||||
|
recipient._raise_exception_on_async_reply_functions()
|
||||||
|
recipient.previous_cache = recipient.client_cache # type: ignore[attr-defined]
|
||||||
|
recipient.client_cache = cache # type: ignore[attr-defined, assignment]
|
||||||
|
|
||||||
|
self._prepare_chat(recipient, clear_history)
|
||||||
|
self.send(message, recipient, silent=silent)
|
||||||
|
summary = self._last_msg_as_summary(self, recipient, summary_args)
|
||||||
|
|
||||||
|
recipient.client_cache = recipient.previous_cache # type: ignore[attr-defined]
|
||||||
|
recipient.previous_cache = None # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
chat_result = ChatResult(
|
||||||
|
chat_history=self.chat_messages[recipient],
|
||||||
|
summary=summary,
|
||||||
|
cost=gather_usage_summary([self, recipient]), # type: ignore[arg-type]
|
||||||
|
human_input=[],
|
||||||
|
)
|
||||||
|
return chat_result
|
||||||
|
|
||||||
|
async def a_generate_reply(
|
||||||
|
self,
|
||||||
|
messages: Optional[list[dict[str, Any]]] = None,
|
||||||
|
sender: Optional["Agent"] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[str, dict[str, Any], None]:
|
||||||
|
return self.generate_reply(messages=messages, sender=sender, **kwargs)
|
||||||
|
|
||||||
|
async def a_receive(
|
||||||
|
self,
|
||||||
|
message: Union[dict[str, Any], str],
|
||||||
|
sender: "Agent",
|
||||||
|
request_reply: Optional[bool] = None,
|
||||||
|
) -> None:
|
||||||
|
self.receive(message, sender, request_reply)
|
||||||
|
|
||||||
|
async def a_send(
|
||||||
|
self,
|
||||||
|
message: Union[dict[str, Any], str],
|
||||||
|
recipient: "Agent",
|
||||||
|
request_reply: Optional[bool] = None,
|
||||||
|
) -> None:
|
||||||
|
self.send(message, recipient, request_reply)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chat_messages(self) -> dict[Agent, list[dict[str, Any]]]:
|
||||||
|
"""A dictionary of conversations from agent to list of messages."""
|
||||||
|
return self._oai_messages
|
||||||
|
|
||||||
|
def last_message(self, agent: Optional[Agent] = None) -> Optional[dict[str, Any]]:
|
||||||
|
if agent is None:
|
||||||
|
n_conversations = len(self._oai_messages)
|
||||||
|
if n_conversations == 0:
|
||||||
|
return None
|
||||||
|
if n_conversations == 1:
|
||||||
|
for conversation in self._oai_messages.values():
|
||||||
|
return conversation[-1] # type: ignore[no-any-return]
|
||||||
|
raise ValueError("More than one conversation is found. Please specify the sender to get the last message.")
|
||||||
|
if agent not in self._oai_messages():
|
||||||
|
raise KeyError(
|
||||||
|
f"The agent '{agent.name}' is not present in any conversation. No history available for this agent."
|
||||||
|
)
|
||||||
|
return self._oai_messages[agent][-1] # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
def _prepare_chat(
|
||||||
|
self,
|
||||||
|
recipient: ConversableAgent,
|
||||||
|
clear_history: bool,
|
||||||
|
prepare_recipient: bool = True,
|
||||||
|
reply_at_receive: bool = True,
|
||||||
|
) -> None:
|
||||||
|
self.reply_at_receive[recipient] = reply_at_receive
|
||||||
|
if clear_history:
|
||||||
|
self._oai_messages[recipient].clear()
|
||||||
|
if prepare_recipient:
|
||||||
|
recipient._prepare_chat(self, clear_history, False, reply_at_receive) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
def _raise_exception_on_async_reply_functions(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_ui_tools(self, tools: Optional[list] = None) -> None:
|
||||||
|
"""Set UI tools for the agent."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def unset_ui_tools(self) -> None:
|
||||||
|
"""Unset UI tools for the agent."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _last_msg_as_summary(sender: Agent, recipient: Agent, summary_args: Optional[dict[str, Any]]) -> str:
|
||||||
|
"""Get a chat summary from the last message of the recipient."""
|
||||||
|
summary = ""
|
||||||
|
try:
|
||||||
|
content = recipient.last_message(sender)["content"] # type: ignore[attr-defined]
|
||||||
|
if isinstance(content, str):
|
||||||
|
summary = content.replace("TERMINATE", "")
|
||||||
|
elif isinstance(content, list):
|
||||||
|
summary = "\n".join(
|
||||||
|
x["text"].replace("TERMINATE", "") for x in content if isinstance(x, dict) and "text" in x
|
||||||
|
)
|
||||||
|
except (IndexError, AttributeError) as e:
|
||||||
|
warnings.warn(f"Cannot extract summary using last_msg: {e}. Using an empty str as summary.", UserWarning)
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
# check that the SwarmableAgent class is implementing LLMAgent protocol
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
|
def _create_swarmable_agent(
|
||||||
|
name: str,
|
||||||
|
system_message: str,
|
||||||
|
is_termination_msg: Optional[Callable[..., bool]],
|
||||||
|
description: Optional[str],
|
||||||
|
silent: Optional[bool],
|
||||||
|
) -> LLMAgent:
|
||||||
|
return SwarmableAgent(
|
||||||
|
name=name,
|
||||||
|
system_message=system_message,
|
||||||
|
is_termination_msg=is_termination_msg,
|
||||||
|
description=description,
|
||||||
|
silent=silent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SwarmableRealtimeAgent(SwarmableAgent):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
realtime_agent: "RealtimeAgent",
|
||||||
|
initial_agent: ConversableAgent,
|
||||||
|
agents: list[ConversableAgent],
|
||||||
|
question_message: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
self._initial_agent = initial_agent
|
||||||
|
self._agents = agents
|
||||||
|
self._realtime_agent = realtime_agent
|
||||||
|
|
||||||
|
self._answer_event: anyio.Event = anyio.Event()
|
||||||
|
self._answer: str = ""
|
||||||
|
self.question_message = question_message or QUESTION_MESSAGE
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
name=realtime_agent._name,
|
||||||
|
is_termination_msg=None,
|
||||||
|
description=None,
|
||||||
|
silent=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset_answer(self) -> None:
|
||||||
|
"""Reset the answer event."""
|
||||||
|
self._answer_event = anyio.Event()
|
||||||
|
|
||||||
|
def set_answer(self, answer: str) -> str:
|
||||||
|
"""Set the answer to the question."""
|
||||||
|
self._answer = answer
|
||||||
|
self._answer_event.set()
|
||||||
|
return "Answer set successfully."
|
||||||
|
|
||||||
|
async def get_answer(self) -> str:
|
||||||
|
"""Get the answer to the question."""
|
||||||
|
await self._answer_event.wait()
|
||||||
|
return self._answer
|
||||||
|
|
||||||
|
async def ask_question(self, question: str, question_timeout: int) -> None:
|
||||||
|
"""Send a question for the user to the agent and wait for the answer.
|
||||||
|
If the answer is not received within the timeout, the question is repeated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
question: The question to ask the user.
|
||||||
|
question_timeout: The time in seconds to wait for the answer.
|
||||||
|
"""
|
||||||
|
self.reset_answer()
|
||||||
|
realtime_client = self._realtime_agent._realtime_client
|
||||||
|
await realtime_client.send_text(role=QUESTION_ROLE, text=question)
|
||||||
|
|
||||||
|
async def _check_event_set(timeout: int = question_timeout) -> bool:
|
||||||
|
for _ in range(timeout):
|
||||||
|
if self._answer_event.is_set():
|
||||||
|
return True
|
||||||
|
await anyio.sleep(1)
|
||||||
|
return False
|
||||||
|
|
||||||
|
while not await _check_event_set():
|
||||||
|
await realtime_client.send_text(role=QUESTION_ROLE, text=question)
|
||||||
|
|
||||||
|
def check_termination_and_human_reply(
|
||||||
|
self,
|
||||||
|
messages: Optional[list[dict[str, Any]]] = None,
|
||||||
|
sender: Optional[Agent] = None,
|
||||||
|
config: Optional[Any] = None,
|
||||||
|
) -> tuple[bool, Optional[str]]:
|
||||||
|
"""Check if the conversation should be terminated and if the agent should reply.
|
||||||
|
|
||||||
|
Called when its agents turn in the chat conversation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (list[dict[str, Any]]): The messages in the conversation.
|
||||||
|
sender (Agent): The agent that sent the message.
|
||||||
|
config (Optional[Any]): The configuration for the agent.
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
async def get_input() -> None:
|
||||||
|
async with create_task_group() as tg:
|
||||||
|
tg.soonify(self.ask_question)(
|
||||||
|
self.question_message.format(messages[-1]["content"]),
|
||||||
|
question_timeout=QUESTION_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
|
||||||
|
syncify(get_input)()
|
||||||
|
|
||||||
|
return True, {"role": "user", "content": self._answer} # type: ignore[return-value]
|
||||||
|
|
||||||
|
def start_chat(self) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def configure_realtime_agent(self, system_message: Optional[str]) -> None:
|
||||||
|
realtime_agent = self._realtime_agent
|
||||||
|
|
||||||
|
logger = realtime_agent.logger
|
||||||
|
if not system_message:
|
||||||
|
if realtime_agent.system_message != "You are a helpful AI Assistant.":
|
||||||
|
logger.warning(
|
||||||
|
"Overriding system message set up in `__init__`, please use `system_message` parameter of the `register_swarm` function instead."
|
||||||
|
)
|
||||||
|
system_message = SWARM_SYSTEM_MESSAGE
|
||||||
|
|
||||||
|
realtime_agent._system_message = system_message
|
||||||
|
|
||||||
|
realtime_agent.register_realtime_function(
|
||||||
|
name="answer_task_question", description="Answer question from the task"
|
||||||
|
)(self.set_answer)
|
||||||
|
|
||||||
|
async def on_observers_ready() -> None:
|
||||||
|
self._realtime_agent._tg.soonify(asyncify(initiate_swarm_chat))(
|
||||||
|
initial_agent=self._initial_agent,
|
||||||
|
agents=self._agents,
|
||||||
|
user_agent=self, # type: ignore[arg-type]
|
||||||
|
messages="Find out what the user wants.",
|
||||||
|
after_work=AfterWorkOption.REVERT_TO_USER,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._realtime_agent.callbacks.on_observers_ready = on_observers_ready
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.agentchat.realtime.experimental")
|
||||||
|
def register_swarm(
|
||||||
|
*,
|
||||||
|
realtime_agent: "RealtimeAgent",
|
||||||
|
initial_agent: ConversableAgent,
|
||||||
|
agents: list[ConversableAgent],
|
||||||
|
system_message: Optional[str] = None,
|
||||||
|
question_message: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Create a SwarmableRealtimeAgent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
realtime_agent (RealtimeAgent): The RealtimeAgent to create the SwarmableRealtimeAgent from.
|
||||||
|
initial_agent (ConversableAgent): The initial agent.
|
||||||
|
agents (list[ConversableAgent]): The agents in the swarm.
|
||||||
|
system_message (Optional[str]): The system message to set for the agent. If None, the default system message is used.
|
||||||
|
question_message (Optional[str]): The question message to set for the agent. If None, the default QUESTION_MESSAGE is used.
|
||||||
|
"""
|
||||||
|
swarmable_agent = SwarmableRealtimeAgent(
|
||||||
|
realtime_agent=realtime_agent, initial_agent=initial_agent, agents=agents, question_message=question_message
|
||||||
|
)
|
||||||
|
|
||||||
|
swarmable_agent.configure_realtime_agent(system_message=system_message)
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
__all__ = ["WebSocketProtocol"]
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class WebSocketProtocol(Protocol):
|
||||||
|
"""WebSocket protocol for sending and receiving JSON data modelled after FastAPI's WebSocket."""
|
||||||
|
|
||||||
|
async def send_json(self, data: Any, mode: str = "text") -> None: ...
|
||||||
|
|
||||||
|
async def receive_json(self, mode: str = "text") -> Any: ...
|
||||||
|
|
||||||
|
async def receive_text(self) -> str: ...
|
||||||
|
|
||||||
|
def iter_text(self) -> AsyncIterator[str]: ...
|
||||||
21
mm_agents/coact/autogen/agentchat/realtime_agent/__init__.py
Normal file
21
mm_agents/coact/autogen/agentchat/realtime_agent/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from ..realtime.experimental import (
|
||||||
|
FunctionObserver,
|
||||||
|
RealtimeAgent,
|
||||||
|
RealtimeObserver,
|
||||||
|
TwilioAudioAdapter,
|
||||||
|
WebSocketAudioAdapter,
|
||||||
|
register_swarm,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"FunctionObserver",
|
||||||
|
"RealtimeAgent",
|
||||||
|
"RealtimeObserver",
|
||||||
|
"TwilioAudioAdapter",
|
||||||
|
"WebSocketAudioAdapter",
|
||||||
|
"register_swarm",
|
||||||
|
]
|
||||||
111
mm_agents/coact/autogen/agentchat/user_proxy_agent.py
Normal file
111
mm_agents/coact/autogen/agentchat/user_proxy_agent.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
from typing import Any, Callable, Literal, Optional, Union
|
||||||
|
|
||||||
|
from ..doc_utils import export_module
|
||||||
|
from ..llm_config import LLMConfig
|
||||||
|
from ..runtime_logging import log_new_agent, logging_enabled
|
||||||
|
from .conversable_agent import ConversableAgent
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen")
|
||||||
|
class UserProxyAgent(ConversableAgent):
|
||||||
|
"""(In preview) A proxy agent for the user, that can execute code and provide feedback to the other agents.
|
||||||
|
|
||||||
|
UserProxyAgent is a subclass of ConversableAgent configured with `human_input_mode` to ALWAYS
|
||||||
|
and `llm_config` to False. By default, the agent will prompt for human input every time a message is received.
|
||||||
|
Code execution is enabled by default. LLM-based auto reply is disabled by default.
|
||||||
|
To modify auto reply, register a method with [`register_reply`](../ConversableAgent#register-reply).
|
||||||
|
To modify the way to get human input, override `get_human_input` method.
|
||||||
|
To modify the way to execute code blocks, single code block, or function call, override `execute_code_blocks`,
|
||||||
|
`run_code`, and `execute_function` methods respectively.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Default UserProxyAgent.description values, based on human_input_mode
|
||||||
|
DEFAULT_USER_PROXY_AGENT_DESCRIPTIONS = {
|
||||||
|
"ALWAYS": "An attentive HUMAN user who can answer questions about the task, and can perform tasks such as running Python code or inputting command line commands at a Linux terminal and reporting back the execution results.",
|
||||||
|
"TERMINATE": "A user that can run Python code or input command line commands at a Linux terminal and report back the execution results.",
|
||||||
|
"NEVER": "A computer terminal that performs no other action than running Python scripts (provided to it quoted in ```python code blocks), or sh shell scripts (provided to it quoted in ```sh code blocks).",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
is_termination_msg: Optional[Callable[[dict[str, Any]], bool]] = None,
|
||||||
|
max_consecutive_auto_reply: Optional[int] = None,
|
||||||
|
human_input_mode: Literal["ALWAYS", "TERMINATE", "NEVER"] = "ALWAYS",
|
||||||
|
function_map: Optional[dict[str, Callable[..., Any]]] = None,
|
||||||
|
code_execution_config: Union[dict[str, Any], Literal[False]] = {},
|
||||||
|
default_auto_reply: Optional[Union[str, dict[str, Any]]] = "",
|
||||||
|
llm_config: Optional[Union[LLMConfig, dict[str, Any], Literal[False]]] = False,
|
||||||
|
system_message: Optional[Union[str, list[str]]] = "",
|
||||||
|
description: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
"""Args:
|
||||||
|
name (str): name of the agent.
|
||||||
|
is_termination_msg (function): a function that takes a message in the form of a dictionary
|
||||||
|
and returns a boolean value indicating if this received message is a termination message.
|
||||||
|
The dict can contain the following keys: "content", "role", "name", "function_call".
|
||||||
|
max_consecutive_auto_reply (int): the maximum number of consecutive auto replies.
|
||||||
|
default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case).
|
||||||
|
The limit only plays a role when human_input_mode is not "ALWAYS".
|
||||||
|
human_input_mode (str): whether to ask for human inputs every time a message is received.
|
||||||
|
Possible values are "ALWAYS", "TERMINATE", "NEVER".
|
||||||
|
(1) When "ALWAYS", the agent prompts for human input every time a message is received.
|
||||||
|
Under this mode, the conversation stops when the human input is "exit",
|
||||||
|
or when is_termination_msg is True and there is no human input.
|
||||||
|
(2) When "TERMINATE", the agent only prompts for human input only when a termination message is received or
|
||||||
|
the number of auto reply reaches the max_consecutive_auto_reply.
|
||||||
|
(3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops
|
||||||
|
when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True.
|
||||||
|
function_map (dict[str, callable]): Mapping function names (passed to openai) to callable functions.
|
||||||
|
code_execution_config (dict or False): config for the code execution.
|
||||||
|
To disable code execution, set to False. Otherwise, set to a dictionary with the following keys:
|
||||||
|
- work_dir (Optional, str): The working directory for the code execution.
|
||||||
|
If None, a default working directory will be used.
|
||||||
|
The default working directory is the "extensions" directory under
|
||||||
|
"path_to_autogen".
|
||||||
|
- use_docker (Optional, list, str or bool): The docker image to use for code execution.
|
||||||
|
Default is True, which means the code will be executed in a docker container. A default list of images will be used.
|
||||||
|
If a list or a str of image name(s) is provided, the code will be executed in a docker container
|
||||||
|
with the first image successfully pulled.
|
||||||
|
If False, the code will be executed in the current environment.
|
||||||
|
We strongly recommend using docker for code execution.
|
||||||
|
- timeout (Optional, int): The maximum execution time in seconds.
|
||||||
|
- last_n_messages (Experimental, Optional, int): The number of messages to look back for code execution. Default to 1.
|
||||||
|
default_auto_reply (str or dict or None): the default auto reply message when no code execution or llm based reply is generated.
|
||||||
|
llm_config (LLMConfig or dict or False or None): llm inference configuration.
|
||||||
|
Please refer to [OpenAIWrapper.create](https://docs.ag2.ai/latest/docs/api-reference/autogen/OpenAIWrapper/#autogen.OpenAIWrapper.create)
|
||||||
|
for available options.
|
||||||
|
Default to False, which disables llm-based auto reply.
|
||||||
|
When set to None, will use self.DEFAULT_CONFIG, which defaults to False.
|
||||||
|
system_message (str or List): system message for ChatCompletion inference.
|
||||||
|
Only used when llm_config is not False. Use it to reprogram the agent.
|
||||||
|
description (str): a short description of the agent. This description is used by other agents
|
||||||
|
(e.g. the GroupChatManager) to decide when to call upon this agent. (Default: system_message)
|
||||||
|
**kwargs (dict): Please refer to other kwargs in
|
||||||
|
[ConversableAgent](https://docs.ag2.ai/latest/docs/api-reference/autogen/ConversableAgent).
|
||||||
|
"""
|
||||||
|
super().__init__(
|
||||||
|
name=name,
|
||||||
|
system_message=system_message,
|
||||||
|
is_termination_msg=is_termination_msg,
|
||||||
|
max_consecutive_auto_reply=max_consecutive_auto_reply,
|
||||||
|
human_input_mode=human_input_mode,
|
||||||
|
function_map=function_map,
|
||||||
|
code_execution_config=code_execution_config,
|
||||||
|
llm_config=llm_config,
|
||||||
|
default_auto_reply=default_auto_reply,
|
||||||
|
description=(
|
||||||
|
description if description is not None else self.DEFAULT_USER_PROXY_AGENT_DESCRIPTIONS[human_input_mode]
|
||||||
|
),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if logging_enabled():
|
||||||
|
log_new_agent(self, locals())
|
||||||
206
mm_agents/coact/autogen/agentchat/utils.py
Normal file
206
mm_agents/coact/autogen/agentchat/utils.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
import re
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
from ..doc_utils import export_module
|
||||||
|
from .agent import Agent
|
||||||
|
|
||||||
|
|
||||||
|
def consolidate_chat_info(
|
||||||
|
chat_info: Union[dict[str, Any], list[dict[str, Any]]], uniform_sender: Optional[Agent] = None
|
||||||
|
) -> None:
|
||||||
|
if isinstance(chat_info, dict):
|
||||||
|
chat_info = [chat_info]
|
||||||
|
for c in chat_info:
|
||||||
|
if uniform_sender is None:
|
||||||
|
assert "sender" in c, "sender must be provided."
|
||||||
|
sender = c["sender"]
|
||||||
|
else:
|
||||||
|
sender = uniform_sender
|
||||||
|
assert "recipient" in c, "recipient must be provided."
|
||||||
|
summary_method = c.get("summary_method")
|
||||||
|
assert (
|
||||||
|
summary_method is None or callable(summary_method) or summary_method in ("last_msg", "reflection_with_llm")
|
||||||
|
), "summary_method must be a string chosen from 'reflection_with_llm' or 'last_msg' or a callable, or None."
|
||||||
|
if summary_method == "reflection_with_llm":
|
||||||
|
assert sender.client is not None or c["recipient"].client is not None, (
|
||||||
|
"llm client must be set in either the recipient or sender when summary_method is reflection_with_llm."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen")
|
||||||
|
def gather_usage_summary(agents: list[Agent]) -> dict[str, dict[str, Any]]:
|
||||||
|
r"""Gather usage summary from all agents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agents: (list): List of agents.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dictionary: A dictionary containing two keys:
|
||||||
|
- "usage_including_cached_inference": Cost information on the total usage, including the tokens in cached inference.
|
||||||
|
- "usage_excluding_cached_inference": Cost information on the usage of tokens, excluding the tokens in cache. No larger than "usage_including_cached_inference".
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"usage_including_cached_inference": {
|
||||||
|
"total_cost": 0.0006090000000000001,
|
||||||
|
"gpt-35-turbo": {
|
||||||
|
"cost": 0.0006090000000000001,
|
||||||
|
"prompt_tokens": 242,
|
||||||
|
"completion_tokens": 123,
|
||||||
|
"total_tokens": 365,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"usage_excluding_cached_inference": {
|
||||||
|
"total_cost": 0.0006090000000000001,
|
||||||
|
"gpt-35-turbo": {
|
||||||
|
"cost": 0.0006090000000000001,
|
||||||
|
"prompt_tokens": 242,
|
||||||
|
"completion_tokens": 123,
|
||||||
|
"total_tokens": 365,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If none of the agents incurred any cost (not having a client), then the usage_including_cached_inference and usage_excluding_cached_inference will be `{'total_cost': 0}`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def aggregate_summary(usage_summary: dict[str, Any], agent_summary: dict[str, Any]) -> None:
|
||||||
|
if agent_summary is None:
|
||||||
|
return
|
||||||
|
usage_summary["total_cost"] += agent_summary.get("total_cost", 0)
|
||||||
|
for model, data in agent_summary.items():
|
||||||
|
if model != "total_cost":
|
||||||
|
if model not in usage_summary:
|
||||||
|
usage_summary[model] = data.copy()
|
||||||
|
else:
|
||||||
|
usage_summary[model]["cost"] += data.get("cost", 0)
|
||||||
|
usage_summary[model]["prompt_tokens"] += data.get("prompt_tokens", 0)
|
||||||
|
usage_summary[model]["completion_tokens"] += data.get("completion_tokens", 0)
|
||||||
|
usage_summary[model]["total_tokens"] += data.get("total_tokens", 0)
|
||||||
|
|
||||||
|
usage_including_cached_inference = {"total_cost": 0}
|
||||||
|
usage_excluding_cached_inference = {"total_cost": 0}
|
||||||
|
|
||||||
|
for agent in agents:
|
||||||
|
if getattr(agent, "client", None):
|
||||||
|
aggregate_summary(usage_including_cached_inference, agent.client.total_usage_summary) # type: ignore[attr-defined]
|
||||||
|
aggregate_summary(usage_excluding_cached_inference, agent.client.actual_usage_summary) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"usage_including_cached_inference": usage_including_cached_inference,
|
||||||
|
"usage_excluding_cached_inference": usage_excluding_cached_inference,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_tags_from_content(tag: str, content: Union[str, list[dict[str, Any]]]) -> list[dict[str, Any]]:
|
||||||
|
"""Parses HTML style tags from message contents.
|
||||||
|
|
||||||
|
The parsing is done by looking for patterns in the text that match the format of HTML tags. The tag to be parsed is
|
||||||
|
specified as an argument to the function. The function looks for this tag in the text and extracts its content. The
|
||||||
|
content of a tag is everything that is inside the tag, between the opening and closing angle brackets. The content
|
||||||
|
can be a single string or a set of attribute-value pairs.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
`<img http://example.com/image.png> -> [{"tag": "img", "attr": {"src": "http://example.com/image.png"}, "match": re.Match}]`
|
||||||
|
```<audio text="Hello I'm a robot" prompt="whisper"> ->
|
||||||
|
[{"tag": "audio", "attr": {"text": "Hello I'm a robot", "prompt": "whisper"}, "match": re.Match}]```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tag (str): The HTML style tag to be parsed.
|
||||||
|
content (Union[str, list[dict[str, Any]]]): The message content to parse. Can be a string or a list of content
|
||||||
|
items.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[dict[str, str]]: A list of dictionaries, where each dictionary represents a parsed tag. Each dictionary
|
||||||
|
contains three key-value pairs: 'type' which is the tag, 'attr' which is a dictionary of the parsed attributes,
|
||||||
|
and 'match' which is a regular expression match object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the content is not a string or a list.
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
if isinstance(content, str):
|
||||||
|
results.extend(_parse_tags_from_text(tag, content))
|
||||||
|
# Handles case for multimodal messages.
|
||||||
|
elif isinstance(content, list):
|
||||||
|
for item in content:
|
||||||
|
if item.get("type") == "text":
|
||||||
|
results.extend(_parse_tags_from_text(tag, item["text"]))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"content must be str or list, but got {type(content)}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_tags_from_text(tag: str, text: str) -> list[dict[str, Any]]:
|
||||||
|
pattern = re.compile(f"<{tag} (.*?)>")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for match in re.finditer(pattern, text):
|
||||||
|
tag_attr = match.group(1).strip()
|
||||||
|
attr = _parse_attributes_from_tags(tag_attr)
|
||||||
|
|
||||||
|
results.append({"tag": tag, "attr": attr, "match": match})
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_attributes_from_tags(tag_content: str) -> dict[str, str]:
|
||||||
|
pattern = r"([^ ]+)"
|
||||||
|
attrs = re.findall(pattern, tag_content)
|
||||||
|
reconstructed_attrs = _reconstruct_attributes(attrs)
|
||||||
|
|
||||||
|
def _append_src_value(content: dict[str, str], value: Any) -> None:
|
||||||
|
if "src" in content:
|
||||||
|
content["src"] += f" {value}"
|
||||||
|
else:
|
||||||
|
content["src"] = value
|
||||||
|
|
||||||
|
content: dict[str, str] = {}
|
||||||
|
for attr in reconstructed_attrs:
|
||||||
|
if "=" not in attr:
|
||||||
|
_append_src_value(content, attr)
|
||||||
|
continue
|
||||||
|
|
||||||
|
key, value = attr.split("=", 1)
|
||||||
|
if value.startswith("'") or value.startswith('"'):
|
||||||
|
content[key] = value[1:-1] # remove quotes
|
||||||
|
else:
|
||||||
|
_append_src_value(content, attr)
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def _reconstruct_attributes(attrs: list[str]) -> list[str]:
|
||||||
|
"""Reconstructs attributes from a list of strings where some attributes may be split across multiple elements."""
|
||||||
|
|
||||||
|
def is_attr(attr: str) -> bool:
|
||||||
|
if "=" in attr:
|
||||||
|
_, value = attr.split("=", 1)
|
||||||
|
if value.startswith("'") or value.startswith('"'):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
reconstructed = []
|
||||||
|
found_attr = False
|
||||||
|
for attr in attrs:
|
||||||
|
if is_attr(attr):
|
||||||
|
reconstructed.append(attr)
|
||||||
|
found_attr = True
|
||||||
|
else:
|
||||||
|
if found_attr:
|
||||||
|
reconstructed[-1] += f" {attr}"
|
||||||
|
found_attr = True
|
||||||
|
elif reconstructed:
|
||||||
|
reconstructed[-1] += f" {attr}"
|
||||||
|
else:
|
||||||
|
reconstructed.append(attr)
|
||||||
|
return reconstructed
|
||||||
596
mm_agents/coact/autogen/code_utils.py
Normal file
596
mm_agents/coact/autogen/code_utils.py
Normal file
@@ -0,0 +1,596 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import re
|
||||||
|
import string
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import venv
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, TimeoutError
|
||||||
|
from hashlib import md5
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
|
import docker
|
||||||
|
|
||||||
|
from .types import UserMessageImageContentPart, UserMessageTextContentPart
|
||||||
|
|
||||||
|
SENTINEL = object()
|
||||||
|
DEFAULT_MODEL = "gpt-4"
|
||||||
|
FAST_MODEL = "gpt-3.5-turbo"
|
||||||
|
# Regular expression for finding a code block
|
||||||
|
# ```[ \t]*(\w+)?[ \t]*\r?\n(.*?)[ \t]*\r?\n``` Matches multi-line code blocks.
|
||||||
|
# The [ \t]* matches the potential spaces before language name.
|
||||||
|
# The (\w+)? matches the language, where the ? indicates it is optional.
|
||||||
|
# The [ \t]* matches the potential spaces (not newlines) after language name.
|
||||||
|
# The \r?\n makes sure there is a linebreak after ```.
|
||||||
|
# The (.*?) matches the code itself (non-greedy).
|
||||||
|
# The \r?\n makes sure there is a linebreak before ```.
|
||||||
|
# The [ \t]* matches the potential spaces before closing ``` (the spec allows indentation).
|
||||||
|
CODE_BLOCK_PATTERN = r"```[ \t]*(\w+)?[ \t]*\r?\n(.*?)\r?\n[ \t]*```"
|
||||||
|
WORKING_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extensions")
|
||||||
|
UNKNOWN = "unknown"
|
||||||
|
TIMEOUT_MSG = "Timeout"
|
||||||
|
DEFAULT_TIMEOUT = 600
|
||||||
|
WIN32 = sys.platform == "win32"
|
||||||
|
PATH_SEPARATOR = (WIN32 and "\\") or "/"
|
||||||
|
PYTHON_VARIANTS = ["python", "Python", "py"]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def content_str(content: Union[str, list[Union[UserMessageTextContentPart, UserMessageImageContentPart]], None]) -> str:
|
||||||
|
"""Converts the `content` field of an OpenAI message into a string format.
|
||||||
|
|
||||||
|
This function processes content that may be a string, a list of mixed text and image URLs, or None,
|
||||||
|
and converts it into a string. Text is directly appended to the result string, while image URLs are
|
||||||
|
represented by a placeholder image token. If the content is None, an empty string is returned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The content to be processed. Can be a string, a list of dictionaries representing text and image URLs, or None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A string representation of the input content. Image URLs are replaced with an image token.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- The function expects each dictionary in the list to have a "type" key that is either "text" or "image_url".
|
||||||
|
For "text" type, the "text" key's value is appended to the result. For "image_url", an image token is appended.
|
||||||
|
- This function is useful for handling content that may include both text and image references, especially
|
||||||
|
in contexts where images need to be represented as placeholders.
|
||||||
|
"""
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if not isinstance(content, list):
|
||||||
|
raise TypeError(f"content must be None, str, or list, but got {type(content)}")
|
||||||
|
|
||||||
|
rst = ""
|
||||||
|
for item in content:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
raise TypeError("Wrong content format: every element should be dict if the content is a list.")
|
||||||
|
assert "type" in item, "Wrong content format. Missing 'type' key in content's dict."
|
||||||
|
if item["type"] == "text":
|
||||||
|
rst += item["text"]
|
||||||
|
elif item["type"] in ["input_image", "image_url"]:
|
||||||
|
rst += "<image>"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Wrong content format: unknown type {item['type']} within the content")
|
||||||
|
return rst
|
||||||
|
|
||||||
|
|
||||||
|
def infer_lang(code: str) -> str:
|
||||||
|
"""Infer the language for the code.
|
||||||
|
TODO: make it robust.
|
||||||
|
"""
|
||||||
|
if code.startswith("python ") or code.startswith("pip") or code.startswith("python3 "):
|
||||||
|
return "sh"
|
||||||
|
|
||||||
|
# check if code is a valid python code
|
||||||
|
try:
|
||||||
|
compile(code, "test", "exec")
|
||||||
|
return "python"
|
||||||
|
except SyntaxError:
|
||||||
|
# not a valid python code
|
||||||
|
return UNKNOWN
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: In the future move, to better support https://spec.commonmark.org/0.30/#fenced-code-blocks
|
||||||
|
# perhaps by using a full Markdown parser.
|
||||||
|
def extract_code(
|
||||||
|
text: Union[str, list], pattern: str = CODE_BLOCK_PATTERN, detect_single_line_code: bool = False
|
||||||
|
) -> list[tuple[str, str]]:
|
||||||
|
"""Extract code from a text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str or List): The content to extract code from. The content can be
|
||||||
|
a string or a list, as returned by standard GPT or multimodal GPT.
|
||||||
|
pattern (str, optional): The regular expression pattern for finding the
|
||||||
|
code block. Defaults to CODE_BLOCK_PATTERN.
|
||||||
|
detect_single_line_code (bool, optional): Enable the new feature for
|
||||||
|
extracting single line code. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of tuples, each containing the language and the code.
|
||||||
|
If there is no code block in the input text, the language would be "unknown".
|
||||||
|
If there is code block but the language is not specified, the language would be "".
|
||||||
|
"""
|
||||||
|
text = content_str(text)
|
||||||
|
if not detect_single_line_code:
|
||||||
|
match = re.findall(pattern, text, flags=re.DOTALL)
|
||||||
|
return match if match else [(UNKNOWN, text)]
|
||||||
|
|
||||||
|
# Extract both multi-line and single-line code block, separated by the | operator
|
||||||
|
# `([^`]+)`: Matches inline code.
|
||||||
|
code_pattern = re.compile(CODE_BLOCK_PATTERN + r"|`([^`]+)`")
|
||||||
|
code_blocks = code_pattern.findall(text)
|
||||||
|
|
||||||
|
# Extract the individual code blocks and languages from the matched groups
|
||||||
|
extracted = []
|
||||||
|
for lang, group1, group2 in code_blocks:
|
||||||
|
if group1:
|
||||||
|
extracted.append((lang.strip(), group1.strip()))
|
||||||
|
elif group2:
|
||||||
|
extracted.append(("", group2.strip()))
|
||||||
|
|
||||||
|
return extracted
|
||||||
|
|
||||||
|
|
||||||
|
def timeout_handler(signum, frame):
|
||||||
|
raise TimeoutError("Timed out!")
|
||||||
|
|
||||||
|
|
||||||
|
def get_powershell_command():
|
||||||
|
try:
|
||||||
|
result = subprocess.run(["powershell", "$PSVersionTable.PSVersion.Major"], capture_output=True, text=True)
|
||||||
|
if result.returncode == 0:
|
||||||
|
return "powershell"
|
||||||
|
except (FileNotFoundError, NotADirectoryError):
|
||||||
|
# This means that 'powershell' command is not found so now we try looking for 'pwsh'
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["pwsh", "-Command", "$PSVersionTable.PSVersion.Major"], capture_output=True, text=True
|
||||||
|
)
|
||||||
|
if result.returncode == 0:
|
||||||
|
return "pwsh"
|
||||||
|
except FileExistsError as e:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
"Neither powershell.exe nor pwsh.exe is present in the system. "
|
||||||
|
"Please install PowerShell and try again. "
|
||||||
|
) from e
|
||||||
|
except NotADirectoryError as e:
|
||||||
|
raise NotADirectoryError(
|
||||||
|
"PowerShell is either not installed or its path is not given "
|
||||||
|
"properly in the environment variable PATH. Please check the "
|
||||||
|
"path and try again. "
|
||||||
|
) from e
|
||||||
|
except PermissionError as e:
|
||||||
|
raise PermissionError("No permission to run powershell.") from e
|
||||||
|
|
||||||
|
|
||||||
|
def _cmd(lang: str) -> str:
|
||||||
|
if lang in PYTHON_VARIANTS:
|
||||||
|
return "python"
|
||||||
|
if lang.startswith("python") or lang in ["bash", "sh"]:
|
||||||
|
return lang
|
||||||
|
if lang in ["shell"]:
|
||||||
|
return "sh"
|
||||||
|
if lang == "javascript":
|
||||||
|
return "node"
|
||||||
|
if lang in ["ps1", "pwsh", "powershell"]:
|
||||||
|
powershell_command = get_powershell_command()
|
||||||
|
return powershell_command
|
||||||
|
|
||||||
|
raise NotImplementedError(f"{lang} not recognized in code execution")
|
||||||
|
|
||||||
|
|
||||||
|
def is_docker_running() -> bool:
|
||||||
|
"""Check if docker is running.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if docker is running; False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = docker.from_env()
|
||||||
|
client.ping()
|
||||||
|
return True
|
||||||
|
except docker.errors.DockerException:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def in_docker_container() -> bool:
|
||||||
|
"""Check if the code is running in a docker container.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the code is running in a docker container; False otherwise.
|
||||||
|
"""
|
||||||
|
return os.path.exists("/.dockerenv")
|
||||||
|
|
||||||
|
|
||||||
|
def decide_use_docker(use_docker: Optional[bool]) -> Optional[bool]:
|
||||||
|
if use_docker is None:
|
||||||
|
env_var_use_docker = os.environ.get("AUTOGEN_USE_DOCKER", "True")
|
||||||
|
|
||||||
|
truthy_values = {"1", "true", "yes", "t"}
|
||||||
|
falsy_values = {"0", "false", "no", "f"}
|
||||||
|
|
||||||
|
# Convert the value to lowercase for case-insensitive comparison
|
||||||
|
env_var_use_docker_lower = env_var_use_docker.lower()
|
||||||
|
|
||||||
|
# Determine the boolean value based on the environment variable
|
||||||
|
if env_var_use_docker_lower in truthy_values:
|
||||||
|
use_docker = True
|
||||||
|
elif env_var_use_docker_lower in falsy_values:
|
||||||
|
use_docker = False
|
||||||
|
elif env_var_use_docker_lower == "none": # Special case for 'None' as a string
|
||||||
|
use_docker = None
|
||||||
|
else:
|
||||||
|
# Raise an error for any unrecognized value
|
||||||
|
raise ValueError(
|
||||||
|
f'Invalid value for AUTOGEN_USE_DOCKER: {env_var_use_docker}. Please set AUTOGEN_USE_DOCKER to "1/True/yes", "0/False/no", or "None".'
|
||||||
|
)
|
||||||
|
return use_docker
|
||||||
|
|
||||||
|
|
||||||
|
def check_can_use_docker_or_throw(use_docker) -> None:
|
||||||
|
if use_docker is not None:
|
||||||
|
inside_docker = in_docker_container()
|
||||||
|
docker_installed_and_running = is_docker_running()
|
||||||
|
if use_docker and not inside_docker and not docker_installed_and_running:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Code execution is set to be run in docker (default behaviour) but docker is not running.\n"
|
||||||
|
"The options available are:\n"
|
||||||
|
"- Make sure docker is running (advised approach for code execution)\n"
|
||||||
|
'- Set "use_docker": False in code_execution_config\n'
|
||||||
|
'- Set AUTOGEN_USE_DOCKER to "0/False/no" in your environment variables'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_filename_for_docker_tag(filename: str) -> str:
|
||||||
|
"""Convert a filename to a valid docker tag.
|
||||||
|
See https://docs.docker.com/engine/reference/commandline/tag/ for valid tag
|
||||||
|
format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename (str): The filename to be converted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The sanitized Docker tag.
|
||||||
|
"""
|
||||||
|
# Replace any character not allowed with an underscore
|
||||||
|
allowed_chars = set(string.ascii_letters + string.digits + "_.-")
|
||||||
|
sanitized = "".join(char if char in allowed_chars else "_" for char in filename)
|
||||||
|
|
||||||
|
# Ensure it does not start with a period or a dash
|
||||||
|
if sanitized.startswith(".") or sanitized.startswith("-"):
|
||||||
|
sanitized = "_" + sanitized[1:]
|
||||||
|
|
||||||
|
# Truncate if longer than 128 characters
|
||||||
|
return sanitized[:128]
|
||||||
|
|
||||||
|
|
||||||
|
def execute_code(
|
||||||
|
code: Optional[str] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
filename: Optional[str] = None,
|
||||||
|
work_dir: Optional[str] = None,
|
||||||
|
use_docker: Union[list[str], str, bool] = SENTINEL,
|
||||||
|
lang: Optional[str] = "python",
|
||||||
|
) -> tuple[int, str, Optional[str]]:
|
||||||
|
"""Execute code in a docker container.
|
||||||
|
This function is not tested on MacOS.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code (Optional, str): The code to execute.
|
||||||
|
If None, the code from the file specified by filename will be executed.
|
||||||
|
Either code or filename must be provided.
|
||||||
|
timeout (Optional, int): The maximum execution time in seconds.
|
||||||
|
If None, a default timeout will be used. The default timeout is 600 seconds. On Windows, the timeout is not enforced when use_docker=False.
|
||||||
|
filename (Optional, str): The file name to save the code or where the code is stored when `code` is None.
|
||||||
|
If None, a file with a randomly generated name will be created.
|
||||||
|
The randomly generated file will be deleted after execution.
|
||||||
|
The file name must be a relative path. Relative paths are relative to the working directory.
|
||||||
|
work_dir (Optional, str): The working directory for the code execution.
|
||||||
|
If None, a default working directory will be used.
|
||||||
|
The default working directory is the "extensions" directory under
|
||||||
|
"path_to_autogen".
|
||||||
|
use_docker (list, str or bool): The docker image to use for code execution.
|
||||||
|
Default is True, which means the code will be executed in a docker container. A default list of images will be used.
|
||||||
|
If a list or a str of image name(s) is provided, the code will be executed in a docker container
|
||||||
|
with the first image successfully pulled.
|
||||||
|
If False, the code will be executed in the current environment.
|
||||||
|
Expected behaviour:
|
||||||
|
- If `use_docker` is not set (i.e. left default to True) or is explicitly set to True and the docker package is available, the code will run in a Docker container.
|
||||||
|
- If `use_docker` is not set (i.e. left default to True) or is explicitly set to True but the Docker package is missing or docker isn't running, an error will be raised.
|
||||||
|
- If `use_docker` is explicitly set to False, the code will run natively.
|
||||||
|
If the code is executed in the current environment,
|
||||||
|
the code must be trusted.
|
||||||
|
lang (Optional, str): The language of the code. Default is "python".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 0 if the code executes successfully.
|
||||||
|
str: The error message if the code fails to execute; the stdout otherwise.
|
||||||
|
image: The docker image name after container run when docker is used.
|
||||||
|
"""
|
||||||
|
if all((code is None, filename is None)):
|
||||||
|
error_msg = f"Either {code=} or {filename=} must be provided."
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise AssertionError(error_msg)
|
||||||
|
|
||||||
|
running_inside_docker = in_docker_container()
|
||||||
|
docker_running = is_docker_running()
|
||||||
|
|
||||||
|
# SENTINEL is used to indicate that the user did not explicitly set the argument
|
||||||
|
if use_docker is SENTINEL:
|
||||||
|
use_docker = decide_use_docker(use_docker=None)
|
||||||
|
check_can_use_docker_or_throw(use_docker)
|
||||||
|
|
||||||
|
timeout = timeout or DEFAULT_TIMEOUT
|
||||||
|
original_filename = filename
|
||||||
|
if WIN32 and lang in ["sh", "shell"] and (not use_docker):
|
||||||
|
lang = "ps1"
|
||||||
|
if filename is None:
|
||||||
|
code_hash = md5(code.encode()).hexdigest()
|
||||||
|
# create a file with a automatically generated name
|
||||||
|
filename = f"tmp_code_{code_hash}.{'py' if lang.startswith('python') else lang}"
|
||||||
|
if work_dir is None:
|
||||||
|
work_dir = WORKING_DIR
|
||||||
|
|
||||||
|
filepath = os.path.join(work_dir, filename)
|
||||||
|
file_dir = os.path.dirname(filepath)
|
||||||
|
os.makedirs(file_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if code is not None:
|
||||||
|
with open(filepath, "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(code)
|
||||||
|
|
||||||
|
if not use_docker or running_inside_docker:
|
||||||
|
# already running in a docker container
|
||||||
|
cmd = [
|
||||||
|
sys.executable if lang.startswith("python") else _cmd(lang),
|
||||||
|
f".\\{filename}" if WIN32 else filename,
|
||||||
|
]
|
||||||
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
|
future = executor.submit(
|
||||||
|
subprocess.run,
|
||||||
|
cmd,
|
||||||
|
cwd=work_dir,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
result = future.result(timeout=timeout)
|
||||||
|
except TimeoutError:
|
||||||
|
if original_filename is None:
|
||||||
|
os.remove(filepath)
|
||||||
|
return 1, TIMEOUT_MSG, None
|
||||||
|
if original_filename is None:
|
||||||
|
os.remove(filepath)
|
||||||
|
if result.returncode:
|
||||||
|
logs = result.stderr
|
||||||
|
if original_filename is None:
|
||||||
|
abs_path = str(pathlib.Path(filepath).absolute())
|
||||||
|
logs = logs.replace(str(abs_path), "").replace(filename, "")
|
||||||
|
else:
|
||||||
|
abs_path = str(pathlib.Path(work_dir).absolute()) + PATH_SEPARATOR
|
||||||
|
logs = logs.replace(str(abs_path), "")
|
||||||
|
else:
|
||||||
|
logs = result.stdout
|
||||||
|
return result.returncode, logs, None
|
||||||
|
|
||||||
|
# create a docker client
|
||||||
|
if use_docker and not docker_running:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Docker package is missing or docker is not running. Please make sure docker is running or set use_docker=False."
|
||||||
|
)
|
||||||
|
|
||||||
|
client = docker.from_env()
|
||||||
|
|
||||||
|
image_list = (
|
||||||
|
["python:3-slim", "python:3", "python:3-windowsservercore"]
|
||||||
|
if use_docker is True
|
||||||
|
else [use_docker]
|
||||||
|
if isinstance(use_docker, str)
|
||||||
|
else use_docker
|
||||||
|
)
|
||||||
|
for image in image_list:
|
||||||
|
# check if the image exists
|
||||||
|
try:
|
||||||
|
client.images.get(image)
|
||||||
|
break
|
||||||
|
except docker.errors.ImageNotFound:
|
||||||
|
# pull the image
|
||||||
|
print("Pulling image", image)
|
||||||
|
try:
|
||||||
|
client.images.pull(image)
|
||||||
|
break
|
||||||
|
except docker.errors.DockerException:
|
||||||
|
print("Failed to pull image", image)
|
||||||
|
# get a randomized str based on current time to wrap the exit code
|
||||||
|
exit_code_str = f"exitcode{time.time()}"
|
||||||
|
abs_path = pathlib.Path(work_dir).absolute()
|
||||||
|
cmd = [
|
||||||
|
"sh",
|
||||||
|
"-c",
|
||||||
|
f'{_cmd(lang)} "{filename}"; exit_code=$?; echo -n {exit_code_str}; echo -n $exit_code; echo {exit_code_str}',
|
||||||
|
]
|
||||||
|
# create a docker container
|
||||||
|
container = client.containers.run(
|
||||||
|
image,
|
||||||
|
command=cmd,
|
||||||
|
working_dir="/workspace",
|
||||||
|
detach=True,
|
||||||
|
# get absolute path to the working directory
|
||||||
|
volumes={abs_path: {"bind": "/workspace", "mode": "rw"}},
|
||||||
|
)
|
||||||
|
start_time = time.time()
|
||||||
|
while container.status != "exited" and time.time() - start_time < timeout:
|
||||||
|
# Reload the container object
|
||||||
|
container.reload()
|
||||||
|
if container.status != "exited":
|
||||||
|
container.stop()
|
||||||
|
container.remove()
|
||||||
|
if original_filename is None:
|
||||||
|
os.remove(filepath)
|
||||||
|
return 1, TIMEOUT_MSG, image
|
||||||
|
# get the container logs
|
||||||
|
logs = container.logs().decode("utf-8").rstrip()
|
||||||
|
# commit the image
|
||||||
|
tag = _sanitize_filename_for_docker_tag(filename)
|
||||||
|
container.commit(repository="python", tag=tag)
|
||||||
|
# remove the container
|
||||||
|
container.remove()
|
||||||
|
# check if the code executed successfully
|
||||||
|
exit_code = container.attrs["State"]["ExitCode"]
|
||||||
|
if exit_code == 0:
|
||||||
|
# extract the exit code from the logs
|
||||||
|
pattern = re.compile(f"{exit_code_str}(\\d+){exit_code_str}")
|
||||||
|
match = pattern.search(logs)
|
||||||
|
exit_code = 1 if match is None else int(match.group(1))
|
||||||
|
# remove the exit code from the logs
|
||||||
|
logs = logs if match is None else pattern.sub("", logs)
|
||||||
|
|
||||||
|
if original_filename is None:
|
||||||
|
os.remove(filepath)
|
||||||
|
if exit_code:
|
||||||
|
logs = logs.replace(f"/workspace/{filename if original_filename is None else ''}", "")
|
||||||
|
# return the exit code, logs and image
|
||||||
|
return exit_code, logs, f"python:{tag}"
|
||||||
|
|
||||||
|
|
||||||
|
_GENERATE_ASSERTIONS_CONFIG = {
|
||||||
|
"prompt": """Given the signature and docstring, write the exactly same number of assertion(s) for the provided example(s) in the docstring, without assertion messages.
|
||||||
|
|
||||||
|
func signature:
|
||||||
|
{definition}
|
||||||
|
assertions:""",
|
||||||
|
"model": FAST_MODEL,
|
||||||
|
"max_tokens": 256,
|
||||||
|
"stop": "\n\n",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_check(response):
|
||||||
|
"""Remove the check function from the response."""
|
||||||
|
# find the position of the check function
|
||||||
|
pos = response.find("def check(")
|
||||||
|
if pos == -1:
|
||||||
|
return response
|
||||||
|
return response[:pos]
|
||||||
|
|
||||||
|
|
||||||
|
def eval_function_completions(
|
||||||
|
responses: list[str],
|
||||||
|
definition: str,
|
||||||
|
test: Optional[str] = None,
|
||||||
|
entry_point: Optional[str] = None,
|
||||||
|
assertions: Optional[Union[str, Callable[[str], tuple[str, float]]]] = None,
|
||||||
|
timeout: Optional[float] = 3,
|
||||||
|
use_docker: Optional[bool] = True,
|
||||||
|
) -> dict:
|
||||||
|
"""`(openai<1)` Select a response from a list of responses for the function completion task (using generated assertions), and/or evaluate if the task is successful using a gold test.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
responses: The list of responses.
|
||||||
|
definition: The input definition.
|
||||||
|
test: The test code.
|
||||||
|
entry_point: The name of the function.
|
||||||
|
assertions: The assertion code which serves as a filter of the responses, or an assertion generator.
|
||||||
|
When provided, only the responses that pass the assertions will be considered for the actual test (if provided).
|
||||||
|
timeout: The timeout for executing the code.
|
||||||
|
use_docker: Whether to use docker for code execution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The success metrics.
|
||||||
|
"""
|
||||||
|
n = len(responses)
|
||||||
|
if assertions is None:
|
||||||
|
# no assertion filter
|
||||||
|
success_list = []
|
||||||
|
for i in range(n):
|
||||||
|
response = _remove_check(responses[i])
|
||||||
|
code = (
|
||||||
|
f"{response}\n{test}\ncheck({entry_point})"
|
||||||
|
if response.startswith("def")
|
||||||
|
else f"{definition}{response}\n{test}\ncheck({entry_point})"
|
||||||
|
)
|
||||||
|
success = execute_code(code, timeout=timeout, use_docker=use_docker)[0] == 0
|
||||||
|
success_list.append(success)
|
||||||
|
return {
|
||||||
|
"expected_success": 1 - pow(1 - sum(success_list) / n, n),
|
||||||
|
"success": any(s for s in success_list),
|
||||||
|
}
|
||||||
|
if callable(assertions) and n > 1:
|
||||||
|
# assertion generator
|
||||||
|
assertions, gen_cost = assertions(definition)
|
||||||
|
else:
|
||||||
|
assertions, gen_cost = None, 0
|
||||||
|
if n > 1 or test is None:
|
||||||
|
for i in range(n):
|
||||||
|
response = responses[i] = _remove_check(responses[i])
|
||||||
|
code = (
|
||||||
|
f"{response}\n{assertions}" if response.startswith("def") else f"{definition}{response}\n{assertions}"
|
||||||
|
)
|
||||||
|
succeed_assertions = execute_code(code, timeout=timeout, use_docker=use_docker)[0] == 0
|
||||||
|
if succeed_assertions:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# just test, no need to check assertions
|
||||||
|
succeed_assertions = False
|
||||||
|
i, response = 0, responses[0]
|
||||||
|
if test is None:
|
||||||
|
# no test code
|
||||||
|
return {
|
||||||
|
"index_selected": i,
|
||||||
|
"succeed_assertions": succeed_assertions,
|
||||||
|
"gen_cost": gen_cost,
|
||||||
|
"assertions": assertions,
|
||||||
|
}
|
||||||
|
code_test = (
|
||||||
|
f"{response}\n{test}\ncheck({entry_point})"
|
||||||
|
if response.startswith("def")
|
||||||
|
else f"{definition}{response}\n{test}\ncheck({entry_point})"
|
||||||
|
)
|
||||||
|
success = execute_code(code_test, timeout=timeout, use_docker=use_docker)[0] == 0
|
||||||
|
return {
|
||||||
|
"index_selected": i,
|
||||||
|
"succeed_assertions": succeed_assertions,
|
||||||
|
"success": success,
|
||||||
|
"gen_cost": gen_cost,
|
||||||
|
"assertions": assertions,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
_FUNC_COMPLETION_PROMPT = "# Python 3{definition}"
|
||||||
|
_FUNC_COMPLETION_STOP = ["\nclass", "\ndef", "\nif", "\nprint"]
|
||||||
|
_IMPLEMENT_CONFIGS = [
|
||||||
|
{"model": FAST_MODEL, "prompt": _FUNC_COMPLETION_PROMPT, "temperature": 0, "cache_seed": 0},
|
||||||
|
{"model": FAST_MODEL, "prompt": _FUNC_COMPLETION_PROMPT, "stop": _FUNC_COMPLETION_STOP, "n": 7, "cache_seed": 0},
|
||||||
|
{"model": DEFAULT_MODEL, "prompt": _FUNC_COMPLETION_PROMPT, "temperature": 0, "cache_seed": 1},
|
||||||
|
{"model": DEFAULT_MODEL, "prompt": _FUNC_COMPLETION_PROMPT, "stop": _FUNC_COMPLETION_STOP, "n": 2, "cache_seed": 2},
|
||||||
|
{"model": DEFAULT_MODEL, "prompt": _FUNC_COMPLETION_PROMPT, "stop": _FUNC_COMPLETION_STOP, "n": 1, "cache_seed": 2},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_virtual_env(dir_path: str, **env_args) -> SimpleNamespace:
|
||||||
|
"""Creates a python virtual environment and returns the context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dir_path (str): Directory path where the env will be created.
|
||||||
|
**env_args: Any extra args to pass to the `EnvBuilder`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SimpleNamespace: the virtual env context object.
|
||||||
|
"""
|
||||||
|
if not env_args:
|
||||||
|
env_args = {"with_pip": True}
|
||||||
|
env_builder = venv.EnvBuilder(**env_args)
|
||||||
|
env_builder.create(dir_path)
|
||||||
|
return env_builder.ensure_directories(dir_path)
|
||||||
22
mm_agents/coact/autogen/coding/__init__.py
Normal file
22
mm_agents/coact/autogen/coding/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Original portions of this file are derived from https://github.com/microsoft/autogen under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
from .base import CodeBlock, CodeExecutor, CodeExtractor, CodeResult
|
||||||
|
from .docker_commandline_code_executor import DockerCommandLineCodeExecutor
|
||||||
|
from .factory import CodeExecutorFactory
|
||||||
|
from .local_commandline_code_executor import LocalCommandLineCodeExecutor
|
||||||
|
from .markdown_code_extractor import MarkdownCodeExtractor
|
||||||
|
|
||||||
|
__all__ = (
|
||||||
|
"CodeBlock",
|
||||||
|
"CodeExecutor",
|
||||||
|
"CodeExecutorFactory",
|
||||||
|
"CodeExtractor",
|
||||||
|
"CodeResult",
|
||||||
|
"DockerCommandLineCodeExecutor",
|
||||||
|
"LocalCommandLineCodeExecutor",
|
||||||
|
"MarkdownCodeExtractor",
|
||||||
|
)
|
||||||
119
mm_agents/coact/autogen/coding/base.py
Normal file
119
mm_agents/coact/autogen/coding/base.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Literal, Optional, Protocol, TypedDict, Union, runtime_checkable
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from ..doc_utils import export_module
|
||||||
|
from ..types import UserMessageImageContentPart, UserMessageTextContentPart
|
||||||
|
|
||||||
|
__all__ = ("CodeBlock", "CodeExecutionConfig", "CodeExecutor", "CodeExtractor", "CodeResult")
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.coding")
|
||||||
|
class CodeBlock(BaseModel):
|
||||||
|
"""(Experimental) A class that represents a code block."""
|
||||||
|
|
||||||
|
code: str = Field(description="The code to execute.")
|
||||||
|
|
||||||
|
language: str = Field(description="The language of the code.")
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.coding")
|
||||||
|
class CodeResult(BaseModel):
|
||||||
|
"""(Experimental) A class that represents the result of a code execution."""
|
||||||
|
|
||||||
|
exit_code: int = Field(description="The exit code of the code execution.")
|
||||||
|
|
||||||
|
output: str = Field(description="The output of the code execution.")
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.coding")
|
||||||
|
class CodeExtractor(Protocol):
|
||||||
|
"""(Experimental) A code extractor class that extracts code blocks from a message."""
|
||||||
|
|
||||||
|
def extract_code_blocks(
|
||||||
|
self, message: Optional[Union[str, list[Union[UserMessageTextContentPart, UserMessageImageContentPart]]]]
|
||||||
|
) -> list[CodeBlock]:
|
||||||
|
"""(Experimental) Extract code blocks from a message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (str): The message to extract code blocks from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[CodeBlock]: The extracted code blocks.
|
||||||
|
"""
|
||||||
|
... # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
@export_module("autogen.coding")
|
||||||
|
class CodeExecutor(Protocol):
|
||||||
|
"""(Experimental) A code executor class that executes code blocks and returns the result."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def code_extractor(self) -> CodeExtractor:
|
||||||
|
"""(Experimental) The code extractor used by this code executor."""
|
||||||
|
... # pragma: no cover
|
||||||
|
|
||||||
|
def execute_code_blocks(self, code_blocks: list[CodeBlock]) -> CodeResult:
|
||||||
|
"""(Experimental) Execute code blocks and return the result.
|
||||||
|
|
||||||
|
This method should be implemented by the code executor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code_blocks (List[CodeBlock]): The code blocks to execute.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CodeResult: The result of the code execution.
|
||||||
|
"""
|
||||||
|
... # pragma: no cover
|
||||||
|
|
||||||
|
def restart(self) -> None:
|
||||||
|
"""(Experimental) Restart the code executor.
|
||||||
|
|
||||||
|
This method should be implemented by the code executor.
|
||||||
|
|
||||||
|
This method is called when the agent is reset.
|
||||||
|
"""
|
||||||
|
... # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
|
class IPythonCodeResult(CodeResult):
|
||||||
|
"""(Experimental) A code result class for IPython code executor."""
|
||||||
|
|
||||||
|
output_files: list[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="The list of files that the executed code blocks generated.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
CodeExecutionConfig = TypedDict(
|
||||||
|
"CodeExecutionConfig",
|
||||||
|
{
|
||||||
|
"executor": Union[Literal["ipython-embedded", "commandline-local"], CodeExecutor],
|
||||||
|
"last_n_messages": Union[int, Literal["auto"]],
|
||||||
|
"timeout": int,
|
||||||
|
"use_docker": Union[bool, str, list[str]],
|
||||||
|
"work_dir": str,
|
||||||
|
"ipython-embedded": Mapping[str, Any],
|
||||||
|
"commandline-local": Mapping[str, Any],
|
||||||
|
},
|
||||||
|
total=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CommandLineCodeResult(CodeResult):
|
||||||
|
"""(Experimental) A code result class for command line code executor."""
|
||||||
|
|
||||||
|
code_file: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The file that the executed code block was saved to.",
|
||||||
|
)
|
||||||
@@ -0,0 +1,268 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import atexit
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import uuid
|
||||||
|
from hashlib import md5
|
||||||
|
from pathlib import Path
|
||||||
|
from time import sleep
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import Any, ClassVar, Optional, Union
|
||||||
|
|
||||||
|
import docker
|
||||||
|
from docker.errors import ImageNotFound
|
||||||
|
|
||||||
|
from ..code_utils import TIMEOUT_MSG, _cmd
|
||||||
|
from ..doc_utils import export_module
|
||||||
|
from .base import CodeBlock, CodeExecutor, CodeExtractor, CommandLineCodeResult
|
||||||
|
from .markdown_code_extractor import MarkdownCodeExtractor
|
||||||
|
from .utils import _get_file_name_from_content, silence_pip
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 11):
|
||||||
|
from typing import Self
|
||||||
|
else:
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
|
def _wait_for_ready(container: Any, timeout: int = 60, stop_time: float = 0.1) -> None:
|
||||||
|
elapsed_time = 0.0
|
||||||
|
while container.status != "running" and elapsed_time < timeout:
|
||||||
|
sleep(stop_time)
|
||||||
|
elapsed_time += stop_time
|
||||||
|
container.reload()
|
||||||
|
continue
|
||||||
|
if container.status != "running":
|
||||||
|
raise ValueError("Container failed to start")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ("DockerCommandLineCodeExecutor",)
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.coding")
|
||||||
|
class DockerCommandLineCodeExecutor(CodeExecutor):
|
||||||
|
DEFAULT_EXECUTION_POLICY: ClassVar[dict[str, bool]] = {
|
||||||
|
"bash": True,
|
||||||
|
"shell": True,
|
||||||
|
"sh": True,
|
||||||
|
"pwsh": True,
|
||||||
|
"powershell": True,
|
||||||
|
"ps1": True,
|
||||||
|
"python": True,
|
||||||
|
"javascript": False,
|
||||||
|
"html": False,
|
||||||
|
"css": False,
|
||||||
|
}
|
||||||
|
LANGUAGE_ALIASES: ClassVar[dict[str, str]] = {"py": "python", "js": "javascript"}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image: str = "python:3-slim",
|
||||||
|
container_name: Optional[str] = None,
|
||||||
|
timeout: int = 60,
|
||||||
|
work_dir: Optional[Union[Path, str]] = None,
|
||||||
|
bind_dir: Optional[Union[Path, str]] = None,
|
||||||
|
auto_remove: bool = True,
|
||||||
|
stop_container: bool = True,
|
||||||
|
execution_policies: Optional[dict[str, bool]] = None,
|
||||||
|
):
|
||||||
|
"""(Experimental) A code executor class that executes code through
|
||||||
|
a command line environment in a Docker container.
|
||||||
|
|
||||||
|
The executor first saves each code block in a file in the working
|
||||||
|
directory, and then executes the code file in the container.
|
||||||
|
The executor executes the code blocks in the order they are received.
|
||||||
|
Currently, the executor only supports Python and shell scripts.
|
||||||
|
For Python code, use the language "python" for the code block.
|
||||||
|
For shell scripts, use the language "bash", "shell", or "sh" for the code
|
||||||
|
block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Docker image to use for code execution. Defaults to "python:3-slim".
|
||||||
|
container_name: Name of the Docker container which is created. If None, will autogenerate a name. Defaults to None.
|
||||||
|
timeout: The timeout for code execution. Defaults to 60.
|
||||||
|
work_dir: The working directory for the code execution. Defaults to Path(".").
|
||||||
|
bind_dir: The directory that will be bound to the code executor container. Useful for cases where you want to spawn
|
||||||
|
the container from within a container. Defaults to work_dir.
|
||||||
|
auto_remove: If true, will automatically remove the Docker container when it is stopped. Defaults to True.
|
||||||
|
stop_container: If true, will automatically stop the
|
||||||
|
container when stop is called, when the context manager exits or when
|
||||||
|
the Python process exits with atext. Defaults to True.
|
||||||
|
execution_policies: A dictionary mapping language names to boolean values that determine
|
||||||
|
whether code in that language should be executed. True means code in that language
|
||||||
|
will be executed, False means it will only be saved to a file. This overrides the
|
||||||
|
default execution policies. Defaults to None.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: On argument error, or if the container fails to start.
|
||||||
|
"""
|
||||||
|
work_dir = work_dir if work_dir is not None else Path()
|
||||||
|
|
||||||
|
if timeout < 1:
|
||||||
|
raise ValueError("Timeout must be greater than or equal to 1.")
|
||||||
|
|
||||||
|
if isinstance(work_dir, str):
|
||||||
|
work_dir = Path(work_dir)
|
||||||
|
work_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
if bind_dir is None:
|
||||||
|
bind_dir = work_dir
|
||||||
|
elif isinstance(bind_dir, str):
|
||||||
|
bind_dir = Path(bind_dir)
|
||||||
|
|
||||||
|
client = docker.from_env()
|
||||||
|
# Check if the image exists
|
||||||
|
try:
|
||||||
|
client.images.get(image)
|
||||||
|
except ImageNotFound:
|
||||||
|
logging.info(f"Pulling image {image}...")
|
||||||
|
# Let the docker exception escape if this fails.
|
||||||
|
client.images.pull(image)
|
||||||
|
|
||||||
|
if container_name is None:
|
||||||
|
container_name = f"autogen-code-exec-{uuid.uuid4()}"
|
||||||
|
|
||||||
|
# Start a container from the image, read to exec commands later
|
||||||
|
self._container = client.containers.create(
|
||||||
|
image,
|
||||||
|
name=container_name,
|
||||||
|
entrypoint="/bin/sh",
|
||||||
|
tty=True,
|
||||||
|
auto_remove=auto_remove,
|
||||||
|
volumes={str(bind_dir.resolve()): {"bind": "/workspace", "mode": "rw"}},
|
||||||
|
working_dir="/workspace",
|
||||||
|
)
|
||||||
|
self._container.start()
|
||||||
|
|
||||||
|
_wait_for_ready(self._container)
|
||||||
|
|
||||||
|
def cleanup() -> None:
|
||||||
|
try:
|
||||||
|
container = client.containers.get(container_name)
|
||||||
|
container.stop()
|
||||||
|
except docker.errors.NotFound:
|
||||||
|
pass
|
||||||
|
atexit.unregister(cleanup)
|
||||||
|
|
||||||
|
if stop_container:
|
||||||
|
atexit.register(cleanup)
|
||||||
|
|
||||||
|
self._cleanup = cleanup
|
||||||
|
|
||||||
|
# Check if the container is running
|
||||||
|
if self._container.status != "running":
|
||||||
|
raise ValueError(f"Failed to start container from image {image}. Logs: {self._container.logs()}")
|
||||||
|
|
||||||
|
self._timeout = timeout
|
||||||
|
self._work_dir: Path = work_dir
|
||||||
|
self._bind_dir: Path = bind_dir
|
||||||
|
self.execution_policies = self.DEFAULT_EXECUTION_POLICY.copy()
|
||||||
|
if execution_policies is not None:
|
||||||
|
self.execution_policies.update(execution_policies)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def timeout(self) -> int:
|
||||||
|
"""(Experimental) The timeout for code execution."""
|
||||||
|
return self._timeout
|
||||||
|
|
||||||
|
@property
|
||||||
|
def work_dir(self) -> Path:
|
||||||
|
"""(Experimental) The working directory for the code execution."""
|
||||||
|
return self._work_dir
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bind_dir(self) -> Path:
|
||||||
|
"""(Experimental) The binding directory for the code execution container."""
|
||||||
|
return self._bind_dir
|
||||||
|
|
||||||
|
@property
|
||||||
|
def code_extractor(self) -> CodeExtractor:
|
||||||
|
"""(Experimental) Export a code extractor that can be used by an agent."""
|
||||||
|
return MarkdownCodeExtractor()
|
||||||
|
|
||||||
|
def execute_code_blocks(self, code_blocks: list[CodeBlock]) -> CommandLineCodeResult:
|
||||||
|
"""(Experimental) Execute the code blocks and return the result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code_blocks (List[CodeBlock]): The code blocks to execute.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CommandlineCodeResult: The result of the code execution.
|
||||||
|
"""
|
||||||
|
if len(code_blocks) == 0:
|
||||||
|
raise ValueError("No code blocks to execute.")
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
files = []
|
||||||
|
last_exit_code = 0
|
||||||
|
for code_block in code_blocks:
|
||||||
|
lang = self.LANGUAGE_ALIASES.get(code_block.language.lower(), code_block.language.lower())
|
||||||
|
if lang not in self.DEFAULT_EXECUTION_POLICY:
|
||||||
|
outputs.append(f"Unsupported language {lang}\n")
|
||||||
|
last_exit_code = 1
|
||||||
|
break
|
||||||
|
|
||||||
|
execute_code = self.execution_policies.get(lang, False)
|
||||||
|
code = silence_pip(code_block.code, lang)
|
||||||
|
|
||||||
|
# Check if there is a filename comment
|
||||||
|
try:
|
||||||
|
filename = _get_file_name_from_content(code, self._work_dir)
|
||||||
|
except ValueError:
|
||||||
|
outputs.append("Filename is not in the workspace")
|
||||||
|
last_exit_code = 1
|
||||||
|
break
|
||||||
|
|
||||||
|
if not filename:
|
||||||
|
filename = f"tmp_code_{md5(code.encode()).hexdigest()}.{lang}"
|
||||||
|
|
||||||
|
code_path = self._work_dir / filename
|
||||||
|
with code_path.open("w", encoding="utf-8") as fout:
|
||||||
|
fout.write(code)
|
||||||
|
files.append(code_path)
|
||||||
|
|
||||||
|
if not execute_code:
|
||||||
|
outputs.append(f"Code saved to {code_path!s}\n")
|
||||||
|
continue
|
||||||
|
|
||||||
|
command = ["timeout", str(self._timeout), _cmd(lang), filename]
|
||||||
|
result = self._container.exec_run(command)
|
||||||
|
exit_code = result.exit_code
|
||||||
|
output = result.output.decode("utf-8")
|
||||||
|
if exit_code == 124:
|
||||||
|
output += "\n" + TIMEOUT_MSG
|
||||||
|
outputs.append(output)
|
||||||
|
|
||||||
|
last_exit_code = exit_code
|
||||||
|
if exit_code != 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
code_file = str(files[0]) if files else None
|
||||||
|
return CommandLineCodeResult(exit_code=last_exit_code, output="".join(outputs), code_file=code_file)
|
||||||
|
|
||||||
|
def restart(self) -> None:
|
||||||
|
"""(Experimental) Restart the code executor."""
|
||||||
|
self._container.restart()
|
||||||
|
if self._container.status != "running":
|
||||||
|
raise ValueError(f"Failed to restart container. Logs: {self._container.logs()}")
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""(Experimental) Stop the code executor."""
|
||||||
|
self._cleanup()
|
||||||
|
|
||||||
|
def __enter__(self) -> Self:
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[type[BaseException]],
|
||||||
|
exc_val: Optional[BaseException],
|
||||||
|
exc_tb: Optional[TracebackType],
|
||||||
|
) -> None:
|
||||||
|
self.stop()
|
||||||
47
mm_agents/coact/autogen/coding/factory.py
Normal file
47
mm_agents/coact/autogen/coding/factory.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
from ..doc_utils import export_module
|
||||||
|
from .base import CodeExecutionConfig, CodeExecutor
|
||||||
|
|
||||||
|
__all__ = ("CodeExecutorFactory",)
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.coding")
|
||||||
|
class CodeExecutorFactory:
|
||||||
|
"""(Experimental) A factory class for creating code executors."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create(code_execution_config: CodeExecutionConfig) -> CodeExecutor:
|
||||||
|
"""(Experimental) Get a code executor based on the code execution config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code_execution_config (Dict): The code execution config,
|
||||||
|
which is a dictionary that must contain the key "executor".
|
||||||
|
The value of the key "executor" can be either a string
|
||||||
|
or an instance of CodeExecutor, in which case the code
|
||||||
|
executor is returned directly.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CodeExecutor: The code executor.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the code executor is unknown or not specified.
|
||||||
|
"""
|
||||||
|
executor = code_execution_config.get("executor")
|
||||||
|
if isinstance(executor, CodeExecutor):
|
||||||
|
# If the executor is already an instance of CodeExecutor, return it.
|
||||||
|
return executor
|
||||||
|
if executor == "ipython-embedded":
|
||||||
|
from .jupyter.embedded_ipython_code_executor import EmbeddedIPythonCodeExecutor
|
||||||
|
|
||||||
|
return EmbeddedIPythonCodeExecutor(**code_execution_config.get("ipython-embedded", {}))
|
||||||
|
elif executor == "commandline-local":
|
||||||
|
from .local_commandline_code_executor import LocalCommandLineCodeExecutor
|
||||||
|
|
||||||
|
return LocalCommandLineCodeExecutor(**code_execution_config.get("commandline-local", {}))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown code executor {executor}")
|
||||||
202
mm_agents/coact/autogen/coding/func_with_reqs.py
Normal file
202
mm_agents/coact/autogen/coding/func_with_reqs.py
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from importlib.abc import SourceLoader
|
||||||
|
from textwrap import dedent, indent
|
||||||
|
from typing import Any, Callable, Generic, TypeVar, Union
|
||||||
|
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
P = ParamSpec("P")
|
||||||
|
|
||||||
|
|
||||||
|
def _to_code(func: Union["FunctionWithRequirements[T, P]", Callable[P, T], "FunctionWithRequirementsStr"]) -> str:
|
||||||
|
if isinstance(func, FunctionWithRequirementsStr):
|
||||||
|
return func.func
|
||||||
|
|
||||||
|
code = inspect.getsource(func)
|
||||||
|
# Strip the decorator
|
||||||
|
if code.startswith("@"):
|
||||||
|
code = code[code.index("\n") + 1 :]
|
||||||
|
return code
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Alias:
|
||||||
|
name: str
|
||||||
|
alias: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImportFromModule:
|
||||||
|
module: str
|
||||||
|
imports: list[Union[str, Alias]]
|
||||||
|
|
||||||
|
|
||||||
|
Import = Union[str, ImportFromModule, Alias]
|
||||||
|
|
||||||
|
|
||||||
|
def _import_to_str(im: Import) -> str:
|
||||||
|
if isinstance(im, str):
|
||||||
|
return f"import {im}"
|
||||||
|
elif isinstance(im, Alias):
|
||||||
|
return f"import {im.name} as {im.alias}"
|
||||||
|
else:
|
||||||
|
|
||||||
|
def to_str(i: Union[str, Alias]) -> str:
|
||||||
|
if isinstance(i, str):
|
||||||
|
return i
|
||||||
|
else:
|
||||||
|
return f"{i.name} as {i.alias}"
|
||||||
|
|
||||||
|
imports = ", ".join(map(to_str, im.imports))
|
||||||
|
return f"from {im.module} import {imports}"
|
||||||
|
|
||||||
|
|
||||||
|
class _StringLoader(SourceLoader):
|
||||||
|
def __init__(self, data: str):
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
def get_source(self, fullname: str) -> str:
|
||||||
|
return self.data
|
||||||
|
|
||||||
|
def get_data(self, path: str) -> bytes:
|
||||||
|
return self.data.encode("utf-8")
|
||||||
|
|
||||||
|
def get_filename(self, fullname: str) -> str:
|
||||||
|
return "<not a real path>/" + fullname + ".py"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FunctionWithRequirementsStr:
|
||||||
|
func: str
|
||||||
|
_compiled_func: Callable[..., Any]
|
||||||
|
_func_name: str
|
||||||
|
python_packages: list[str] = field(default_factory=list)
|
||||||
|
global_imports: list[Import] = field(default_factory=list)
|
||||||
|
|
||||||
|
def __init__(self, func: str, python_packages: list[str] = [], global_imports: list[Import] = []):
|
||||||
|
self.func = func
|
||||||
|
self.python_packages = python_packages
|
||||||
|
self.global_imports = global_imports
|
||||||
|
|
||||||
|
module_name = "func_module"
|
||||||
|
loader = _StringLoader(func)
|
||||||
|
spec = importlib.util.spec_from_loader(module_name, loader)
|
||||||
|
if spec is None:
|
||||||
|
raise ValueError("Could not create spec")
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
if spec.loader is None:
|
||||||
|
raise ValueError("Could not create loader")
|
||||||
|
|
||||||
|
try:
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Could not compile function: {e}") from e
|
||||||
|
|
||||||
|
functions = inspect.getmembers(module, inspect.isfunction)
|
||||||
|
if len(functions) != 1:
|
||||||
|
raise ValueError("The string must contain exactly one function")
|
||||||
|
|
||||||
|
self._func_name, self._compiled_func = functions[0]
|
||||||
|
|
||||||
|
def __call__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
raise NotImplementedError("String based function with requirement objects are not directly callable")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FunctionWithRequirements(Generic[T, P]):
|
||||||
|
func: Callable[P, T]
|
||||||
|
python_packages: list[str] = field(default_factory=list)
|
||||||
|
global_imports: list[Import] = field(default_factory=list)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_callable(
|
||||||
|
cls, func: Callable[P, T], python_packages: list[str] = [], global_imports: list[Import] = []
|
||||||
|
) -> "FunctionWithRequirements[T, P]":
|
||||||
|
return cls(python_packages=python_packages, global_imports=global_imports, func=func)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_str(
|
||||||
|
func: str, python_packages: list[str] = [], global_imports: list[Import] = []
|
||||||
|
) -> FunctionWithRequirementsStr:
|
||||||
|
return FunctionWithRequirementsStr(func=func, python_packages=python_packages, global_imports=global_imports)
|
||||||
|
|
||||||
|
# Type this based on F
|
||||||
|
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
|
||||||
|
return self.func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def with_requirements(
|
||||||
|
python_packages: list[str] = [], global_imports: list[Import] = []
|
||||||
|
) -> Callable[[Callable[P, T]], FunctionWithRequirements[T, P]]:
|
||||||
|
"""Decorate a function with package and import requirements
|
||||||
|
|
||||||
|
Args:
|
||||||
|
python_packages (List[str], optional): Packages required to function. Can include version info.. Defaults to [].
|
||||||
|
global_imports (List[Import], optional): Required imports. Defaults to [].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[[Callable[P, T]], FunctionWithRequirements[T, P]]: The decorated function
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrapper(func: Callable[P, T]) -> FunctionWithRequirements[T, P]:
|
||||||
|
func_with_reqs = FunctionWithRequirements(
|
||||||
|
python_packages=python_packages, global_imports=global_imports, func=func
|
||||||
|
)
|
||||||
|
|
||||||
|
functools.update_wrapper(func_with_reqs, func)
|
||||||
|
return func_with_reqs
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def _build_python_functions_file(
|
||||||
|
funcs: list[Union[FunctionWithRequirements[Any, P], Callable[..., Any], FunctionWithRequirementsStr]],
|
||||||
|
) -> str:
|
||||||
|
# First collect all global imports
|
||||||
|
global_imports: set[str] = set()
|
||||||
|
for func in funcs:
|
||||||
|
if isinstance(func, (FunctionWithRequirements, FunctionWithRequirementsStr)):
|
||||||
|
global_imports.update(map(_import_to_str, func.global_imports))
|
||||||
|
|
||||||
|
content = "\n".join(global_imports) + "\n\n"
|
||||||
|
|
||||||
|
for func in funcs:
|
||||||
|
content += _to_code(func) + "\n\n"
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def to_stub(func: Union[Callable[..., Any], FunctionWithRequirementsStr]) -> str:
|
||||||
|
"""Generate a stub for a function as a string
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func (Callable[..., Any]): The function to generate a stub for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The stub for the function
|
||||||
|
"""
|
||||||
|
if isinstance(func, FunctionWithRequirementsStr):
|
||||||
|
return to_stub(func._compiled_func)
|
||||||
|
|
||||||
|
content = f"def {func.__name__}{inspect.signature(func)}:\n"
|
||||||
|
docstring = func.__doc__
|
||||||
|
|
||||||
|
if docstring:
|
||||||
|
docstring = dedent(docstring)
|
||||||
|
docstring = '"""' + docstring + '"""'
|
||||||
|
docstring = indent(docstring, " ")
|
||||||
|
content += docstring + "\n"
|
||||||
|
|
||||||
|
content += " ..."
|
||||||
|
return content
|
||||||
23
mm_agents/coact/autogen/coding/jupyter/__init__.py
Normal file
23
mm_agents/coact/autogen/coding/jupyter/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Original portions of this file are derived from https://github.com/microsoft/autogen under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from .base import JupyterConnectable, JupyterConnectionInfo
|
||||||
|
from .docker_jupyter_server import DockerJupyterServer
|
||||||
|
from .embedded_ipython_code_executor import EmbeddedIPythonCodeExecutor
|
||||||
|
from .jupyter_client import JupyterClient
|
||||||
|
from .jupyter_code_executor import JupyterCodeExecutor
|
||||||
|
from .local_jupyter_server import LocalJupyterServer
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DockerJupyterServer",
|
||||||
|
"EmbeddedIPythonCodeExecutor",
|
||||||
|
"JupyterClient",
|
||||||
|
"JupyterCodeExecutor",
|
||||||
|
"JupyterConnectable",
|
||||||
|
"JupyterConnectionInfo",
|
||||||
|
"LocalJupyterServer",
|
||||||
|
]
|
||||||
36
mm_agents/coact/autogen/coding/jupyter/base.py
Normal file
36
mm_agents/coact/autogen/coding/jupyter/base.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from ...doc_utils import export_module
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@export_module("autogen.coding.jupyter")
|
||||||
|
class JupyterConnectionInfo:
|
||||||
|
"""(Experimental)"""
|
||||||
|
|
||||||
|
host: str
|
||||||
|
"""`str` - Host of the Jupyter gateway server"""
|
||||||
|
use_https: bool
|
||||||
|
"""`bool` - Whether to use HTTPS"""
|
||||||
|
port: Optional[int] = None
|
||||||
|
"""`Optional[int]` - Port of the Jupyter gateway server. If None, the default port is used"""
|
||||||
|
token: Optional[str] = None
|
||||||
|
"""`Optional[str]` - Token for authentication. If None, no token is used"""
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
@export_module("autogen.coding.jupyter")
|
||||||
|
class JupyterConnectable(Protocol):
|
||||||
|
"""(Experimental)"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def connection_info(self) -> JupyterConnectionInfo:
|
||||||
|
"""Return the connection information for this connectable."""
|
||||||
|
pass
|
||||||
167
mm_agents/coact/autogen/coding/jupyter/docker_jupyter_server.py
Normal file
167
mm_agents/coact/autogen/coding/jupyter/docker_jupyter_server.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import atexit
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
import sys
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import docker
|
||||||
|
|
||||||
|
from ...doc_utils import export_module
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 11):
|
||||||
|
from typing import Self
|
||||||
|
else:
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from ..docker_commandline_code_executor import _wait_for_ready
|
||||||
|
from .base import JupyterConnectable, JupyterConnectionInfo
|
||||||
|
from .import_utils import require_jupyter_kernel_gateway_installed
|
||||||
|
from .jupyter_client import JupyterClient
|
||||||
|
|
||||||
|
|
||||||
|
@require_jupyter_kernel_gateway_installed()
|
||||||
|
@export_module("autogen.coding.jupyter")
|
||||||
|
class DockerJupyterServer(JupyterConnectable):
|
||||||
|
DEFAULT_DOCKERFILE = """FROM quay.io/jupyter/docker-stacks-foundation
|
||||||
|
|
||||||
|
SHELL ["/bin/bash", "-o", "pipefail", "-c"]
|
||||||
|
|
||||||
|
USER ${NB_UID}
|
||||||
|
RUN mamba install --yes jupyter_kernel_gateway ipykernel && \
|
||||||
|
mamba clean --all -f -y && \
|
||||||
|
fix-permissions "${CONDA_DIR}" && \
|
||||||
|
fix-permissions "/home/${NB_USER}"
|
||||||
|
|
||||||
|
ENV TOKEN="UNSET"
|
||||||
|
CMD python -m jupyter kernelgateway --KernelGatewayApp.ip=0.0.0.0 \
|
||||||
|
--KernelGatewayApp.port=8888 \
|
||||||
|
--KernelGatewayApp.auth_token="${TOKEN}" \
|
||||||
|
--JupyterApp.answer_yes=true \
|
||||||
|
--JupyterWebsocketPersonality.list_kernels=true
|
||||||
|
|
||||||
|
EXPOSE 8888
|
||||||
|
|
||||||
|
WORKDIR "${HOME}"
|
||||||
|
"""
|
||||||
|
|
||||||
|
class GenerateToken:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
custom_image_name: Optional[str] = None,
|
||||||
|
container_name: Optional[str] = None,
|
||||||
|
auto_remove: bool = True,
|
||||||
|
stop_container: bool = True,
|
||||||
|
docker_env: dict[str, str] = {},
|
||||||
|
token: str | GenerateToken = GenerateToken(),
|
||||||
|
):
|
||||||
|
"""Start a Jupyter kernel gateway server in a Docker container.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
custom_image_name (Optional[str], optional): Custom image to use. If this is None,
|
||||||
|
then the bundled image will be built and used. The default image is based on
|
||||||
|
quay.io/jupyter/docker-stacks-foundation and extended to include jupyter_kernel_gateway
|
||||||
|
container_name (Optional[str], optional): Name of the container to start.
|
||||||
|
A name will be generated if None.
|
||||||
|
auto_remove (bool, optional): If true the Docker container will be deleted
|
||||||
|
when it is stopped.
|
||||||
|
stop_container (bool, optional): If true the container will be stopped,
|
||||||
|
either by program exit or using the context manager
|
||||||
|
docker_env (Dict[str, str], optional): Extra environment variables to pass
|
||||||
|
to the running Docker container.
|
||||||
|
token (Union[str, GenerateToken], optional): Token to use for authentication.
|
||||||
|
If GenerateToken is used, a random token will be generated. Empty string
|
||||||
|
will be unauthenticated.
|
||||||
|
"""
|
||||||
|
if container_name is None:
|
||||||
|
container_name = f"autogen-jupyterkernelgateway-{uuid.uuid4()}"
|
||||||
|
|
||||||
|
client = docker.from_env()
|
||||||
|
if custom_image_name is None:
|
||||||
|
image_name = "autogen-jupyterkernelgateway"
|
||||||
|
# Make sure the image exists
|
||||||
|
try:
|
||||||
|
client.images.get(image_name)
|
||||||
|
except docker.errors.ImageNotFound:
|
||||||
|
# Build the image
|
||||||
|
# Get this script directory
|
||||||
|
here = Path(__file__).parent
|
||||||
|
dockerfile = io.BytesIO(self.DEFAULT_DOCKERFILE.encode("utf-8"))
|
||||||
|
logging.info(f"Image {image_name} not found. Building it now.")
|
||||||
|
client.images.build(path=here, fileobj=dockerfile, tag=image_name)
|
||||||
|
logging.info(f"Image {image_name} built successfully.")
|
||||||
|
else:
|
||||||
|
image_name = custom_image_name
|
||||||
|
# Check if the image exists
|
||||||
|
try:
|
||||||
|
client.images.get(image_name)
|
||||||
|
except docker.errors.ImageNotFound:
|
||||||
|
raise ValueError(f"Custom image {image_name} does not exist")
|
||||||
|
|
||||||
|
if isinstance(token, DockerJupyterServer.GenerateToken):
|
||||||
|
self._token = secrets.token_hex(32)
|
||||||
|
else:
|
||||||
|
self._token = token
|
||||||
|
|
||||||
|
# Run the container
|
||||||
|
env = {"TOKEN": self._token}
|
||||||
|
env.update(docker_env)
|
||||||
|
container = client.containers.run(
|
||||||
|
image_name,
|
||||||
|
detach=True,
|
||||||
|
auto_remove=auto_remove,
|
||||||
|
environment=env,
|
||||||
|
publish_all_ports=True,
|
||||||
|
name=container_name,
|
||||||
|
)
|
||||||
|
_wait_for_ready(container)
|
||||||
|
container_ports = container.ports
|
||||||
|
self._port = int(container_ports["8888/tcp"][0]["HostPort"])
|
||||||
|
self._container_id = container.id
|
||||||
|
|
||||||
|
def cleanup() -> None:
|
||||||
|
try:
|
||||||
|
inner_container = client.containers.get(container.id)
|
||||||
|
inner_container.stop()
|
||||||
|
except docker.errors.NotFound:
|
||||||
|
pass
|
||||||
|
|
||||||
|
atexit.unregister(cleanup)
|
||||||
|
|
||||||
|
if stop_container:
|
||||||
|
atexit.register(cleanup)
|
||||||
|
|
||||||
|
self._cleanup_func = cleanup
|
||||||
|
self._stop_container = stop_container
|
||||||
|
|
||||||
|
@property
|
||||||
|
def connection_info(self) -> JupyterConnectionInfo:
|
||||||
|
return JupyterConnectionInfo(host="127.0.0.1", use_https=False, port=self._port, token=self._token)
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
self._cleanup_func()
|
||||||
|
|
||||||
|
def get_client(self) -> JupyterClient:
|
||||||
|
return JupyterClient(self.connection_info)
|
||||||
|
|
||||||
|
def __enter__(self) -> Self:
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(
|
||||||
|
self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
|
||||||
|
) -> None:
|
||||||
|
self.stop()
|
||||||
@@ -0,0 +1,182 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
from queue import Empty
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
from ...doc_utils import export_module
|
||||||
|
from ...import_utils import optional_import_block, require_optional_import
|
||||||
|
from ..base import CodeBlock, CodeExtractor, IPythonCodeResult
|
||||||
|
from ..markdown_code_extractor import MarkdownCodeExtractor
|
||||||
|
from .import_utils import require_jupyter_kernel_gateway_installed
|
||||||
|
|
||||||
|
with optional_import_block():
|
||||||
|
from jupyter_client import KernelManager # type: ignore[attr-defined]
|
||||||
|
from jupyter_client.kernelspec import KernelSpecManager
|
||||||
|
|
||||||
|
__all__ = ["EmbeddedIPythonCodeExecutor"]
|
||||||
|
|
||||||
|
|
||||||
|
@require_optional_import("jupyter_client", "jupyter-executor")
|
||||||
|
@require_jupyter_kernel_gateway_installed()
|
||||||
|
@export_module("autogen.coding.jupyter")
|
||||||
|
class EmbeddedIPythonCodeExecutor(BaseModel):
|
||||||
|
"""(Experimental) A code executor class that executes code statefully using an embedded
|
||||||
|
IPython kernel managed by this class.
|
||||||
|
|
||||||
|
**This will execute LLM generated code on the local machine.**
|
||||||
|
|
||||||
|
Each execution is stateful and can access variables created from previous
|
||||||
|
executions in the same session. The kernel must be installed before using
|
||||||
|
this class. The kernel can be installed using the following command:
|
||||||
|
`python -m ipykernel install --user --name {kernel_name}`
|
||||||
|
where `kernel_name` is the name of the kernel to install.
|
||||||
|
"""
|
||||||
|
|
||||||
|
timeout: int = Field(default=60, ge=1, description="The timeout for code execution.")
|
||||||
|
kernel_name: str = Field(default="python3", description="The kernel name to use. Make sure it is installed.")
|
||||||
|
output_dir: str = Field(default=".", description="The directory to save output files.")
|
||||||
|
|
||||||
|
@field_validator("output_dir")
|
||||||
|
@classmethod
|
||||||
|
def _output_dir_must_exist(cls, value: str) -> str:
|
||||||
|
if not os.path.exists(value):
|
||||||
|
raise ValueError(f"Output directory {value} does not exist.")
|
||||||
|
return value
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
# Check if the kernel is installed.
|
||||||
|
if self.kernel_name not in KernelSpecManager().find_kernel_specs():
|
||||||
|
raise ValueError(
|
||||||
|
f"Kernel {self.kernel_name} is not installed. "
|
||||||
|
"Please first install it with "
|
||||||
|
f"`python -m ipykernel install --user --name {self.kernel_name}`."
|
||||||
|
)
|
||||||
|
self._kernel_manager = KernelManager(kernel_name=self.kernel_name)
|
||||||
|
self._kernel_manager.start_kernel()
|
||||||
|
self._kernel_client = self._kernel_manager.client()
|
||||||
|
self._kernel_client.start_channels()
|
||||||
|
self._timeout = self.timeout
|
||||||
|
self._kernel_name = self.kernel_name
|
||||||
|
self._output_dir = Path(self.output_dir)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def code_extractor(self) -> CodeExtractor:
|
||||||
|
"""(Experimental) Export a code extractor that can be used by an agent."""
|
||||||
|
return MarkdownCodeExtractor()
|
||||||
|
|
||||||
|
def execute_code_blocks(self, code_blocks: list[CodeBlock]) -> IPythonCodeResult:
|
||||||
|
"""(Experimental) Execute a list of code blocks and return the result.
|
||||||
|
|
||||||
|
This method executes a list of code blocks as cells in an IPython kernel
|
||||||
|
managed by this class.
|
||||||
|
See: https://jupyter-client.readthedocs.io/en/stable/messaging.html
|
||||||
|
for the message protocol.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code_blocks (List[CodeBlock]): A list of code blocks to execute.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
IPythonCodeResult: The result of the code execution.
|
||||||
|
"""
|
||||||
|
self._kernel_client.wait_for_ready()
|
||||||
|
outputs = []
|
||||||
|
output_files = []
|
||||||
|
for code_block in code_blocks:
|
||||||
|
code = self._process_code(code_block.code)
|
||||||
|
self._kernel_client.execute(code, store_history=True)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
msg = self._kernel_client.get_iopub_msg(timeout=self._timeout)
|
||||||
|
msg_type = msg["msg_type"]
|
||||||
|
content = msg["content"]
|
||||||
|
if msg_type in ["execute_result", "display_data"]:
|
||||||
|
for data_type, data in content["data"].items():
|
||||||
|
if data_type == "text/plain":
|
||||||
|
# Output is a text.
|
||||||
|
outputs.append(data)
|
||||||
|
elif data_type.startswith("image/"):
|
||||||
|
# Output is an image.
|
||||||
|
path = self._save_image(data)
|
||||||
|
outputs.append(f"Image data saved to {path}")
|
||||||
|
output_files.append(path)
|
||||||
|
elif data_type == "text/html":
|
||||||
|
# Output is an html.
|
||||||
|
path = self._save_html(data)
|
||||||
|
outputs.append(f"HTML data saved to {path}")
|
||||||
|
output_files.append(path)
|
||||||
|
else:
|
||||||
|
# Output raw data.
|
||||||
|
outputs.append(json.dumps(data))
|
||||||
|
elif msg_type == "stream":
|
||||||
|
# Output is a text.
|
||||||
|
outputs.append(content["text"])
|
||||||
|
elif msg_type == "error":
|
||||||
|
# Output is an error.
|
||||||
|
return IPythonCodeResult(
|
||||||
|
exit_code=1,
|
||||||
|
output=f"ERROR: {content['ename']}: {content['evalue']}\n{content['traceback']}",
|
||||||
|
)
|
||||||
|
if msg_type == "status" and content["execution_state"] == "idle":
|
||||||
|
break
|
||||||
|
# handle time outs.
|
||||||
|
except Empty:
|
||||||
|
return IPythonCodeResult(
|
||||||
|
exit_code=1,
|
||||||
|
output=f"ERROR: Timeout waiting for output from code block: {code_block.code}",
|
||||||
|
)
|
||||||
|
# We return the full output.
|
||||||
|
return IPythonCodeResult(
|
||||||
|
exit_code=0, output="\n".join([str(output) for output in outputs]), output_files=output_files
|
||||||
|
)
|
||||||
|
|
||||||
|
def restart(self) -> None:
|
||||||
|
"""(Experimental) Restart a new session."""
|
||||||
|
self._kernel_client.stop_channels()
|
||||||
|
self._kernel_manager.shutdown_kernel()
|
||||||
|
self._kernel_manager = KernelManager(kernel_name=self.kernel_name)
|
||||||
|
self._kernel_manager.start_kernel()
|
||||||
|
self._kernel_client = self._kernel_manager.client()
|
||||||
|
self._kernel_client.start_channels()
|
||||||
|
|
||||||
|
def _save_image(self, image_data_base64: str) -> str:
|
||||||
|
"""Save image data to a file."""
|
||||||
|
image_data = base64.b64decode(image_data_base64)
|
||||||
|
# Randomly generate a filename.
|
||||||
|
filename = f"{uuid.uuid4().hex}.png"
|
||||||
|
path = os.path.join(self.output_dir, filename)
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
f.write(image_data)
|
||||||
|
return os.path.abspath(path)
|
||||||
|
|
||||||
|
def _save_html(self, html_data: str) -> str:
|
||||||
|
"""Save html data to a file."""
|
||||||
|
# Randomly generate a filename.
|
||||||
|
filename = f"{uuid.uuid4().hex}.html"
|
||||||
|
path = os.path.join(self.output_dir, filename)
|
||||||
|
with open(path, "w") as f:
|
||||||
|
f.write(html_data)
|
||||||
|
return os.path.abspath(path)
|
||||||
|
|
||||||
|
def _process_code(self, code: str) -> str:
|
||||||
|
"""Process code before execution."""
|
||||||
|
# Find lines that start with `! pip install` and make sure "-qqq" flag is added.
|
||||||
|
lines = code.split("\n")
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
# use regex to find lines that start with `! pip install` or `!pip install`.
|
||||||
|
match = re.search(r"^! ?pip install", line)
|
||||||
|
if match is not None and "-qqq" not in line:
|
||||||
|
lines[i] = line.replace(match.group(0), match.group(0) + " -qqq")
|
||||||
|
return "\n".join(lines)
|
||||||
82
mm_agents/coact/autogen/coding/jupyter/import_utils.py
Normal file
82
mm_agents/coact/autogen/coding/jupyter/import_utils.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
from functools import lru_cache
|
||||||
|
from logging import getLogger
|
||||||
|
from typing import Callable, TypeVar
|
||||||
|
|
||||||
|
from ...import_utils import patch_object
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
__all__ = ["require_jupyter_kernel_gateway_installed", "skip_on_missing_jupyter_kernel_gateway"]
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def is_jupyter_kernel_gateway_installed() -> bool:
|
||||||
|
"""Check if jupyter-kernel-gateway is installed."""
|
||||||
|
try:
|
||||||
|
subprocess.run(
|
||||||
|
["jupyter", "kernelgateway", "--version"],
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||||
|
logger.warning(
|
||||||
|
"jupyter-kernel-gateway is required for JupyterCodeExecutor, please install it with `pip install ag2[jupyter-executor]`"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def require_jupyter_kernel_gateway_installed() -> Callable[[T], T]:
|
||||||
|
"""Decorator that checks if jupyter-kernel-gateway is installed before function execution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[[T], T]: A decorator function that either:
|
||||||
|
- Returns the original function unchanged if jupyter-kernel-gateway is installed
|
||||||
|
- Returns a patched version of the function that will raise a helpful error indicating the missing dependency when called
|
||||||
|
"""
|
||||||
|
if is_jupyter_kernel_gateway_installed():
|
||||||
|
|
||||||
|
def decorator(o: T) -> T:
|
||||||
|
return o
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def decorator(o: T) -> T:
|
||||||
|
return patch_object(o, missing_modules={}, dep_target="jupyter-executor")
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def skip_on_missing_jupyter_kernel_gateway() -> Callable[[T], T]:
|
||||||
|
"""Decorator to skip a test if an optional module is missing"""
|
||||||
|
# Add pytest.mark.jupyter_executor decorator
|
||||||
|
mark_name = "jupyter_executor"
|
||||||
|
|
||||||
|
if is_jupyter_kernel_gateway_installed():
|
||||||
|
|
||||||
|
def decorator(o: T) -> T:
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytest_mark_o = getattr(pytest.mark, mark_name)(o)
|
||||||
|
return pytest_mark_o # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def decorator(o: T) -> T:
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytest_mark_o = getattr(pytest.mark, mark_name)(o)
|
||||||
|
return pytest.mark.skip( # type: ignore[return-value,no-any-return]
|
||||||
|
reason="jupyter-kernel-gateway is required for JupyterCodeExecutor, please install it with `pip install ag2[jupyter-executor]`"
|
||||||
|
)(pytest_mark_o)
|
||||||
|
|
||||||
|
return decorator
|
||||||
231
mm_agents/coact/autogen/coding/jupyter/jupyter_client.py
Normal file
231
mm_agents/coact/autogen/coding/jupyter/jupyter_client.py
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
|
from ...doc_utils import export_module
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 11):
|
||||||
|
from typing import Self
|
||||||
|
else:
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from requests.adapters import HTTPAdapter, Retry
|
||||||
|
|
||||||
|
from ...import_utils import optional_import_block, require_optional_import
|
||||||
|
from .base import JupyterConnectionInfo
|
||||||
|
|
||||||
|
with optional_import_block():
|
||||||
|
import websocket
|
||||||
|
from websocket import WebSocket
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.coding.jupyter")
|
||||||
|
class JupyterClient:
|
||||||
|
def __init__(self, connection_info: JupyterConnectionInfo):
|
||||||
|
"""(Experimental) A client for communicating with a Jupyter gateway server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection_info (JupyterConnectionInfo): Connection information
|
||||||
|
"""
|
||||||
|
self._connection_info = connection_info
|
||||||
|
self._session = requests.Session()
|
||||||
|
retries = Retry(total=5, backoff_factor=0.1)
|
||||||
|
self._session.mount("http://", HTTPAdapter(max_retries=retries))
|
||||||
|
|
||||||
|
def _get_headers(self) -> dict[str, str]:
|
||||||
|
if self._connection_info.token is None:
|
||||||
|
return {}
|
||||||
|
return {"Authorization": f"token {self._connection_info.token}"}
|
||||||
|
|
||||||
|
def _get_api_base_url(self) -> str:
|
||||||
|
protocol = "https" if self._connection_info.use_https else "http"
|
||||||
|
port = f":{self._connection_info.port}" if self._connection_info.port else ""
|
||||||
|
return f"{protocol}://{self._connection_info.host}{port}"
|
||||||
|
|
||||||
|
def _get_ws_base_url(self) -> str:
|
||||||
|
port = f":{self._connection_info.port}" if self._connection_info.port else ""
|
||||||
|
return f"ws://{self._connection_info.host}{port}"
|
||||||
|
|
||||||
|
def list_kernel_specs(self) -> dict[str, dict[str, str]]:
|
||||||
|
response = self._session.get(f"{self._get_api_base_url()}/api/kernelspecs", headers=self._get_headers())
|
||||||
|
return cast(dict[str, dict[str, str]], response.json())
|
||||||
|
|
||||||
|
def list_kernels(self) -> list[dict[str, str]]:
|
||||||
|
response = self._session.get(f"{self._get_api_base_url()}/api/kernels", headers=self._get_headers())
|
||||||
|
return cast(list[dict[str, str]], response.json())
|
||||||
|
|
||||||
|
def start_kernel(self, kernel_spec_name: str) -> str:
|
||||||
|
"""Start a new kernel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kernel_spec_name (str): Name of the kernel spec to start
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: ID of the started kernel
|
||||||
|
"""
|
||||||
|
response = self._session.post(
|
||||||
|
f"{self._get_api_base_url()}/api/kernels",
|
||||||
|
headers=self._get_headers(),
|
||||||
|
json={"name": kernel_spec_name},
|
||||||
|
)
|
||||||
|
return cast(str, response.json()["id"])
|
||||||
|
|
||||||
|
def delete_kernel(self, kernel_id: str) -> None:
|
||||||
|
response = self._session.delete(
|
||||||
|
f"{self._get_api_base_url()}/api/kernels/{kernel_id}", headers=self._get_headers()
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
def restart_kernel(self, kernel_id: str) -> None:
|
||||||
|
response = self._session.post(
|
||||||
|
f"{self._get_api_base_url()}/api/kernels/{kernel_id}/restart", headers=self._get_headers()
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
@require_optional_import("websocket", "jupyter-executor")
|
||||||
|
def get_kernel_client(self, kernel_id: str) -> JupyterKernelClient:
|
||||||
|
ws_url = f"{self._get_ws_base_url()}/api/kernels/{kernel_id}/channels"
|
||||||
|
ws = websocket.create_connection(ws_url, header=self._get_headers())
|
||||||
|
return JupyterKernelClient(ws)
|
||||||
|
|
||||||
|
|
||||||
|
@require_optional_import("websocket", "jupyter-executor")
|
||||||
|
class JupyterKernelClient:
|
||||||
|
"""(Experimental) A client for communicating with a Jupyter kernel."""
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExecutionResult:
|
||||||
|
@dataclass
|
||||||
|
class DataItem:
|
||||||
|
mime_type: str
|
||||||
|
data: str
|
||||||
|
|
||||||
|
is_ok: bool
|
||||||
|
output: str
|
||||||
|
data_items: list[DataItem]
|
||||||
|
|
||||||
|
def __init__(self, websocket: WebSocket): # type: ignore[no-any-unimported]
|
||||||
|
self._session_id: str = uuid.uuid4().hex
|
||||||
|
self._websocket: WebSocket = websocket # type: ignore[no-any-unimported]
|
||||||
|
|
||||||
|
def __enter__(self) -> Self:
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(
|
||||||
|
self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
|
||||||
|
) -> None:
|
||||||
|
self.stop()
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
self._websocket.close()
|
||||||
|
|
||||||
|
def _send_message(self, *, content: dict[str, Any], channel: str, message_type: str) -> str:
|
||||||
|
timestamp = datetime.datetime.now().isoformat()
|
||||||
|
message_id = uuid.uuid4().hex
|
||||||
|
message = {
|
||||||
|
"header": {
|
||||||
|
"username": "autogen",
|
||||||
|
"version": "5.0",
|
||||||
|
"session": self._session_id,
|
||||||
|
"msg_id": message_id,
|
||||||
|
"msg_type": message_type,
|
||||||
|
"date": timestamp,
|
||||||
|
},
|
||||||
|
"parent_header": {},
|
||||||
|
"channel": channel,
|
||||||
|
"content": content,
|
||||||
|
"metadata": {},
|
||||||
|
"buffers": {},
|
||||||
|
}
|
||||||
|
self._websocket.send_text(json.dumps(message))
|
||||||
|
return message_id
|
||||||
|
|
||||||
|
def _receive_message(self, timeout_seconds: Optional[float]) -> Optional[dict[str, Any]]:
|
||||||
|
self._websocket.settimeout(timeout_seconds)
|
||||||
|
try:
|
||||||
|
data = self._websocket.recv()
|
||||||
|
if isinstance(data, bytes):
|
||||||
|
data = data.decode("utf-8")
|
||||||
|
return cast(dict[str, Any], json.loads(data))
|
||||||
|
except websocket.WebSocketTimeoutException:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def wait_for_ready(self, timeout_seconds: Optional[float] = None) -> bool:
|
||||||
|
message_id = self._send_message(content={}, channel="shell", message_type="kernel_info_request")
|
||||||
|
while True:
|
||||||
|
message = self._receive_message(timeout_seconds)
|
||||||
|
# This means we timed out with no new messages.
|
||||||
|
if message is None:
|
||||||
|
return False
|
||||||
|
if (
|
||||||
|
message.get("parent_header", {}).get("msg_id") == message_id
|
||||||
|
and message["msg_type"] == "kernel_info_reply"
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def execute(self, code: str, timeout_seconds: Optional[float] = None) -> ExecutionResult:
|
||||||
|
message_id = self._send_message(
|
||||||
|
content={
|
||||||
|
"code": code,
|
||||||
|
"silent": False,
|
||||||
|
"store_history": True,
|
||||||
|
"user_expressions": {},
|
||||||
|
"allow_stdin": False,
|
||||||
|
"stop_on_error": True,
|
||||||
|
},
|
||||||
|
channel="shell",
|
||||||
|
message_type="execute_request",
|
||||||
|
)
|
||||||
|
|
||||||
|
text_output = []
|
||||||
|
data_output = []
|
||||||
|
while True:
|
||||||
|
message = self._receive_message(timeout_seconds)
|
||||||
|
if message is None:
|
||||||
|
return JupyterKernelClient.ExecutionResult(
|
||||||
|
is_ok=False, output="ERROR: Timeout waiting for output from code block.", data_items=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ignore messages that are not for this execution.
|
||||||
|
if message.get("parent_header", {}).get("msg_id") != message_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
msg_type = message["msg_type"]
|
||||||
|
content = message["content"]
|
||||||
|
if msg_type in ["execute_result", "display_data"]:
|
||||||
|
for data_type, data in content["data"].items():
|
||||||
|
if data_type == "text/plain":
|
||||||
|
text_output.append(data)
|
||||||
|
elif data_type.startswith("image/") or data_type == "text/html":
|
||||||
|
data_output.append(self.ExecutionResult.DataItem(mime_type=data_type, data=data))
|
||||||
|
else:
|
||||||
|
text_output.append(json.dumps(data))
|
||||||
|
elif msg_type == "stream":
|
||||||
|
text_output.append(content["text"])
|
||||||
|
elif msg_type == "error":
|
||||||
|
# Output is an error.
|
||||||
|
return JupyterKernelClient.ExecutionResult(
|
||||||
|
is_ok=False,
|
||||||
|
output=f"ERROR: {content['ename']}: {content['evalue']}\n{content['traceback']}",
|
||||||
|
data_items=[],
|
||||||
|
)
|
||||||
|
if msg_type == "status" and content["execution_state"] == "idle":
|
||||||
|
break
|
||||||
|
|
||||||
|
return JupyterKernelClient.ExecutionResult(
|
||||||
|
is_ok=True, output="\n".join([str(output) for output in text_output]), data_items=data_output
|
||||||
|
)
|
||||||
160
mm_agents/coact/autogen/coding/jupyter/jupyter_code_executor.py
Normal file
160
mm_agents/coact/autogen/coding/jupyter/jupyter_code_executor.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from ...doc_utils import export_module
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 11):
|
||||||
|
from typing import Self
|
||||||
|
else:
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from ..base import CodeBlock, CodeExecutor, CodeExtractor, IPythonCodeResult
|
||||||
|
from ..markdown_code_extractor import MarkdownCodeExtractor
|
||||||
|
from ..utils import silence_pip
|
||||||
|
from .base import JupyterConnectable, JupyterConnectionInfo
|
||||||
|
from .jupyter_client import JupyterClient
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.coding.jupyter")
|
||||||
|
class JupyterCodeExecutor(CodeExecutor):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
jupyter_server: Union[JupyterConnectable, JupyterConnectionInfo],
|
||||||
|
kernel_name: str = "python3",
|
||||||
|
timeout: int = 60,
|
||||||
|
output_dir: Union[Path, str] = Path(),
|
||||||
|
):
|
||||||
|
"""(Experimental) A code executor class that executes code statefully using
|
||||||
|
a Jupyter server supplied to this class.
|
||||||
|
|
||||||
|
Each execution is stateful and can access variables created from previous
|
||||||
|
executions in the same session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
jupyter_server (Union[JupyterConnectable, JupyterConnectionInfo]): The Jupyter server to use.
|
||||||
|
timeout (int): The timeout for code execution, by default 60.
|
||||||
|
kernel_name (str): The kernel name to use. Make sure it is installed.
|
||||||
|
By default, it is "python3".
|
||||||
|
output_dir (str): The directory to save output files, by default ".".
|
||||||
|
"""
|
||||||
|
if timeout < 1:
|
||||||
|
raise ValueError("Timeout must be greater than or equal to 1.")
|
||||||
|
|
||||||
|
if isinstance(output_dir, str):
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
|
||||||
|
if not output_dir.exists():
|
||||||
|
raise ValueError(f"Output directory {output_dir} does not exist.")
|
||||||
|
|
||||||
|
if isinstance(jupyter_server, JupyterConnectable):
|
||||||
|
self._connection_info = jupyter_server.connection_info
|
||||||
|
elif isinstance(jupyter_server, JupyterConnectionInfo):
|
||||||
|
self._connection_info = jupyter_server
|
||||||
|
else:
|
||||||
|
raise ValueError("jupyter_server must be a JupyterConnectable or JupyterConnectionInfo.")
|
||||||
|
|
||||||
|
self._jupyter_client = JupyterClient(self._connection_info)
|
||||||
|
available_kernels = self._jupyter_client.list_kernel_specs()
|
||||||
|
if kernel_name not in available_kernels["kernelspecs"]:
|
||||||
|
raise ValueError(f"Kernel {kernel_name} is not installed.")
|
||||||
|
|
||||||
|
self._kernel_id = self._jupyter_client.start_kernel(kernel_name)
|
||||||
|
self._kernel_name = kernel_name
|
||||||
|
self._jupyter_kernel_client = self._jupyter_client.get_kernel_client(self._kernel_id)
|
||||||
|
self._timeout = timeout
|
||||||
|
self._output_dir = output_dir
|
||||||
|
|
||||||
|
@property
|
||||||
|
def code_extractor(self) -> CodeExtractor:
|
||||||
|
"""(Experimental) Export a code extractor that can be used by an agent."""
|
||||||
|
return MarkdownCodeExtractor()
|
||||||
|
|
||||||
|
def execute_code_blocks(self, code_blocks: list[CodeBlock]) -> IPythonCodeResult:
|
||||||
|
"""(Experimental) Execute a list of code blocks and return the result.
|
||||||
|
|
||||||
|
This method executes a list of code blocks as cells in the Jupyter kernel.
|
||||||
|
See: https://jupyter-client.readthedocs.io/en/stable/messaging.html
|
||||||
|
for the message protocol.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code_blocks (List[CodeBlock]): A list of code blocks to execute.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
IPythonCodeResult: The result of the code execution.
|
||||||
|
"""
|
||||||
|
self._jupyter_kernel_client.wait_for_ready()
|
||||||
|
outputs = []
|
||||||
|
output_files = []
|
||||||
|
for code_block in code_blocks:
|
||||||
|
code = silence_pip(code_block.code, code_block.language)
|
||||||
|
result = self._jupyter_kernel_client.execute(code, timeout_seconds=self._timeout)
|
||||||
|
if result.is_ok:
|
||||||
|
outputs.append(result.output)
|
||||||
|
for data in result.data_items:
|
||||||
|
if data.mime_type == "image/png":
|
||||||
|
path = self._save_image(data.data)
|
||||||
|
outputs.append(f"Image data saved to {path}")
|
||||||
|
output_files.append(path)
|
||||||
|
elif data.mime_type == "text/html":
|
||||||
|
path = self._save_html(data.data)
|
||||||
|
outputs.append(f"HTML data saved to {path}")
|
||||||
|
output_files.append(path)
|
||||||
|
else:
|
||||||
|
outputs.append(json.dumps(data.data))
|
||||||
|
else:
|
||||||
|
return IPythonCodeResult(
|
||||||
|
exit_code=1,
|
||||||
|
output=f"ERROR: {result.output}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return IPythonCodeResult(
|
||||||
|
exit_code=0, output="\n".join([str(output) for output in outputs]), output_files=output_files
|
||||||
|
)
|
||||||
|
|
||||||
|
def restart(self) -> None:
|
||||||
|
"""(Experimental) Restart a new session."""
|
||||||
|
self._jupyter_client.restart_kernel(self._kernel_id)
|
||||||
|
self._jupyter_kernel_client = self._jupyter_client.get_kernel_client(self._kernel_id)
|
||||||
|
|
||||||
|
def _save_image(self, image_data_base64: str) -> str:
|
||||||
|
"""Save image data to a file."""
|
||||||
|
image_data = base64.b64decode(image_data_base64)
|
||||||
|
# Randomly generate a filename.
|
||||||
|
filename = f"{uuid.uuid4().hex}.png"
|
||||||
|
path = os.path.join(self._output_dir, filename)
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
f.write(image_data)
|
||||||
|
return os.path.abspath(path)
|
||||||
|
|
||||||
|
def _save_html(self, html_data: str) -> str:
|
||||||
|
"""Save html data to a file."""
|
||||||
|
# Randomly generate a filename.
|
||||||
|
filename = f"{uuid.uuid4().hex}.html"
|
||||||
|
path = os.path.join(self._output_dir, filename)
|
||||||
|
with open(path, "w") as f:
|
||||||
|
f.write(html_data)
|
||||||
|
return os.path.abspath(path)
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Stop the kernel."""
|
||||||
|
self._jupyter_client.delete_kernel(self._kernel_id)
|
||||||
|
|
||||||
|
def __enter__(self) -> Self:
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(
|
||||||
|
self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
|
||||||
|
) -> None:
|
||||||
|
self.stop()
|
||||||
172
mm_agents/coact/autogen/coding/jupyter/local_jupyter_server.py
Normal file
172
mm_agents/coact/autogen/coding/jupyter/local_jupyter_server.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import atexit
|
||||||
|
import json
|
||||||
|
import secrets
|
||||||
|
import signal
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ...doc_utils import export_module
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 11):
|
||||||
|
from typing import Self
|
||||||
|
else:
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from .base import JupyterConnectable, JupyterConnectionInfo
|
||||||
|
from .import_utils import require_jupyter_kernel_gateway_installed
|
||||||
|
from .jupyter_client import JupyterClient
|
||||||
|
|
||||||
|
|
||||||
|
@require_jupyter_kernel_gateway_installed()
|
||||||
|
@export_module("autogen.coding.jupyter")
|
||||||
|
class LocalJupyterServer(JupyterConnectable):
|
||||||
|
class GenerateToken:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ip: str = "127.0.0.1",
|
||||||
|
port: Optional[int] = None,
|
||||||
|
token: str | GenerateToken = GenerateToken(),
|
||||||
|
log_file: str = "jupyter_gateway.log",
|
||||||
|
log_level: str = "INFO",
|
||||||
|
log_max_bytes: int = 1048576,
|
||||||
|
log_backup_count: int = 3,
|
||||||
|
):
|
||||||
|
"""Runs a Jupyter Kernel Gateway server locally.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ip (str, optional): IP address to bind to. Defaults to "127.0.0.1".
|
||||||
|
port (Optional[int], optional): Port to use, if None it automatically selects a port. Defaults to None.
|
||||||
|
token (Union[str, GenerateToken], optional): Token to use for Jupyter server. By default will generate a token. Using None will use no token for authentication. Defaults to GenerateToken().
|
||||||
|
log_file (str, optional): File for Jupyter Kernel Gateway logs. Defaults to "jupyter_gateway.log".
|
||||||
|
log_level (str, optional): Level for Jupyter Kernel Gateway logs. Defaults to "INFO".
|
||||||
|
log_max_bytes (int, optional): Max logfile size. Defaults to 1048576.
|
||||||
|
log_backup_count (int, optional): Number of backups for rotating log. Defaults to 3.
|
||||||
|
"""
|
||||||
|
# Remove as soon as https://github.com/jupyter-server/kernel_gateway/issues/398 is fixed
|
||||||
|
if sys.platform == "win32":
|
||||||
|
raise ValueError("LocalJupyterServer is not supported on Windows due to kernelgateway bug.")
|
||||||
|
|
||||||
|
# Check Jupyter gateway server is installed
|
||||||
|
try:
|
||||||
|
subprocess.run(
|
||||||
|
[sys.executable, "-m", "jupyter", "kernelgateway", "--version"],
|
||||||
|
check=True,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
raise ValueError(
|
||||||
|
"Jupyter gateway server is not installed. Please install it with `pip install jupyter_kernel_gateway`."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ip: str = ip
|
||||||
|
|
||||||
|
if isinstance(token, LocalJupyterServer.GenerateToken):
|
||||||
|
token = secrets.token_hex(32)
|
||||||
|
|
||||||
|
self.token: str = token
|
||||||
|
self._subprocess: subprocess.Popen[str]
|
||||||
|
logging_config = {
|
||||||
|
"handlers": {
|
||||||
|
"file": {
|
||||||
|
"class": "logging.handlers.RotatingFileHandler",
|
||||||
|
"level": log_level,
|
||||||
|
"maxBytes": log_max_bytes,
|
||||||
|
"backupCount": log_backup_count,
|
||||||
|
"filename": log_file,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"loggers": {"KernelGatewayApp": {"level": log_level, "handlers": ["file", "console"]}},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Run Jupyter gateway server with detached subprocess
|
||||||
|
args = [
|
||||||
|
sys.executable,
|
||||||
|
"-m",
|
||||||
|
"jupyter",
|
||||||
|
"kernelgateway",
|
||||||
|
"--KernelGatewayApp.ip",
|
||||||
|
ip,
|
||||||
|
"--KernelGatewayApp.auth_token",
|
||||||
|
token,
|
||||||
|
"--JupyterApp.answer_yes",
|
||||||
|
"true",
|
||||||
|
"--JupyterApp.logging_config",
|
||||||
|
json.dumps(logging_config),
|
||||||
|
"--JupyterWebsocketPersonality.list_kernels",
|
||||||
|
"true",
|
||||||
|
]
|
||||||
|
if port is not None:
|
||||||
|
args.extend(["--KernelGatewayApp.port", str(port)])
|
||||||
|
args.extend(["--KernelGatewayApp.port_retries", "0"])
|
||||||
|
self._subprocess = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||||
|
|
||||||
|
# Satisfy mypy, we know this is not None because we passed PIPE
|
||||||
|
assert self._subprocess.stderr is not None
|
||||||
|
# Read stderr until we see "is available at" or the process has exited with an error
|
||||||
|
stderr = ""
|
||||||
|
while True:
|
||||||
|
result = self._subprocess.poll()
|
||||||
|
if result is not None:
|
||||||
|
stderr += self._subprocess.stderr.read()
|
||||||
|
raise ValueError(f"Jupyter gateway server failed to start with exit code: {result}. stderr:\n{stderr}")
|
||||||
|
line = self._subprocess.stderr.readline()
|
||||||
|
stderr += line
|
||||||
|
|
||||||
|
if "ERROR:" in line:
|
||||||
|
error_info = line.split("ERROR:")[1]
|
||||||
|
raise ValueError(f"Jupyter gateway server failed to start. {error_info}")
|
||||||
|
|
||||||
|
if "is available at" in line:
|
||||||
|
# We need to extract what port it settled on
|
||||||
|
# Example output:
|
||||||
|
# Jupyter Kernel Gateway 3.0.0 is available at http://127.0.0.1:8890
|
||||||
|
if port is None:
|
||||||
|
port = int(line.split(":")[-1])
|
||||||
|
self.port: int = port
|
||||||
|
|
||||||
|
break
|
||||||
|
|
||||||
|
# Poll the subprocess to check if it is still running
|
||||||
|
result = self._subprocess.poll()
|
||||||
|
if result is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Jupyter gateway server failed to start. Please check the logs ({log_file}) for more information."
|
||||||
|
)
|
||||||
|
|
||||||
|
atexit.register(self.stop)
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
if self._subprocess.poll() is None:
|
||||||
|
if sys.platform == "win32":
|
||||||
|
self._subprocess.send_signal(signal.CTRL_C_EVENT)
|
||||||
|
else:
|
||||||
|
self._subprocess.send_signal(signal.SIGINT)
|
||||||
|
self._subprocess.wait()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def connection_info(self) -> JupyterConnectionInfo:
|
||||||
|
return JupyterConnectionInfo(host=self.ip, use_https=False, port=self.port, token=self.token)
|
||||||
|
|
||||||
|
def get_client(self) -> JupyterClient:
|
||||||
|
return JupyterClient(self.connection_info)
|
||||||
|
|
||||||
|
def __enter__(self) -> Self:
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(
|
||||||
|
self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
|
||||||
|
) -> None:
|
||||||
|
self.stop()
|
||||||
@@ -0,0 +1,405 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
from hashlib import md5
|
||||||
|
from pathlib import Path
|
||||||
|
from string import Template
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import Any, Callable, ClassVar, Optional, Union
|
||||||
|
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
|
from ..code_utils import PYTHON_VARIANTS, TIMEOUT_MSG, WIN32, _cmd
|
||||||
|
from ..doc_utils import export_module
|
||||||
|
from .base import CodeBlock, CodeExecutor, CodeExtractor, CommandLineCodeResult
|
||||||
|
from .func_with_reqs import (
|
||||||
|
FunctionWithRequirements,
|
||||||
|
FunctionWithRequirementsStr,
|
||||||
|
_build_python_functions_file,
|
||||||
|
to_stub,
|
||||||
|
)
|
||||||
|
from .markdown_code_extractor import MarkdownCodeExtractor
|
||||||
|
from .utils import _get_file_name_from_content, silence_pip
|
||||||
|
|
||||||
|
__all__ = ("LocalCommandLineCodeExecutor",)
|
||||||
|
|
||||||
|
A = ParamSpec("A")
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.coding")
|
||||||
|
class LocalCommandLineCodeExecutor(CodeExecutor):
|
||||||
|
SUPPORTED_LANGUAGES: ClassVar[list[str]] = [
|
||||||
|
"bash",
|
||||||
|
"shell",
|
||||||
|
"sh",
|
||||||
|
"pwsh",
|
||||||
|
"powershell",
|
||||||
|
"ps1",
|
||||||
|
"python",
|
||||||
|
"javascript",
|
||||||
|
"html",
|
||||||
|
"css",
|
||||||
|
]
|
||||||
|
DEFAULT_EXECUTION_POLICY: ClassVar[dict[str, bool]] = {
|
||||||
|
"bash": True,
|
||||||
|
"shell": True,
|
||||||
|
"sh": True,
|
||||||
|
"pwsh": True,
|
||||||
|
"powershell": True,
|
||||||
|
"ps1": True,
|
||||||
|
"python": True,
|
||||||
|
"javascript": False,
|
||||||
|
"html": False,
|
||||||
|
"css": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
FUNCTION_PROMPT_TEMPLATE: ClassVar[
|
||||||
|
str
|
||||||
|
] = """You have access to the following user defined functions. They can be accessed from the module called `$module_name` by their function names.
|
||||||
|
|
||||||
|
For example, if there was a function called `foo` you could import it by writing `from $module_name import foo`
|
||||||
|
|
||||||
|
$functions"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
timeout: int = 60,
|
||||||
|
virtual_env_context: Optional[SimpleNamespace] = None,
|
||||||
|
work_dir: Union[Path, str] = Path(),
|
||||||
|
functions: list[Union[FunctionWithRequirements[Any, A], Callable[..., Any], FunctionWithRequirementsStr]] = [],
|
||||||
|
functions_module: str = "functions",
|
||||||
|
execution_policies: Optional[dict[str, bool]] = None,
|
||||||
|
):
|
||||||
|
"""(Experimental) A code executor class that executes or saves LLM generated code a local command line
|
||||||
|
environment.
|
||||||
|
|
||||||
|
**This will execute or save LLM generated code on the local machine.**
|
||||||
|
|
||||||
|
Each code block is saved as a file in the working directory. Depending on the execution policy,
|
||||||
|
the code may be executed in a separate process.
|
||||||
|
The code blocks are executed or save in the order they are received.
|
||||||
|
Command line code is sanitized against a list of dangerous commands to prevent self-destructive commands from being executed,
|
||||||
|
which could potentially affect the user's environment. Supported languages include Python, shell scripts (bash, shell, sh),
|
||||||
|
PowerShell (pwsh, powershell, ps1), HTML, CSS, and JavaScript.
|
||||||
|
Execution policies determine whether each language's code blocks are executed or saved only.
|
||||||
|
|
||||||
|
## Execution with a Python virtual environment
|
||||||
|
A python virtual env can be used to execute code and install dependencies. This has the added benefit of not polluting the
|
||||||
|
base environment with unwanted modules.
|
||||||
|
```python
|
||||||
|
from autogen.code_utils import create_virtual_env
|
||||||
|
from autogen.coding import LocalCommandLineCodeExecutor
|
||||||
|
|
||||||
|
venv_dir = ".venv"
|
||||||
|
venv_context = create_virtual_env(venv_dir)
|
||||||
|
|
||||||
|
executor = LocalCommandLineCodeExecutor(virtual_env_context=venv_context)
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout (int): The timeout for code execution, default is 60 seconds.
|
||||||
|
virtual_env_context (Optional[SimpleNamespace]): The virtual environment context to use.
|
||||||
|
work_dir (Union[Path, str]): The working directory for code execution, defaults to the current directory.
|
||||||
|
functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any], FunctionWithRequirementsStr]]): A list of callable functions available to the executor.
|
||||||
|
functions_module (str): The module name under which functions are accessible.
|
||||||
|
execution_policies (Optional[Dict[str, bool]]): A dictionary mapping languages to execution policies (True for execution, False for saving only). Defaults to class-wide DEFAULT_EXECUTION_POLICY.
|
||||||
|
"""
|
||||||
|
if timeout < 1:
|
||||||
|
raise ValueError("Timeout must be greater than or equal to 1.")
|
||||||
|
|
||||||
|
if isinstance(work_dir, str):
|
||||||
|
work_dir = Path(work_dir)
|
||||||
|
|
||||||
|
if not functions_module.isidentifier():
|
||||||
|
raise ValueError("Module name must be a valid Python identifier")
|
||||||
|
|
||||||
|
self._functions_module = functions_module
|
||||||
|
|
||||||
|
work_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
self._timeout = timeout
|
||||||
|
self._work_dir: Path = work_dir
|
||||||
|
self._virtual_env_context: Optional[SimpleNamespace] = virtual_env_context
|
||||||
|
|
||||||
|
self._functions = functions
|
||||||
|
# Setup could take some time so we intentionally wait for the first code block to do it.
|
||||||
|
if len(functions) > 0:
|
||||||
|
self._setup_functions_complete = False
|
||||||
|
else:
|
||||||
|
self._setup_functions_complete = True
|
||||||
|
|
||||||
|
self.execution_policies = self.DEFAULT_EXECUTION_POLICY.copy()
|
||||||
|
if execution_policies is not None:
|
||||||
|
self.execution_policies.update(execution_policies)
|
||||||
|
|
||||||
|
def format_functions_for_prompt(self, prompt_template: str = FUNCTION_PROMPT_TEMPLATE) -> str:
|
||||||
|
"""(Experimental) Format the functions for a prompt.
|
||||||
|
|
||||||
|
The template includes two variables:
|
||||||
|
- `$module_name`: The module name.
|
||||||
|
- `$functions`: The functions formatted as stubs with two newlines between each function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_template (str): The prompt template. Default is the class default.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The formatted prompt.
|
||||||
|
"""
|
||||||
|
template = Template(prompt_template)
|
||||||
|
return template.substitute(
|
||||||
|
module_name=self._functions_module,
|
||||||
|
functions="\n\n".join([to_stub(func) for func in self._functions]),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def functions_module(self) -> str:
|
||||||
|
"""(Experimental) The module name for the functions."""
|
||||||
|
return self._functions_module
|
||||||
|
|
||||||
|
@property
|
||||||
|
def functions(
|
||||||
|
self,
|
||||||
|
) -> list[Union[FunctionWithRequirements[Any, A], Callable[..., Any], FunctionWithRequirementsStr]]:
|
||||||
|
"""(Experimental) The functions that are available to the code executor."""
|
||||||
|
return self._functions
|
||||||
|
|
||||||
|
@property
|
||||||
|
def timeout(self) -> int:
|
||||||
|
"""(Experimental) The timeout for code execution."""
|
||||||
|
return self._timeout
|
||||||
|
|
||||||
|
@property
|
||||||
|
def work_dir(self) -> Path:
|
||||||
|
"""(Experimental) The working directory for the code execution."""
|
||||||
|
return self._work_dir
|
||||||
|
|
||||||
|
@property
|
||||||
|
def code_extractor(self) -> CodeExtractor:
|
||||||
|
"""(Experimental) Export a code extractor that can be used by an agent."""
|
||||||
|
return MarkdownCodeExtractor()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def sanitize_command(lang: str, code: str) -> None:
|
||||||
|
"""Sanitize the code block to prevent dangerous commands.
|
||||||
|
This approach acknowledges that while Docker or similar
|
||||||
|
containerization/sandboxing technologies provide a robust layer of security,
|
||||||
|
not all users may have Docker installed or may choose not to use it.
|
||||||
|
Therefore, having a baseline level of protection helps mitigate risks for users who,
|
||||||
|
either out of choice or necessity, run code outside of a sandboxed environment.
|
||||||
|
"""
|
||||||
|
dangerous_patterns = [
|
||||||
|
(r"\brm\s+-rf\b", "Use of 'rm -rf' command is not allowed."),
|
||||||
|
(r"\bmv\b.*?\s+/dev/null", "Moving files to /dev/null is not allowed."),
|
||||||
|
(r"\bdd\b", "Use of 'dd' command is not allowed."),
|
||||||
|
(r">\s*/dev/sd[a-z][1-9]?", "Overwriting disk blocks directly is not allowed."),
|
||||||
|
(r":\(\)\{\s*:\|\:&\s*\};:", "Fork bombs are not allowed."),
|
||||||
|
]
|
||||||
|
if lang in ["bash", "shell", "sh"]:
|
||||||
|
for pattern, message in dangerous_patterns:
|
||||||
|
if re.search(pattern, code):
|
||||||
|
raise ValueError(f"Potentially dangerous command detected: {message}")
|
||||||
|
|
||||||
|
def _setup_functions(self) -> None:
|
||||||
|
func_file_content = _build_python_functions_file(self._functions)
|
||||||
|
func_file = self._work_dir / f"{self._functions_module}.py"
|
||||||
|
func_file.write_text(func_file_content)
|
||||||
|
|
||||||
|
# Collect requirements
|
||||||
|
lists_of_packages = [x.python_packages for x in self._functions if isinstance(x, FunctionWithRequirements)]
|
||||||
|
flattened_packages = [item for sublist in lists_of_packages for item in sublist]
|
||||||
|
required_packages = list(set(flattened_packages))
|
||||||
|
if len(required_packages) > 0:
|
||||||
|
logging.info("Ensuring packages are installed in executor.")
|
||||||
|
py_executable = self._virtual_env_context.env_exe if self._virtual_env_context else sys.executable
|
||||||
|
cmd = [py_executable, "-m", "pip", "install"] + required_packages
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
cwd=self._work_dir,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=float(self._timeout),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
except subprocess.TimeoutExpired as e:
|
||||||
|
raise ValueError("Pip install timed out") from e
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise ValueError(f"Pip install failed. {result.stdout}, {result.stderr}")
|
||||||
|
# Attempt to load the function file to check for syntax errors, imports etc.
|
||||||
|
exec_result = self._execute_code_dont_check_setup([CodeBlock(code=func_file_content, language="python")])
|
||||||
|
if exec_result.exit_code != 0:
|
||||||
|
raise ValueError(f"Functions failed to load: {exec_result.output}")
|
||||||
|
self._setup_functions_complete = True
|
||||||
|
|
||||||
|
def execute_code_blocks(self, code_blocks: list[CodeBlock]) -> CommandLineCodeResult:
|
||||||
|
"""(Experimental) Execute the code blocks and return the result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code_blocks (List[CodeBlock]): The code blocks to execute.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CommandLineCodeResult: The result of the code execution.
|
||||||
|
"""
|
||||||
|
if not self._setup_functions_complete:
|
||||||
|
self._setup_functions()
|
||||||
|
return self._execute_code_dont_check_setup(code_blocks)
|
||||||
|
|
||||||
|
def _execute_code_dont_check_setup(self, code_blocks: list[CodeBlock]) -> CommandLineCodeResult:
|
||||||
|
logs_all = ""
|
||||||
|
file_names = []
|
||||||
|
for code_block in code_blocks:
|
||||||
|
lang, code = code_block.language, code_block.code
|
||||||
|
lang = lang.lower()
|
||||||
|
|
||||||
|
LocalCommandLineCodeExecutor.sanitize_command(lang, code)
|
||||||
|
code = silence_pip(code, lang)
|
||||||
|
|
||||||
|
if lang in PYTHON_VARIANTS:
|
||||||
|
lang = "python"
|
||||||
|
|
||||||
|
if WIN32 and lang in ["sh", "shell"]:
|
||||||
|
lang = "ps1"
|
||||||
|
|
||||||
|
if lang not in self.SUPPORTED_LANGUAGES:
|
||||||
|
# In case the language is not supported, we return an error message.
|
||||||
|
exitcode = 1
|
||||||
|
logs_all += "\n" + f"unknown language {lang}"
|
||||||
|
break
|
||||||
|
|
||||||
|
execute_code = self.execution_policies.get(lang, False)
|
||||||
|
try:
|
||||||
|
# Check if there is a filename comment
|
||||||
|
filename = _get_file_name_from_content(code, self._work_dir)
|
||||||
|
except ValueError:
|
||||||
|
return CommandLineCodeResult(exit_code=1, output="Filename is not in the workspace")
|
||||||
|
|
||||||
|
if filename is None:
|
||||||
|
# create a file with an automatically generated name
|
||||||
|
code_hash = md5(code.encode()).hexdigest()
|
||||||
|
filename = f"tmp_code_{code_hash}.{'py' if lang.startswith('python') else lang}"
|
||||||
|
written_file = (self._work_dir / filename).resolve()
|
||||||
|
with written_file.open("w", encoding="utf-8") as f:
|
||||||
|
f.write(code)
|
||||||
|
file_names.append(written_file)
|
||||||
|
|
||||||
|
if not execute_code:
|
||||||
|
# Just return a message that the file is saved.
|
||||||
|
logs_all += f"Code saved to {written_file!s}\n"
|
||||||
|
exitcode = 0
|
||||||
|
continue
|
||||||
|
|
||||||
|
program = _cmd(lang)
|
||||||
|
cmd = [program, str(written_file.absolute())]
|
||||||
|
env = os.environ.copy()
|
||||||
|
|
||||||
|
if self._virtual_env_context:
|
||||||
|
virtual_env_abs_path = os.path.abspath(self._virtual_env_context.bin_path)
|
||||||
|
path_with_virtualenv = rf"{virtual_env_abs_path}{os.pathsep}{env['PATH']}"
|
||||||
|
env["PATH"] = path_with_virtualenv
|
||||||
|
if WIN32:
|
||||||
|
activation_script = os.path.join(virtual_env_abs_path, "activate.bat")
|
||||||
|
cmd = [activation_script, "&&", *cmd]
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
cwd=self._work_dir,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=float(self._timeout),
|
||||||
|
env=env,
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
logs_all += "\n" + TIMEOUT_MSG
|
||||||
|
# Same exit code as the timeout command on linux.
|
||||||
|
exitcode = 124
|
||||||
|
break
|
||||||
|
|
||||||
|
logs_all += result.stderr
|
||||||
|
logs_all += result.stdout
|
||||||
|
exitcode = result.returncode
|
||||||
|
|
||||||
|
if exitcode != 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
code_file = str(file_names[0]) if len(file_names) > 0 else None
|
||||||
|
return CommandLineCodeResult(exit_code=exitcode, output=logs_all, code_file=code_file)
|
||||||
|
|
||||||
|
def restart(self) -> None:
|
||||||
|
"""(Experimental) Restart the code executor."""
|
||||||
|
warnings.warn("Restarting local command line code executor is not supported. No action is taken.")
|
||||||
|
|
||||||
|
|
||||||
|
# From stack overflow: https://stackoverflow.com/a/52087847/2214524
|
||||||
|
class _DeprecatedClassMeta(type):
|
||||||
|
def __new__(cls, name, bases, classdict, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||||
|
alias = classdict.get("_DeprecatedClassMeta__alias")
|
||||||
|
|
||||||
|
if alias is not None:
|
||||||
|
|
||||||
|
def new(cls, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||||
|
alias = cls._DeprecatedClassMeta__alias
|
||||||
|
|
||||||
|
if alias is not None:
|
||||||
|
warnings.warn(
|
||||||
|
f"{cls.__name__} has been renamed to {alias.__name__}, the alias will be removed in the future",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
return alias(*args, **kwargs)
|
||||||
|
|
||||||
|
classdict["__new__"] = new
|
||||||
|
classdict["_DeprecatedClassMeta__alias"] = alias
|
||||||
|
|
||||||
|
fixed_bases = []
|
||||||
|
|
||||||
|
for b in bases:
|
||||||
|
alias = getattr(b, "_DeprecatedClassMeta__alias", None)
|
||||||
|
|
||||||
|
if alias is not None:
|
||||||
|
warnings.warn(
|
||||||
|
f"{b.__name__} has been renamed to {alias.__name__}, the alias will be removed in the future",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Avoid duplicate base classes.
|
||||||
|
b = alias or b
|
||||||
|
if b not in fixed_bases:
|
||||||
|
fixed_bases.append(b)
|
||||||
|
|
||||||
|
fixed_bases = tuple(fixed_bases) # type: ignore[assignment]
|
||||||
|
|
||||||
|
return super().__new__(cls, name, fixed_bases, classdict, *args, **kwargs) # type: ignore[call-overload]
|
||||||
|
|
||||||
|
def __instancecheck__(cls, instance): # type: ignore[no-untyped-def]
|
||||||
|
return any(cls.__subclasscheck__(c) for c in {type(instance), instance.__class__}) # type: ignore[no-untyped-call]
|
||||||
|
|
||||||
|
def __subclasscheck__(cls, subclass): # type: ignore[no-untyped-def]
|
||||||
|
if subclass is cls:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return issubclass(subclass, cls._DeprecatedClassMeta__alias) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
|
class LocalCommandlineCodeExecutor(metaclass=_DeprecatedClassMeta):
|
||||||
|
"""LocalCommandlineCodeExecutor renamed to LocalCommandLineCodeExecutor"""
|
||||||
|
|
||||||
|
_DeprecatedClassMeta__alias = LocalCommandLineCodeExecutor
|
||||||
|
|
||||||
|
|
||||||
|
class CommandlineCodeResult(metaclass=_DeprecatedClassMeta):
|
||||||
|
"""CommandlineCodeResult renamed to CommandLineCodeResult"""
|
||||||
|
|
||||||
|
_DeprecatedClassMeta__alias = CommandLineCodeResult
|
||||||
45
mm_agents/coact/autogen/coding/markdown_code_extractor.py
Normal file
45
mm_agents/coact/autogen/coding/markdown_code_extractor.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
import re
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from ..code_utils import CODE_BLOCK_PATTERN, UNKNOWN, content_str, infer_lang
|
||||||
|
from ..doc_utils import export_module
|
||||||
|
from ..types import UserMessageImageContentPart, UserMessageTextContentPart
|
||||||
|
from .base import CodeBlock, CodeExtractor
|
||||||
|
|
||||||
|
__all__ = ("MarkdownCodeExtractor",)
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.coding")
|
||||||
|
class MarkdownCodeExtractor(CodeExtractor):
|
||||||
|
"""(Experimental) A class that extracts code blocks from a message using Markdown syntax."""
|
||||||
|
|
||||||
|
def extract_code_blocks(
|
||||||
|
self, message: Union[str, list[Union[UserMessageTextContentPart, UserMessageImageContentPart]], None]
|
||||||
|
) -> list[CodeBlock]:
|
||||||
|
"""(Experimental) Extract code blocks from a message. If no code blocks are found,
|
||||||
|
return an empty list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (str): The message to extract code blocks from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[CodeBlock]: The extracted code blocks or an empty list.
|
||||||
|
"""
|
||||||
|
text = content_str(message)
|
||||||
|
match = re.findall(CODE_BLOCK_PATTERN, text, flags=re.DOTALL)
|
||||||
|
if not match:
|
||||||
|
return []
|
||||||
|
code_blocks = []
|
||||||
|
for lang, code in match:
|
||||||
|
if lang == "":
|
||||||
|
lang = infer_lang(code)
|
||||||
|
if lang == UNKNOWN:
|
||||||
|
lang = ""
|
||||||
|
code_blocks.append(CodeBlock(code=code, language=lang))
|
||||||
|
return code_blocks
|
||||||
56
mm_agents/coact/autogen/coding/utils.py
Normal file
56
mm_agents/coact/autogen/coding/utils.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/ag2ai/ag2 are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
# Will return the filename relative to the workspace path
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
filename_patterns = [
|
||||||
|
re.compile(r"^<!-- (filename:)?(.+?) -->", re.DOTALL),
|
||||||
|
re.compile(r"^/\* (filename:)?(.+?) \*/", re.DOTALL),
|
||||||
|
re.compile(r"^// (filename:)?(.+?)$", re.DOTALL),
|
||||||
|
re.compile(r"^# (filename:)?(.+?)$", re.DOTALL),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Raises ValueError if the file is not in the workspace
|
||||||
|
def _get_file_name_from_content(code: str, workspace_path: Path) -> Optional[str]:
|
||||||
|
first_line = code.split("\n")[0].strip()
|
||||||
|
# TODO - support other languages
|
||||||
|
for pattern in filename_patterns:
|
||||||
|
matches = pattern.match(first_line)
|
||||||
|
if matches is not None:
|
||||||
|
filename = matches.group(2).strip()
|
||||||
|
|
||||||
|
# Handle relative paths in the filename
|
||||||
|
path = Path(filename)
|
||||||
|
if not path.is_absolute():
|
||||||
|
path = workspace_path / path
|
||||||
|
path = path.resolve()
|
||||||
|
# Throws an error if the file is not in the workspace
|
||||||
|
relative = path.relative_to(workspace_path.resolve())
|
||||||
|
return str(relative)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def silence_pip(code: str, lang: str) -> str:
|
||||||
|
"""Apply -qqq flag to pip install commands."""
|
||||||
|
if lang == "python":
|
||||||
|
regex = r"^! ?pip install"
|
||||||
|
elif lang in ["bash", "shell", "sh", "pwsh", "powershell", "ps1"]:
|
||||||
|
regex = r"^pip install"
|
||||||
|
else:
|
||||||
|
return code
|
||||||
|
|
||||||
|
# Find lines that start with pip install and make sure "-qqq" flag is added.
|
||||||
|
lines = code.split("\n")
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
# use regex to find lines that start with pip install.
|
||||||
|
match = re.search(regex, line)
|
||||||
|
if match is not None and "-qqq" not in line:
|
||||||
|
lines[i] = line.replace(match.group(0), match.group(0) + " -qqq")
|
||||||
|
return "\n".join(lines)
|
||||||
34
mm_agents/coact/autogen/doc_utils.py
Normal file
34
mm_agents/coact/autogen/doc_utils.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
__all__ = ["export_module"]
|
||||||
|
|
||||||
|
from typing import Callable, Optional, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
# Global dictionary to store export module mappings
|
||||||
|
# Key: original symbol name (qualified by module)
|
||||||
|
# Value: target module where it should be documented
|
||||||
|
_PDOC_MODULE_EXPORT_MAPPINGS: dict[str, str] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def export_module(module: str) -> Callable[[T], T]:
|
||||||
|
def decorator(cls: T) -> T:
|
||||||
|
original_module = getattr(cls, "__module__", None)
|
||||||
|
if original_module:
|
||||||
|
fqn = f"{original_module}.{cls.__name__}"
|
||||||
|
_PDOC_MODULE_EXPORT_MAPPINGS[fqn] = module
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def get_target_module(obj: object) -> Optional[str]:
|
||||||
|
"""Get the target module where an object should be documented."""
|
||||||
|
if not hasattr(obj, "__module__"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
fqn = f"{obj.__module__}.{obj.__name__}"
|
||||||
|
return _PDOC_MODULE_EXPORT_MAPPINGS.get(fqn)
|
||||||
7
mm_agents/coact/autogen/events/__init__.py
Normal file
7
mm_agents/coact/autogen/events/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from .base_event import BaseEvent, get_annotated_type_for_event_classes, wrap_event
|
||||||
|
from .helpers import deprecated_by
|
||||||
|
|
||||||
|
__all__ = ["BaseEvent", "deprecated_by", "get_annotated_type_for_event_classes", "wrap_event"]
|
||||||
1014
mm_agents/coact/autogen/events/agent_events.py
Normal file
1014
mm_agents/coact/autogen/events/agent_events.py
Normal file
File diff suppressed because it is too large
Load Diff
99
mm_agents/coact/autogen/events/base_event.py
Normal file
99
mm_agents/coact/autogen/events/base_event.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
|
||||||
|
from abc import ABC
|
||||||
|
from typing import Annotated, Any, Callable, Literal, Optional, Union
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, create_model
|
||||||
|
|
||||||
|
from ..doc_utils import export_module
|
||||||
|
|
||||||
|
__all__ = ["BaseEvent", "get_annotated_type_for_event_classes", "get_event_classes", "wrap_event"]
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.events")
|
||||||
|
class BaseEvent(BaseModel, ABC):
|
||||||
|
uuid: UUID
|
||||||
|
|
||||||
|
def __init__(self, uuid: Optional[UUID] = None, **kwargs: Any) -> None:
|
||||||
|
uuid = uuid or uuid4()
|
||||||
|
super().__init__(uuid=uuid, **kwargs)
|
||||||
|
|
||||||
|
def print(self, f: Optional[Callable[..., Any]] = None) -> None:
|
||||||
|
"""Print event
|
||||||
|
|
||||||
|
Args:
|
||||||
|
f (Optional[Callable[..., Any]], optional): Print function. If none, python's default print will be used.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def camel2snake(name: str) -> str:
|
||||||
|
return "".join(["_" + i.lower() if i.isupper() else i for i in name]).lstrip("_")
|
||||||
|
|
||||||
|
|
||||||
|
_event_classes: dict[str, type[BaseModel]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.events")
|
||||||
|
def wrap_event(event_cls: type[BaseEvent]) -> type[BaseModel]:
|
||||||
|
"""Wrap a event class with a type field to be used in a union type
|
||||||
|
|
||||||
|
This is needed for proper serialization and deserialization of events in a union type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_cls (type[BaseEvent]): Event class to wrap
|
||||||
|
"""
|
||||||
|
global _event_classes
|
||||||
|
|
||||||
|
if not event_cls.__name__.endswith("Event"):
|
||||||
|
raise ValueError("Event class name must end with 'Event'")
|
||||||
|
|
||||||
|
type_name = camel2snake(event_cls.__name__)
|
||||||
|
type_name = type_name[: -len("_event")]
|
||||||
|
|
||||||
|
class WrapperBase(BaseModel):
|
||||||
|
# these types are generated dynamically so we need to disable the type checker
|
||||||
|
type: Literal[type_name] = type_name # type: ignore[valid-type]
|
||||||
|
content: event_cls # type: ignore[valid-type]
|
||||||
|
|
||||||
|
def __init__(self, *args: Any, **data: Any):
|
||||||
|
if set(data.keys()) == {"type", "content"} and "content" in data:
|
||||||
|
super().__init__(*args, **data)
|
||||||
|
else:
|
||||||
|
if "content" in data:
|
||||||
|
content = data.pop("content")
|
||||||
|
super().__init__(*args, content=event_cls(*args, **data, content=content), **data)
|
||||||
|
else:
|
||||||
|
super().__init__(content=event_cls(*args, **data), **data)
|
||||||
|
|
||||||
|
def print(self, f: Optional[Callable[..., Any]] = None) -> None:
|
||||||
|
self.content.print(f) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
wrapper_cls = create_model(event_cls.__name__, __base__=WrapperBase)
|
||||||
|
|
||||||
|
# Preserve the original class's docstring and other attributes
|
||||||
|
wrapper_cls.__doc__ = event_cls.__doc__
|
||||||
|
wrapper_cls.__module__ = event_cls.__module__
|
||||||
|
|
||||||
|
# Copy any other relevant attributes/metadata from the original class
|
||||||
|
if hasattr(event_cls, "__annotations__"):
|
||||||
|
wrapper_cls.__annotations__ = event_cls.__annotations__
|
||||||
|
|
||||||
|
_event_classes[type_name] = wrapper_cls
|
||||||
|
|
||||||
|
return wrapper_cls
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen.events")
|
||||||
|
def get_annotated_type_for_event_classes() -> type[Any]:
|
||||||
|
# this is a dynamic type so we need to disable the type checker
|
||||||
|
union_type = Union[tuple(_event_classes.values())] # type: ignore[valid-type]
|
||||||
|
return Annotated[union_type, Field(discriminator="type")] # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
|
def get_event_classes() -> dict[str, type[BaseModel]]:
|
||||||
|
return _event_classes
|
||||||
167
mm_agents/coact/autogen/events/client_events.py
Normal file
167
mm_agents/coact/autogen/events/client_events.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Any, Callable, Literal, Optional, Union
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .base_event import BaseEvent, wrap_event
|
||||||
|
|
||||||
|
__all__ = ["UsageSummaryEvent"]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelUsageSummary(BaseModel):
|
||||||
|
"""Model usage summary."""
|
||||||
|
|
||||||
|
model: str
|
||||||
|
"""Model name."""
|
||||||
|
completion_tokens: int
|
||||||
|
"""Number of tokens used for completion."""
|
||||||
|
cost: float
|
||||||
|
"""Cost of the completion."""
|
||||||
|
prompt_tokens: int
|
||||||
|
"""Number of tokens used for prompt."""
|
||||||
|
total_tokens: int
|
||||||
|
"""Total number of tokens used."""
|
||||||
|
|
||||||
|
|
||||||
|
class ActualUsageSummary(BaseModel):
|
||||||
|
"""Actual usage summary."""
|
||||||
|
|
||||||
|
usages: Optional[list[ModelUsageSummary]] = None
|
||||||
|
"""List of model usage summaries."""
|
||||||
|
total_cost: Optional[float] = None
|
||||||
|
"""Total cost."""
|
||||||
|
|
||||||
|
|
||||||
|
class TotalUsageSummary(BaseModel):
|
||||||
|
"""Total usage summary."""
|
||||||
|
|
||||||
|
usages: Optional[list[ModelUsageSummary]] = None
|
||||||
|
"""List of model usage summaries."""
|
||||||
|
total_cost: Optional[float] = None
|
||||||
|
"""Total cost."""
|
||||||
|
|
||||||
|
|
||||||
|
Mode = Literal["both", "total", "actual"]
|
||||||
|
|
||||||
|
|
||||||
|
def _change_usage_summary_format(
|
||||||
|
actual_usage_summary: Optional[dict[str, Any]] = None, total_usage_summary: Optional[dict[str, Any]] = None
|
||||||
|
) -> dict[str, dict[str, Any]]:
|
||||||
|
summary: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for usage_type, usage_summary in {"actual": actual_usage_summary, "total": total_usage_summary}.items():
|
||||||
|
if usage_summary is None:
|
||||||
|
summary[usage_type] = {"usages": None, "total_cost": None}
|
||||||
|
continue
|
||||||
|
|
||||||
|
usage_summary_altered_format: dict[str, list[dict[str, Any]]] = {"usages": []}
|
||||||
|
for k, v in usage_summary.items():
|
||||||
|
if isinstance(k, str) and isinstance(v, dict):
|
||||||
|
current_usage = {key: value for key, value in v.items()}
|
||||||
|
current_usage["model"] = k
|
||||||
|
usage_summary_altered_format["usages"].append(current_usage)
|
||||||
|
else:
|
||||||
|
usage_summary_altered_format[k] = v
|
||||||
|
summary[usage_type] = usage_summary_altered_format
|
||||||
|
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
@wrap_event
|
||||||
|
class UsageSummaryEvent(BaseEvent):
|
||||||
|
"""Usage summary message."""
|
||||||
|
|
||||||
|
actual: ActualUsageSummary
|
||||||
|
"""Actual usage summary."""
|
||||||
|
total: TotalUsageSummary
|
||||||
|
"""Total usage summary."""
|
||||||
|
mode: Mode
|
||||||
|
"""Mode to display the usage summary."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
uuid: Optional[UUID] = None,
|
||||||
|
actual_usage_summary: Optional[dict[str, Any]] = None,
|
||||||
|
total_usage_summary: Optional[dict[str, Any]] = None,
|
||||||
|
mode: Mode = "both",
|
||||||
|
):
|
||||||
|
# print(f"{actual_usage_summary=}")
|
||||||
|
# print(f"{total_usage_summary=}")
|
||||||
|
|
||||||
|
summary_dict = _change_usage_summary_format(actual_usage_summary, total_usage_summary)
|
||||||
|
|
||||||
|
super().__init__(uuid=uuid, **summary_dict, mode=mode)
|
||||||
|
|
||||||
|
def _print_usage(
|
||||||
|
self,
|
||||||
|
usage_summary: Union[ActualUsageSummary, TotalUsageSummary],
|
||||||
|
usage_type: str = "total",
|
||||||
|
f: Optional[Callable[..., Any]] = None,
|
||||||
|
) -> None:
|
||||||
|
f = f or print
|
||||||
|
word_from_type = "including" if usage_type == "total" else "excluding"
|
||||||
|
if usage_summary.usages is None or len(usage_summary.usages) == 0:
|
||||||
|
f("No actual cost incurred (all completions are using cache).", flush=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
f(f"Usage summary {word_from_type} cached usage: ", flush=True)
|
||||||
|
f(f"Total cost: {round(usage_summary.total_cost, 5)}", flush=True) # type: ignore [arg-type]
|
||||||
|
|
||||||
|
for usage in usage_summary.usages:
|
||||||
|
f(
|
||||||
|
f"* Model '{usage.model}': cost: {round(usage.cost, 5)}, prompt_tokens: {usage.prompt_tokens}, completion_tokens: {usage.completion_tokens}, total_tokens: {usage.total_tokens}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def print(self, f: Optional[Callable[..., Any]] = None) -> None:
|
||||||
|
f = f or print
|
||||||
|
|
||||||
|
if self.total.usages is None:
|
||||||
|
f('No usage summary. Please call "create" first.', flush=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
f("-" * 100, flush=True)
|
||||||
|
if self.mode == "both":
|
||||||
|
self._print_usage(self.actual, "actual", f)
|
||||||
|
f()
|
||||||
|
if self.total.model_dump_json() != self.actual.model_dump_json():
|
||||||
|
self._print_usage(self.total, "total", f)
|
||||||
|
else:
|
||||||
|
f(
|
||||||
|
"All completions are non-cached: the total cost with cached completions is the same as actual cost.",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
elif self.mode == "total":
|
||||||
|
self._print_usage(self.total, "total", f)
|
||||||
|
elif self.mode == "actual":
|
||||||
|
self._print_usage(self.actual, "actual", f)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Invalid mode: {self.mode}, choose from "actual", "total", ["actual", "total"]')
|
||||||
|
f("-" * 100, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
@wrap_event
|
||||||
|
class StreamEvent(BaseEvent):
|
||||||
|
"""Stream event."""
|
||||||
|
|
||||||
|
content: str
|
||||||
|
"""Content of the event."""
|
||||||
|
|
||||||
|
def __init__(self, *, uuid: Optional[UUID] = None, content: str) -> None:
|
||||||
|
super().__init__(uuid=uuid, content=content)
|
||||||
|
|
||||||
|
def print(self, f: Optional[Callable[..., Any]] = None) -> None:
|
||||||
|
f = f or print
|
||||||
|
|
||||||
|
# Set the terminal text color to green
|
||||||
|
f("\033[32m", end="")
|
||||||
|
|
||||||
|
f(self.content, end="", flush=True)
|
||||||
|
|
||||||
|
# Reset the terminal text color
|
||||||
|
f("\033[0m\n")
|
||||||
36
mm_agents/coact/autogen/events/helpers.py
Normal file
36
mm_agents/coact/autogen/events/helpers.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import logging
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def deprecated_by(
|
||||||
|
new_class: type[BaseModel],
|
||||||
|
param_mapping: dict[str, str] = None,
|
||||||
|
) -> Callable[[type[BaseModel]], Callable[..., BaseModel]]:
|
||||||
|
param_mapping = param_mapping or {}
|
||||||
|
|
||||||
|
def decorator(
|
||||||
|
old_class: type[BaseModel],
|
||||||
|
param_mapping: dict[str, str] = param_mapping,
|
||||||
|
) -> Callable[..., BaseModel]:
|
||||||
|
@wraps(old_class)
|
||||||
|
def wrapper(*args, **kwargs) -> BaseModel:
|
||||||
|
logger.warning(
|
||||||
|
f"{old_class.__name__} is deprecated by {new_class.__name__}. Please import it from {new_class.__module__} and use it instead."
|
||||||
|
)
|
||||||
|
# Translate old parameters to new parameters
|
||||||
|
new_kwargs = {param_mapping.get(k, k): v for k, v in kwargs.items()}
|
||||||
|
|
||||||
|
# Pass the translated parameters to the new class
|
||||||
|
return new_class(*args, **new_kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
46
mm_agents/coact/autogen/events/print_event.py
Normal file
46
mm_agents/coact/autogen/events/print_event.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from .base_event import BaseEvent, wrap_event
|
||||||
|
|
||||||
|
|
||||||
|
@wrap_event
|
||||||
|
class PrintEvent(BaseEvent):
|
||||||
|
"""Print message"""
|
||||||
|
|
||||||
|
objects: list[str]
|
||||||
|
"""List of objects to print"""
|
||||||
|
sep: str
|
||||||
|
"""Separator between objects"""
|
||||||
|
end: str
|
||||||
|
"""End of the print"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False, uuid: Optional[UUID] = None
|
||||||
|
):
|
||||||
|
objects_as_string = [self._to_json(x) for x in objects]
|
||||||
|
|
||||||
|
super().__init__(uuid=uuid, objects=objects_as_string, sep=sep, end=end)
|
||||||
|
|
||||||
|
def _to_json(self, obj: Any) -> str:
|
||||||
|
if isinstance(obj, str):
|
||||||
|
return obj
|
||||||
|
|
||||||
|
if hasattr(obj, "model_dump_json"):
|
||||||
|
return obj.model_dump_json() # type: ignore [no-any-return]
|
||||||
|
try:
|
||||||
|
return json.dumps(obj)
|
||||||
|
except Exception:
|
||||||
|
return str(obj)
|
||||||
|
# return repr(obj)
|
||||||
|
|
||||||
|
def print(self, f: Optional[Callable[..., Any]] = None) -> None:
|
||||||
|
f = f or print
|
||||||
|
|
||||||
|
f(*self.objects, sep=self.sep, end=self.end, flush=True)
|
||||||
73
mm_agents/coact/autogen/exception_utils.py
Normal file
73
mm_agents/coact/autogen/exception_utils.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .doc_utils import export_module
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AgentNameConflictError",
|
||||||
|
"InvalidCarryOverTypeError",
|
||||||
|
"ModelToolNotSupportedError",
|
||||||
|
"NoEligibleSpeakerError",
|
||||||
|
"SenderRequiredError",
|
||||||
|
"UndefinedNextAgentError",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen")
|
||||||
|
class AgentNameConflictError(Exception): # noqa: N818
|
||||||
|
def __init__(self, msg: str = "Found multiple agents with the same name.", *args: Any, **kwargs: Any):
|
||||||
|
super().__init__(msg, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen")
|
||||||
|
class NoEligibleSpeakerError(Exception): # noqa: N818
|
||||||
|
"""Exception raised for early termination of a GroupChat."""
|
||||||
|
|
||||||
|
def __init__(self, message: str = "No eligible speakers."):
|
||||||
|
self.message = message
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen")
|
||||||
|
class SenderRequiredError(Exception): # noqa: N818
|
||||||
|
"""Exception raised when the sender is required but not provided."""
|
||||||
|
|
||||||
|
def __init__(self, message: str = "Sender is required but not provided."):
|
||||||
|
self.message = message
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen")
|
||||||
|
class InvalidCarryOverTypeError(Exception): # noqa: N818
|
||||||
|
"""Exception raised when the carryover type is invalid."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, message: str = "Carryover should be a string or a list of strings. Not adding carryover to the message."
|
||||||
|
):
|
||||||
|
self.message = message
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
@export_module("autogen")
|
||||||
|
class UndefinedNextAgentError(Exception): # noqa: N818
|
||||||
|
"""Exception raised when the provided next agents list does not overlap with agents in the group."""
|
||||||
|
|
||||||
|
def __init__(self, message: str = "The provided agents list does not overlap with agents in the group."):
|
||||||
|
self.message = message
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelToolNotSupportedError(Exception):
|
||||||
|
"""Exception raised when attempting to use tools with models that do not support them."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
):
|
||||||
|
self.message = f"Tools are not supported with {model} models. Refer to the documentation at https://platform.openai.com/docs/guides/reasoning#limitations"
|
||||||
|
super().__init__(self.message)
|
||||||
16
mm_agents/coact/autogen/fast_depends/__init__.py
Normal file
16
mm_agents/coact/autogen/fast_depends/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from .dependencies import Provider, dependency_provider
|
||||||
|
from .use import Depends, inject
|
||||||
|
|
||||||
|
__all__ = (
|
||||||
|
"Depends",
|
||||||
|
"Provider",
|
||||||
|
"dependency_provider",
|
||||||
|
"inject",
|
||||||
|
)
|
||||||
80
mm_agents/coact/autogen/fast_depends/_compat.py
Normal file
80
mm_agents/coact/autogen/fast_depends/_compat.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License.
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from importlib.metadata import version as get_version
|
||||||
|
from typing import Any, Dict, Optional, Tuple, Type
|
||||||
|
|
||||||
|
from pydantic import BaseModel, create_model
|
||||||
|
from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||||
|
|
||||||
|
__all__ = (
|
||||||
|
"PYDANTIC_V2",
|
||||||
|
"BaseModel",
|
||||||
|
"ConfigDict",
|
||||||
|
"ExceptionGroup",
|
||||||
|
"create_model",
|
||||||
|
"evaluate_forwardref",
|
||||||
|
"get_config_base",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
|
||||||
|
|
||||||
|
default_pydantic_config = {"arbitrary_types_allowed": True}
|
||||||
|
|
||||||
|
evaluate_forwardref: Any
|
||||||
|
# isort: off
|
||||||
|
if PYDANTIC_V2:
|
||||||
|
from pydantic import ConfigDict
|
||||||
|
from pydantic._internal._typing_extra import ( # type: ignore[no-redef]
|
||||||
|
eval_type_lenient as evaluate_forwardref,
|
||||||
|
)
|
||||||
|
|
||||||
|
def model_schema(model: Type[BaseModel]) -> Dict[str, Any]:
|
||||||
|
return model.model_json_schema()
|
||||||
|
|
||||||
|
def get_config_base(config_data: Optional[ConfigDict] = None) -> ConfigDict:
|
||||||
|
return config_data or ConfigDict(**default_pydantic_config) # type: ignore[typeddict-item]
|
||||||
|
|
||||||
|
def get_aliases(model: Type[BaseModel]) -> Tuple[str, ...]:
|
||||||
|
return tuple(f.alias or name for name, f in model.model_fields.items())
|
||||||
|
|
||||||
|
class CreateBaseModel(BaseModel):
|
||||||
|
"""Just to support FastStream < 0.3.7."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
else:
|
||||||
|
from pydantic.typing import evaluate_forwardref as evaluate_forwardref # type: ignore[no-redef]
|
||||||
|
from pydantic.config import get_config, ConfigDict, BaseConfig
|
||||||
|
|
||||||
|
def get_config_base(config_data: Optional[ConfigDict] = None) -> Type[BaseConfig]: # type: ignore[misc,no-any-unimported]
|
||||||
|
return get_config(config_data or ConfigDict(**default_pydantic_config)) # type: ignore[typeddict-item,no-any-unimported,no-any-return]
|
||||||
|
|
||||||
|
def model_schema(model: Type[BaseModel]) -> Dict[str, Any]:
|
||||||
|
return model.schema()
|
||||||
|
|
||||||
|
def get_aliases(model: Type[BaseModel]) -> Tuple[str, ...]:
|
||||||
|
return tuple(f.alias or name for name, f in model.__fields__.items()) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
class CreateBaseModel(BaseModel): # type: ignore[no-redef]
|
||||||
|
"""Just to support FastStream < 0.3.7."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
|
ANYIO_V3 = get_version("anyio").startswith("3.")
|
||||||
|
|
||||||
|
if ANYIO_V3:
|
||||||
|
from anyio import ExceptionGroup as ExceptionGroup # type: ignore[attr-defined]
|
||||||
|
else:
|
||||||
|
if sys.version_info < (3, 11):
|
||||||
|
from exceptiongroup import ExceptionGroup as ExceptionGroup
|
||||||
|
else:
|
||||||
|
ExceptionGroup = ExceptionGroup
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user