From b9681557573b1d4f848c7b628521e8214a7a19bd Mon Sep 17 00:00:00 2001 From: Linxin Song Date: Wed, 30 Jul 2025 19:35:20 -0700 Subject: [PATCH] CoACT initialize (#292) --- desktop_env/controllers/python.py | 76 + desktop_env/server/main.py | 225 + mm_agents/coact/OAI_CONFIG_LIST | 27 + mm_agents/coact/__init__.py | 0 mm_agents/coact/autogen/__init__.py | 81 + mm_agents/coact/autogen/agentchat/__init__.py | 38 + mm_agents/coact/autogen/agentchat/agent.py | 182 + .../autogen/agentchat/assistant_agent.py | 85 + mm_agents/coact/autogen/agentchat/chat.py | 309 ++ .../autogen/agentchat/contrib/__init__.py | 5 + .../contrib/capabilities/__init__.py | 5 + .../contrib/capabilities/agent_capability.py | 20 + .../contrib/capabilities/generate_images.py | 301 ++ .../contrib/capabilities/teachability.py | 393 ++ .../contrib/capabilities/text_compressors.py | 66 + .../contrib/capabilities/tools_capability.py | 22 + .../capabilities/transform_messages.py | 93 + .../contrib/capabilities/transforms.py | 579 +++ .../contrib/capabilities/transforms_util.py | 122 + .../contrib/capabilities/vision_capability.py | 212 + .../autogen/agentchat/contrib/img_utils.py | 411 ++ .../contrib/multimodal_conversable_agent.py | 153 + .../autogen/agentchat/conversable_agent.py | 4023 +++++++++++++++++ .../coact/autogen/agentchat/group/__init__.py | 64 + .../agentchat/group/available_condition.py | 91 + .../agentchat/group/context_condition.py | 77 + .../agentchat/group/context_expression.py | 238 + .../autogen/agentchat/group/context_str.py | 41 + .../agentchat/group/context_variables.py | 192 + .../agentchat/group/group_tool_executor.py | 202 + .../autogen/agentchat/group/group_utils.py | 636 +++ .../coact/autogen/agentchat/group/handoffs.py | 320 ++ .../autogen/agentchat/group/llm_condition.py | 93 + .../agentchat/group/multi_agent_chat.py | 237 + .../autogen/agentchat/group/on_condition.py | 58 + .../agentchat/group/on_context_condition.py | 54 + .../agentchat/group/patterns/__init__.py | 18 + .../autogen/agentchat/group/patterns/auto.py | 159 + .../agentchat/group/patterns/manual.py | 176 + .../agentchat/group/patterns/pattern.py | 294 ++ .../agentchat/group/patterns/random.py | 106 + .../agentchat/group/patterns/round_robin.py | 117 + .../autogen/agentchat/group/reply_result.py | 26 + .../group/speaker_selection_result.py | 41 + .../agentchat/group/targets/__init__.py | 4 + .../group/targets/group_chat_target.py | 132 + .../group/targets/group_manager_target.py | 151 + .../group/targets/transition_target.py | 413 ++ .../group/targets/transition_utils.py | 6 + .../coact/autogen/agentchat/groupchat.py | 1694 +++++++ .../autogen/agentchat/realtime/__init__.py | 3 + .../realtime/experimental/__init__.py | 20 + .../experimental/audio_adapters/__init__.py | 8 + .../audio_adapters/twilio_audio_adapter.py | 148 + .../audio_adapters/websocket_audio_adapter.py | 139 + .../realtime/experimental/audio_observer.py | 42 + .../realtime/experimental/clients/__init__.py | 15 + .../experimental/clients/gemini/__init__.py | 7 + .../experimental/clients/gemini/client.py | 274 ++ .../experimental/clients/oai/__init__.py | 8 + .../experimental/clients/oai/base_client.py | 220 + .../experimental/clients/oai/rtc_client.py | 243 + .../experimental/clients/oai/utils.py | 48 + .../experimental/clients/realtime_client.py | 190 + .../experimental/function_observer.py | 85 + .../realtime/experimental/realtime_agent.py | 158 + .../realtime/experimental/realtime_events.py | 42 + .../experimental/realtime_observer.py | 100 + .../realtime/experimental/realtime_swarm.py | 483 ++ .../realtime/experimental/websockets.py | 21 + .../agentchat/realtime_agent/__init__.py | 21 + .../autogen/agentchat/user_proxy_agent.py | 111 + mm_agents/coact/autogen/agentchat/utils.py | 206 + mm_agents/coact/autogen/code_utils.py | 596 +++ mm_agents/coact/autogen/coding/__init__.py | 22 + mm_agents/coact/autogen/coding/base.py | 119 + .../docker_commandline_code_executor.py | 268 ++ mm_agents/coact/autogen/coding/factory.py | 47 + .../coact/autogen/coding/func_with_reqs.py | 202 + .../coact/autogen/coding/jupyter/__init__.py | 23 + .../coact/autogen/coding/jupyter/base.py | 36 + .../coding/jupyter/docker_jupyter_server.py | 167 + .../jupyter/embedded_ipython_code_executor.py | 182 + .../autogen/coding/jupyter/import_utils.py | 82 + .../autogen/coding/jupyter/jupyter_client.py | 231 + .../coding/jupyter/jupyter_code_executor.py | 160 + .../coding/jupyter/local_jupyter_server.py | 172 + .../coding/local_commandline_code_executor.py | 405 ++ .../autogen/coding/markdown_code_extractor.py | 45 + mm_agents/coact/autogen/coding/utils.py | 56 + mm_agents/coact/autogen/doc_utils.py | 34 + mm_agents/coact/autogen/events/__init__.py | 7 + .../coact/autogen/events/agent_events.py | 1014 +++++ mm_agents/coact/autogen/events/base_event.py | 99 + .../coact/autogen/events/client_events.py | 167 + mm_agents/coact/autogen/events/helpers.py | 36 + mm_agents/coact/autogen/events/print_event.py | 46 + mm_agents/coact/autogen/exception_utils.py | 73 + .../coact/autogen/fast_depends/__init__.py | 16 + .../coact/autogen/fast_depends/_compat.py | 80 + .../autogen/fast_depends/core/__init__.py | 14 + .../coact/autogen/fast_depends/core/build.py | 225 + .../coact/autogen/fast_depends/core/model.py | 576 +++ .../fast_depends/dependencies/__init__.py | 15 + .../fast_depends/dependencies/model.py | 29 + .../fast_depends/dependencies/provider.py | 39 + .../autogen/fast_depends/library/__init__.py | 10 + .../autogen/fast_depends/library/model.py | 46 + mm_agents/coact/autogen/fast_depends/py.typed | 6 + .../coact/autogen/fast_depends/schema.py | 66 + mm_agents/coact/autogen/fast_depends/use.py | 280 ++ mm_agents/coact/autogen/fast_depends/utils.py | 187 + mm_agents/coact/autogen/formatting_utils.py | 83 + mm_agents/coact/autogen/function_utils.py | 13 + mm_agents/coact/autogen/graph_utils.py | 178 + mm_agents/coact/autogen/import_utils.py | 526 +++ mm_agents/coact/autogen/interop/__init__.py | 22 + .../coact/autogen/interop/crewai/__init__.py | 7 + .../coact/autogen/interop/crewai/crewai.py | 88 + .../coact/autogen/interop/interoperability.py | 71 + .../coact/autogen/interop/interoperable.py | 46 + .../autogen/interop/langchain/__init__.py | 8 + .../langchain/langchain_chat_model_factory.py | 155 + .../interop/langchain/langchain_tool.py | 82 + .../coact/autogen/interop/litellm/__init__.py | 7 + .../interop/litellm/litellm_config_factory.py | 179 + .../autogen/interop/pydantic_ai/__init__.py | 7 + .../interop/pydantic_ai/pydantic_ai.py | 168 + mm_agents/coact/autogen/interop/registry.py | 69 + mm_agents/coact/autogen/io/__init__.py | 15 + mm_agents/coact/autogen/io/base.py | 151 + mm_agents/coact/autogen/io/console.py | 56 + .../coact/autogen/io/processors/__init__.py | 12 + mm_agents/coact/autogen/io/processors/base.py | 21 + .../io/processors/console_event_processor.py | 56 + mm_agents/coact/autogen/io/run_response.py | 293 ++ .../coact/autogen/io/thread_io_stream.py | 63 + mm_agents/coact/autogen/io/websockets.py | 213 + mm_agents/coact/autogen/json_utils.py | 43 + mm_agents/coact/autogen/llm_config.py | 382 ++ mm_agents/coact/autogen/logger/__init__.py | 11 + mm_agents/coact/autogen/logger/base_logger.py | 128 + mm_agents/coact/autogen/logger/file_logger.py | 261 ++ .../coact/autogen/logger/logger_factory.py | 42 + .../coact/autogen/logger/logger_utils.py | 57 + .../coact/autogen/logger/sqlite_logger.py | 523 +++ mm_agents/coact/autogen/messages/__init__.py | 7 + .../coact/autogen/messages/agent_messages.py | 948 ++++ .../coact/autogen/messages/base_message.py | 107 + .../coact/autogen/messages/client_messages.py | 171 + .../coact/autogen/messages/print_message.py | 49 + mm_agents/coact/autogen/oai/__init__.py | 53 + mm_agents/coact/autogen/oai/anthropic.py | 714 +++ mm_agents/coact/autogen/oai/bedrock.py | 628 +++ mm_agents/coact/autogen/oai/cerebras.py | 299 ++ mm_agents/coact/autogen/oai/client.py | 1444 ++++++ mm_agents/coact/autogen/oai/client_utils.py | 169 + mm_agents/coact/autogen/oai/cohere.py | 479 ++ mm_agents/coact/autogen/oai/gemini.py | 1007 +++++ mm_agents/coact/autogen/oai/gemini_types.py | 156 + mm_agents/coact/autogen/oai/groq.py | 305 ++ mm_agents/coact/autogen/oai/mistral.py | 303 ++ .../coact/autogen/oai/oai_models/__init__.py | 11 + .../coact/autogen/oai/oai_models/_models.py | 16 + .../autogen/oai/oai_models/chat_completion.py | 87 + .../oai/oai_models/chat_completion_audio.py | 32 + .../oai/oai_models/chat_completion_message.py | 86 + .../chat_completion_message_tool_call.py | 37 + .../chat_completion_token_logprob.py | 63 + .../oai/oai_models/completion_usage.py | 60 + mm_agents/coact/autogen/oai/ollama.py | 643 +++ mm_agents/coact/autogen/oai/openai_utils.py | 881 ++++ mm_agents/coact/autogen/oai/together.py | 370 ++ mm_agents/coact/autogen/retrieve_utils.py | 491 ++ mm_agents/coact/autogen/runtime_logging.py | 160 + mm_agents/coact/autogen/token_count_utils.py | 265 ++ mm_agents/coact/autogen/tools/__init__.py | 20 + .../coact/autogen/tools/contrib/__init__.py | 9 + .../autogen/tools/contrib/time/__init__.py | 7 + .../coact/autogen/tools/contrib/time/time.py | 41 + .../autogen/tools/dependency_injection.py | 254 ++ .../autogen/tools/experimental/__init__.py | 48 + .../experimental/browser_use/__init__.py | 7 + .../experimental/browser_use/browser_use.py | 161 + .../tools/experimental/crawl4ai/__init__.py | 7 + .../tools/experimental/crawl4ai/crawl4ai.py | 153 + .../experimental/deep_research/__init__.py | 7 + .../deep_research/deep_research.py | 328 ++ .../tools/experimental/duckduckgo/__init__.py | 7 + .../duckduckgo/duckduckgo_search.py | 109 + .../tools/experimental/google/__init__.py | 14 + .../google/authentication/__init__.py | 11 + .../credentials_hosted_provider.py | 43 + .../credentials_local_provider.py | 91 + .../authentication/credentials_provider.py | 35 + .../experimental/google/drive/__init__.py | 9 + .../google/drive/drive_functions.py | 124 + .../experimental/google/drive/toolkit.py | 88 + .../tools/experimental/google/model.py | 17 + .../experimental/google/toolkit_protocol.py | 19 + .../experimental/google_search/__init__.py | 8 + .../google_search/google_search.py | 93 + .../google_search/youtube_search.py | 181 + .../experimental/messageplatform/__init__.py | 17 + .../messageplatform/discord/__init__.py | 7 + .../messageplatform/discord/discord.py | 288 ++ .../messageplatform/slack/__init__.py | 7 + .../messageplatform/slack/slack.py | 391 ++ .../messageplatform/telegram/__init__.py | 7 + .../messageplatform/telegram/telegram.py | 275 ++ .../tools/experimental/perplexity/__init__.py | 7 + .../perplexity/perplexity_search.py | 260 ++ .../tools/experimental/reliable/__init__.py | 10 + .../tools/experimental/reliable/reliable.py | 1316 ++++++ .../tools/experimental/tavily/__init__.py | 7 + .../experimental/tavily/tavily_search.py | 183 + .../web_search_preview/__init__.py | 7 + .../web_search_preview/web_search_preview.py | 114 + .../tools/experimental/wikipedia/__init__.py | 7 + .../tools/experimental/wikipedia/wikipedia.py | 287 ++ .../coact/autogen/tools/function_utils.py | 411 ++ mm_agents/coact/autogen/tools/tool.py | 187 + mm_agents/coact/autogen/tools/toolkit.py | 86 + mm_agents/coact/autogen/types.py | 29 + mm_agents/coact/coding_agent.py | 78 + mm_agents/coact/cua_agent.py | 328 ++ mm_agents/coact/operator_agent.py | 305 ++ run_coact.py | 266 ++ 228 files changed, 42386 insertions(+) create mode 100644 mm_agents/coact/OAI_CONFIG_LIST create mode 100644 mm_agents/coact/__init__.py create mode 100644 mm_agents/coact/autogen/__init__.py create mode 100644 mm_agents/coact/autogen/agentchat/__init__.py create mode 100644 mm_agents/coact/autogen/agentchat/agent.py create mode 100644 mm_agents/coact/autogen/agentchat/assistant_agent.py create mode 100644 mm_agents/coact/autogen/agentchat/chat.py create mode 100644 mm_agents/coact/autogen/agentchat/contrib/__init__.py create mode 100644 mm_agents/coact/autogen/agentchat/contrib/capabilities/__init__.py create mode 100644 mm_agents/coact/autogen/agentchat/contrib/capabilities/agent_capability.py create mode 100644 mm_agents/coact/autogen/agentchat/contrib/capabilities/generate_images.py create mode 100644 mm_agents/coact/autogen/agentchat/contrib/capabilities/teachability.py create mode 100644 mm_agents/coact/autogen/agentchat/contrib/capabilities/text_compressors.py create mode 100644 mm_agents/coact/autogen/agentchat/contrib/capabilities/tools_capability.py create mode 100644 mm_agents/coact/autogen/agentchat/contrib/capabilities/transform_messages.py create mode 100644 mm_agents/coact/autogen/agentchat/contrib/capabilities/transforms.py create mode 100644 mm_agents/coact/autogen/agentchat/contrib/capabilities/transforms_util.py create mode 100644 mm_agents/coact/autogen/agentchat/contrib/capabilities/vision_capability.py create mode 100644 mm_agents/coact/autogen/agentchat/contrib/img_utils.py create mode 100644 mm_agents/coact/autogen/agentchat/contrib/multimodal_conversable_agent.py create mode 100644 mm_agents/coact/autogen/agentchat/conversable_agent.py create mode 100644 mm_agents/coact/autogen/agentchat/group/__init__.py create mode 100644 mm_agents/coact/autogen/agentchat/group/available_condition.py create mode 100644 mm_agents/coact/autogen/agentchat/group/context_condition.py create mode 100644 mm_agents/coact/autogen/agentchat/group/context_expression.py create mode 100644 mm_agents/coact/autogen/agentchat/group/context_str.py create mode 100644 mm_agents/coact/autogen/agentchat/group/context_variables.py create mode 100644 mm_agents/coact/autogen/agentchat/group/group_tool_executor.py create mode 100644 mm_agents/coact/autogen/agentchat/group/group_utils.py create mode 100644 mm_agents/coact/autogen/agentchat/group/handoffs.py create mode 100644 mm_agents/coact/autogen/agentchat/group/llm_condition.py create mode 100644 mm_agents/coact/autogen/agentchat/group/multi_agent_chat.py create mode 100644 mm_agents/coact/autogen/agentchat/group/on_condition.py create mode 100644 mm_agents/coact/autogen/agentchat/group/on_context_condition.py create mode 100644 mm_agents/coact/autogen/agentchat/group/patterns/__init__.py create mode 100644 mm_agents/coact/autogen/agentchat/group/patterns/auto.py create mode 100644 mm_agents/coact/autogen/agentchat/group/patterns/manual.py create mode 100644 mm_agents/coact/autogen/agentchat/group/patterns/pattern.py create mode 100644 mm_agents/coact/autogen/agentchat/group/patterns/random.py create mode 100644 mm_agents/coact/autogen/agentchat/group/patterns/round_robin.py create mode 100644 mm_agents/coact/autogen/agentchat/group/reply_result.py create mode 100644 mm_agents/coact/autogen/agentchat/group/speaker_selection_result.py create mode 100644 mm_agents/coact/autogen/agentchat/group/targets/__init__.py create mode 100644 mm_agents/coact/autogen/agentchat/group/targets/group_chat_target.py create mode 100644 mm_agents/coact/autogen/agentchat/group/targets/group_manager_target.py create mode 100644 mm_agents/coact/autogen/agentchat/group/targets/transition_target.py create mode 100644 mm_agents/coact/autogen/agentchat/group/targets/transition_utils.py create mode 100644 mm_agents/coact/autogen/agentchat/groupchat.py create mode 100644 mm_agents/coact/autogen/agentchat/realtime/__init__.py create mode 100644 mm_agents/coact/autogen/agentchat/realtime/experimental/__init__.py create mode 100644 mm_agents/coact/autogen/agentchat/realtime/experimental/audio_adapters/__init__.py create mode 100644 mm_agents/coact/autogen/agentchat/realtime/experimental/audio_adapters/twilio_audio_adapter.py create mode 100644 mm_agents/coact/autogen/agentchat/realtime/experimental/audio_adapters/websocket_audio_adapter.py create mode 100644 mm_agents/coact/autogen/agentchat/realtime/experimental/audio_observer.py create mode 100644 mm_agents/coact/autogen/agentchat/realtime/experimental/clients/__init__.py create mode 100644 mm_agents/coact/autogen/agentchat/realtime/experimental/clients/gemini/__init__.py create mode 100644 mm_agents/coact/autogen/agentchat/realtime/experimental/clients/gemini/client.py create mode 100644 mm_agents/coact/autogen/agentchat/realtime/experimental/clients/oai/__init__.py create mode 100644 mm_agents/coact/autogen/agentchat/realtime/experimental/clients/oai/base_client.py create mode 100644 mm_agents/coact/autogen/agentchat/realtime/experimental/clients/oai/rtc_client.py create mode 100644 mm_agents/coact/autogen/agentchat/realtime/experimental/clients/oai/utils.py create mode 100644 mm_agents/coact/autogen/agentchat/realtime/experimental/clients/realtime_client.py create mode 100644 mm_agents/coact/autogen/agentchat/realtime/experimental/function_observer.py create mode 100644 mm_agents/coact/autogen/agentchat/realtime/experimental/realtime_agent.py create mode 100644 mm_agents/coact/autogen/agentchat/realtime/experimental/realtime_events.py create mode 100644 mm_agents/coact/autogen/agentchat/realtime/experimental/realtime_observer.py create mode 100644 mm_agents/coact/autogen/agentchat/realtime/experimental/realtime_swarm.py create mode 100644 mm_agents/coact/autogen/agentchat/realtime/experimental/websockets.py create mode 100644 mm_agents/coact/autogen/agentchat/realtime_agent/__init__.py create mode 100644 mm_agents/coact/autogen/agentchat/user_proxy_agent.py create mode 100644 mm_agents/coact/autogen/agentchat/utils.py create mode 100644 mm_agents/coact/autogen/code_utils.py create mode 100644 mm_agents/coact/autogen/coding/__init__.py create mode 100644 mm_agents/coact/autogen/coding/base.py create mode 100644 mm_agents/coact/autogen/coding/docker_commandline_code_executor.py create mode 100644 mm_agents/coact/autogen/coding/factory.py create mode 100644 mm_agents/coact/autogen/coding/func_with_reqs.py create mode 100644 mm_agents/coact/autogen/coding/jupyter/__init__.py create mode 100644 mm_agents/coact/autogen/coding/jupyter/base.py create mode 100644 mm_agents/coact/autogen/coding/jupyter/docker_jupyter_server.py create mode 100644 mm_agents/coact/autogen/coding/jupyter/embedded_ipython_code_executor.py create mode 100644 mm_agents/coact/autogen/coding/jupyter/import_utils.py create mode 100644 mm_agents/coact/autogen/coding/jupyter/jupyter_client.py create mode 100644 mm_agents/coact/autogen/coding/jupyter/jupyter_code_executor.py create mode 100644 mm_agents/coact/autogen/coding/jupyter/local_jupyter_server.py create mode 100644 mm_agents/coact/autogen/coding/local_commandline_code_executor.py create mode 100644 mm_agents/coact/autogen/coding/markdown_code_extractor.py create mode 100644 mm_agents/coact/autogen/coding/utils.py create mode 100644 mm_agents/coact/autogen/doc_utils.py create mode 100644 mm_agents/coact/autogen/events/__init__.py create mode 100644 mm_agents/coact/autogen/events/agent_events.py create mode 100644 mm_agents/coact/autogen/events/base_event.py create mode 100644 mm_agents/coact/autogen/events/client_events.py create mode 100644 mm_agents/coact/autogen/events/helpers.py create mode 100644 mm_agents/coact/autogen/events/print_event.py create mode 100644 mm_agents/coact/autogen/exception_utils.py create mode 100644 mm_agents/coact/autogen/fast_depends/__init__.py create mode 100644 mm_agents/coact/autogen/fast_depends/_compat.py create mode 100644 mm_agents/coact/autogen/fast_depends/core/__init__.py create mode 100644 mm_agents/coact/autogen/fast_depends/core/build.py create mode 100644 mm_agents/coact/autogen/fast_depends/core/model.py create mode 100644 mm_agents/coact/autogen/fast_depends/dependencies/__init__.py create mode 100644 mm_agents/coact/autogen/fast_depends/dependencies/model.py create mode 100644 mm_agents/coact/autogen/fast_depends/dependencies/provider.py create mode 100644 mm_agents/coact/autogen/fast_depends/library/__init__.py create mode 100644 mm_agents/coact/autogen/fast_depends/library/model.py create mode 100644 mm_agents/coact/autogen/fast_depends/py.typed create mode 100644 mm_agents/coact/autogen/fast_depends/schema.py create mode 100644 mm_agents/coact/autogen/fast_depends/use.py create mode 100644 mm_agents/coact/autogen/fast_depends/utils.py create mode 100644 mm_agents/coact/autogen/formatting_utils.py create mode 100644 mm_agents/coact/autogen/function_utils.py create mode 100644 mm_agents/coact/autogen/graph_utils.py create mode 100644 mm_agents/coact/autogen/import_utils.py create mode 100644 mm_agents/coact/autogen/interop/__init__.py create mode 100644 mm_agents/coact/autogen/interop/crewai/__init__.py create mode 100644 mm_agents/coact/autogen/interop/crewai/crewai.py create mode 100644 mm_agents/coact/autogen/interop/interoperability.py create mode 100644 mm_agents/coact/autogen/interop/interoperable.py create mode 100644 mm_agents/coact/autogen/interop/langchain/__init__.py create mode 100644 mm_agents/coact/autogen/interop/langchain/langchain_chat_model_factory.py create mode 100644 mm_agents/coact/autogen/interop/langchain/langchain_tool.py create mode 100644 mm_agents/coact/autogen/interop/litellm/__init__.py create mode 100644 mm_agents/coact/autogen/interop/litellm/litellm_config_factory.py create mode 100644 mm_agents/coact/autogen/interop/pydantic_ai/__init__.py create mode 100644 mm_agents/coact/autogen/interop/pydantic_ai/pydantic_ai.py create mode 100644 mm_agents/coact/autogen/interop/registry.py create mode 100644 mm_agents/coact/autogen/io/__init__.py create mode 100644 mm_agents/coact/autogen/io/base.py create mode 100644 mm_agents/coact/autogen/io/console.py create mode 100644 mm_agents/coact/autogen/io/processors/__init__.py create mode 100644 mm_agents/coact/autogen/io/processors/base.py create mode 100644 mm_agents/coact/autogen/io/processors/console_event_processor.py create mode 100644 mm_agents/coact/autogen/io/run_response.py create mode 100644 mm_agents/coact/autogen/io/thread_io_stream.py create mode 100644 mm_agents/coact/autogen/io/websockets.py create mode 100644 mm_agents/coact/autogen/json_utils.py create mode 100644 mm_agents/coact/autogen/llm_config.py create mode 100644 mm_agents/coact/autogen/logger/__init__.py create mode 100644 mm_agents/coact/autogen/logger/base_logger.py create mode 100644 mm_agents/coact/autogen/logger/file_logger.py create mode 100644 mm_agents/coact/autogen/logger/logger_factory.py create mode 100644 mm_agents/coact/autogen/logger/logger_utils.py create mode 100644 mm_agents/coact/autogen/logger/sqlite_logger.py create mode 100644 mm_agents/coact/autogen/messages/__init__.py create mode 100644 mm_agents/coact/autogen/messages/agent_messages.py create mode 100644 mm_agents/coact/autogen/messages/base_message.py create mode 100644 mm_agents/coact/autogen/messages/client_messages.py create mode 100644 mm_agents/coact/autogen/messages/print_message.py create mode 100644 mm_agents/coact/autogen/oai/__init__.py create mode 100644 mm_agents/coact/autogen/oai/anthropic.py create mode 100644 mm_agents/coact/autogen/oai/bedrock.py create mode 100644 mm_agents/coact/autogen/oai/cerebras.py create mode 100644 mm_agents/coact/autogen/oai/client.py create mode 100644 mm_agents/coact/autogen/oai/client_utils.py create mode 100644 mm_agents/coact/autogen/oai/cohere.py create mode 100644 mm_agents/coact/autogen/oai/gemini.py create mode 100644 mm_agents/coact/autogen/oai/gemini_types.py create mode 100644 mm_agents/coact/autogen/oai/groq.py create mode 100644 mm_agents/coact/autogen/oai/mistral.py create mode 100644 mm_agents/coact/autogen/oai/oai_models/__init__.py create mode 100644 mm_agents/coact/autogen/oai/oai_models/_models.py create mode 100644 mm_agents/coact/autogen/oai/oai_models/chat_completion.py create mode 100644 mm_agents/coact/autogen/oai/oai_models/chat_completion_audio.py create mode 100644 mm_agents/coact/autogen/oai/oai_models/chat_completion_message.py create mode 100644 mm_agents/coact/autogen/oai/oai_models/chat_completion_message_tool_call.py create mode 100644 mm_agents/coact/autogen/oai/oai_models/chat_completion_token_logprob.py create mode 100644 mm_agents/coact/autogen/oai/oai_models/completion_usage.py create mode 100644 mm_agents/coact/autogen/oai/ollama.py create mode 100644 mm_agents/coact/autogen/oai/openai_utils.py create mode 100644 mm_agents/coact/autogen/oai/together.py create mode 100644 mm_agents/coact/autogen/retrieve_utils.py create mode 100644 mm_agents/coact/autogen/runtime_logging.py create mode 100644 mm_agents/coact/autogen/token_count_utils.py create mode 100644 mm_agents/coact/autogen/tools/__init__.py create mode 100644 mm_agents/coact/autogen/tools/contrib/__init__.py create mode 100644 mm_agents/coact/autogen/tools/contrib/time/__init__.py create mode 100644 mm_agents/coact/autogen/tools/contrib/time/time.py create mode 100644 mm_agents/coact/autogen/tools/dependency_injection.py create mode 100644 mm_agents/coact/autogen/tools/experimental/__init__.py create mode 100644 mm_agents/coact/autogen/tools/experimental/browser_use/__init__.py create mode 100644 mm_agents/coact/autogen/tools/experimental/browser_use/browser_use.py create mode 100644 mm_agents/coact/autogen/tools/experimental/crawl4ai/__init__.py create mode 100644 mm_agents/coact/autogen/tools/experimental/crawl4ai/crawl4ai.py create mode 100644 mm_agents/coact/autogen/tools/experimental/deep_research/__init__.py create mode 100644 mm_agents/coact/autogen/tools/experimental/deep_research/deep_research.py create mode 100644 mm_agents/coact/autogen/tools/experimental/duckduckgo/__init__.py create mode 100644 mm_agents/coact/autogen/tools/experimental/duckduckgo/duckduckgo_search.py create mode 100644 mm_agents/coact/autogen/tools/experimental/google/__init__.py create mode 100644 mm_agents/coact/autogen/tools/experimental/google/authentication/__init__.py create mode 100644 mm_agents/coact/autogen/tools/experimental/google/authentication/credentials_hosted_provider.py create mode 100644 mm_agents/coact/autogen/tools/experimental/google/authentication/credentials_local_provider.py create mode 100644 mm_agents/coact/autogen/tools/experimental/google/authentication/credentials_provider.py create mode 100644 mm_agents/coact/autogen/tools/experimental/google/drive/__init__.py create mode 100644 mm_agents/coact/autogen/tools/experimental/google/drive/drive_functions.py create mode 100644 mm_agents/coact/autogen/tools/experimental/google/drive/toolkit.py create mode 100644 mm_agents/coact/autogen/tools/experimental/google/model.py create mode 100644 mm_agents/coact/autogen/tools/experimental/google/toolkit_protocol.py create mode 100644 mm_agents/coact/autogen/tools/experimental/google_search/__init__.py create mode 100644 mm_agents/coact/autogen/tools/experimental/google_search/google_search.py create mode 100644 mm_agents/coact/autogen/tools/experimental/google_search/youtube_search.py create mode 100644 mm_agents/coact/autogen/tools/experimental/messageplatform/__init__.py create mode 100644 mm_agents/coact/autogen/tools/experimental/messageplatform/discord/__init__.py create mode 100644 mm_agents/coact/autogen/tools/experimental/messageplatform/discord/discord.py create mode 100644 mm_agents/coact/autogen/tools/experimental/messageplatform/slack/__init__.py create mode 100644 mm_agents/coact/autogen/tools/experimental/messageplatform/slack/slack.py create mode 100644 mm_agents/coact/autogen/tools/experimental/messageplatform/telegram/__init__.py create mode 100644 mm_agents/coact/autogen/tools/experimental/messageplatform/telegram/telegram.py create mode 100644 mm_agents/coact/autogen/tools/experimental/perplexity/__init__.py create mode 100644 mm_agents/coact/autogen/tools/experimental/perplexity/perplexity_search.py create mode 100644 mm_agents/coact/autogen/tools/experimental/reliable/__init__.py create mode 100644 mm_agents/coact/autogen/tools/experimental/reliable/reliable.py create mode 100644 mm_agents/coact/autogen/tools/experimental/tavily/__init__.py create mode 100644 mm_agents/coact/autogen/tools/experimental/tavily/tavily_search.py create mode 100644 mm_agents/coact/autogen/tools/experimental/web_search_preview/__init__.py create mode 100644 mm_agents/coact/autogen/tools/experimental/web_search_preview/web_search_preview.py create mode 100644 mm_agents/coact/autogen/tools/experimental/wikipedia/__init__.py create mode 100644 mm_agents/coact/autogen/tools/experimental/wikipedia/wikipedia.py create mode 100644 mm_agents/coact/autogen/tools/function_utils.py create mode 100644 mm_agents/coact/autogen/tools/tool.py create mode 100644 mm_agents/coact/autogen/tools/toolkit.py create mode 100644 mm_agents/coact/autogen/types.py create mode 100644 mm_agents/coact/coding_agent.py create mode 100644 mm_agents/coact/cua_agent.py create mode 100644 mm_agents/coact/operator_agent.py create mode 100644 run_coact.py diff --git a/desktop_env/controllers/python.py b/desktop_env/controllers/python.py index c572083..743dd09 100644 --- a/desktop_env/controllers/python.py +++ b/desktop_env/controllers/python.py @@ -136,6 +136,82 @@ class PythonController: logger.error("Failed to execute command.") return None + + def run_python_script(self, script: str) -> Optional[Dict[str, Any]]: + """ + Executes a python script on the server. + """ + payload = json.dumps({"code": script}) + + for _ in range(self.retry_times): + try: + response = requests.post(self.http_server + "/run_python", headers={'Content-Type': 'application/json'}, + data=payload, timeout=90) + if response.status_code == 200: + return response.json() + else: + return {"status": "error", "message": "Failed to execute command.", "output": None, "error": response.json()["error"]} + except requests.exceptions.ReadTimeout: + break + except Exception: + logger.error("An error occurred while trying to execute the command: %s", traceback.format_exc()) + logger.info("Retrying to execute command.") + time.sleep(self.retry_interval) + + logger.error("Failed to execute command.") + return {"status": "error", "message": "Failed to execute command.", "output": "", "error": "Retry limit reached."} + + def run_bash_script(self, script: str, timeout: int = 30, working_dir: Optional[str] = None) -> Optional[Dict[str, Any]]: + """ + Executes a bash script on the server. + + :param script: The bash script content (can be multi-line) + :param timeout: Execution timeout in seconds (default: 30) + :param working_dir: Working directory for script execution (optional) + :return: Dictionary with status, output, error, and returncode, or None if failed + """ + payload = json.dumps({ + "script": script, + "timeout": timeout, + "working_dir": working_dir + }) + + for _ in range(self.retry_times): + try: + response = requests.post( + self.http_server + "/run_bash_script", + headers={'Content-Type': 'application/json'}, + data=payload, + timeout=timeout + 100 # Add buffer to HTTP timeout + ) + if response.status_code == 200: + result = response.json() + logger.info("Bash script executed successfully with return code: %d", result.get("returncode", -1)) + return result + else: + logger.error("Failed to execute bash script. Status code: %d, response: %s", + response.status_code, response.text) + logger.info("Retrying to execute bash script.") + except requests.exceptions.ReadTimeout: + logger.error("Bash script execution timed out") + return { + "status": "error", + "output": "", + "error": f"Script execution timed out after {timeout} seconds", + "returncode": -1 + } + except Exception as e: + logger.error("An error occurred while trying to execute the bash script: %s", e) + logger.info("Retrying to execute bash script.") + time.sleep(self.retry_interval) + + logger.error("Failed to execute bash script after %d retries.", self.retry_times) + return { + "status": "error", + "output": "", + "error": f"Failed to execute bash script after {self.retry_times} retries", + "returncode": -1 + } def execute_action(self, action: Dict[str, Any]): """ diff --git a/desktop_env/server/main.py b/desktop_env/server/main.py index bc0d8c8..0fc4036 100644 --- a/desktop_env/server/main.py +++ b/desktop_env/server/main.py @@ -1568,5 +1568,230 @@ def end_recording(): return abort(500, description=f"Recording failed. The output file is missing or empty. ffmpeg stderr: {error_output}") +@app.route("/run_python", methods=['POST']) +def run_python(): + data = request.json + code = data.get('code', None) + + if not code: + return jsonify({'status': 'error', 'message': 'Code not supplied!'}), 400 + + # Create a temporary file to save the Python code + import tempfile + import uuid + + # Generate unique filename + temp_filename = f"/tmp/python_exec_{uuid.uuid4().hex}.py" + + try: + # Write code to temporary file + with open(temp_filename, 'w') as f: + f.write(code) + + # Execute the file using subprocess to capture all output + result = subprocess.run( + ['/usr/bin/python3', temp_filename], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + timeout=30 # 30 second timeout + ) + + # Clean up the temporary file + try: + os.remove(temp_filename) + except: + pass # Ignore cleanup errors + + # Prepare response + output = result.stdout + error_output = result.stderr + + # Combine output and errors if both exist + combined_message = output + if error_output: + combined_message += ('\n' + error_output) if output else error_output + + # Determine status based on return code and errors + if result.returncode != 0: + status = 'error' + if not error_output: + # If no stderr but non-zero return code, add a generic error message + error_output = f"Process exited with code {result.returncode}" + combined_message = combined_message + '\n' + error_output if combined_message else error_output + else: + status = 'success' + + return jsonify({ + 'status': status, + 'message': combined_message, + 'need_more': False, # Not applicable for file execution + 'output': output, # stdout only + 'error': error_output, # stderr only + 'return_code': result.returncode + }) + + except subprocess.TimeoutExpired: + # Clean up the temporary file on timeout + try: + os.remove(temp_filename) + except: + pass + + return jsonify({ + 'status': 'error', + 'message': 'Execution timeout: Code took too long to execute', + 'error': 'TimeoutExpired', + 'need_more': False, + 'output': None, + }), 500 + + except Exception as e: + # Clean up the temporary file on error + try: + os.remove(temp_filename) + except: + pass + + # Capture the exception details + return jsonify({ + 'status': 'error', + 'message': f'Execution error: {str(e)}', + 'error': traceback.format_exc(), + 'need_more': False, + 'output': None, + }), 500 + + +@app.route("/run_bash_script", methods=['POST']) +def run_bash_script(): + data = request.json + script = data.get('script', None) + timeout = data.get('timeout', 100) # Default timeout of 30 seconds + working_dir = data.get('working_dir', None) + + if not script: + return jsonify({ + 'status': 'error', + 'output': 'Script not supplied!', + 'error': "", # Always empty as requested + 'returncode': -1 + }), 400 + + # Expand user directory if provided + if working_dir: + working_dir = os.path.expanduser(working_dir) + if not os.path.exists(working_dir): + return jsonify({ + 'status': 'error', + 'output': f'Working directory does not exist: {working_dir}', + 'error': "", # Always empty as requested + 'returncode': -1 + }), 400 + + # Create a temporary script file + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.sh', delete=False) as tmp_file: + if "#!/bin/bash" not in script: + script = "#!/bin/bash\n\n" + script + tmp_file.write(script) + tmp_file_path = tmp_file.name + + try: + # Make the script executable + os.chmod(tmp_file_path, 0o755) + + # Execute the script + if platform_name == "Windows": + # On Windows, use Git Bash or WSL if available, otherwise cmd + flags = subprocess.CREATE_NO_WINDOW + # Try to use bash if available (Git Bash, WSL, etc.) + result = subprocess.run( + ['bash', tmp_file_path], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, # Merge stderr into stdout + text=True, + timeout=timeout, + cwd=working_dir, + creationflags=flags, + shell=False + ) + else: + # On Unix-like systems, use bash directly + flags = 0 + result = subprocess.run( + ['/bin/bash', tmp_file_path], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, # Merge stderr into stdout + text=True, + timeout=timeout, + cwd=working_dir, + creationflags=flags, + shell=False + ) + + # Log the command execution for trajectory recording + _append_event("BashScript", + {"script": script, "output": result.stdout, "error": "", "returncode": result.returncode}, + ts=time.time()) + + return jsonify({ + 'status': 'success' if result.returncode == 0 else 'error', + 'output': result.stdout, # Contains both stdout and stderr merged + 'error': "", # Always empty as requested + 'returncode': result.returncode + }) + + except subprocess.TimeoutExpired: + return jsonify({ + 'status': 'error', + 'output': f'Script execution timed out after {timeout} seconds', + 'error': "", # Always empty as requested + 'returncode': -1 + }), 500 + except FileNotFoundError: + # Bash not found, try with sh + try: + result = subprocess.run( + ['sh', tmp_file_path], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, # Merge stderr into stdout + text=True, + timeout=timeout, + cwd=working_dir, + shell=False + ) + + _append_event("BashScript", + {"script": script, "output": result.stdout, "error": "", "returncode": result.returncode}, + ts=time.time()) + + return jsonify({ + 'status': 'success' if result.returncode == 0 else 'error', + 'output': result.stdout, # Contains both stdout and stderr merged + 'error': "", # Always empty as requested + 'returncode': result.returncode, + }) + except Exception as e: + return jsonify({ + 'status': 'error', + 'output': f'Failed to execute script: {str(e)}', + 'error': "", # Always empty as requested + 'returncode': -1 + }), 500 + except Exception as e: + return jsonify({ + 'status': 'error', + 'output': f'Failed to execute script: {str(e)}', + 'error': "", # Always empty as requested + 'returncode': -1 + }), 500 + finally: + # Clean up the temporary file + try: + os.unlink(tmp_file_path) + except: + pass + if __name__ == '__main__': app.run(debug=True, host="0.0.0.0") diff --git a/mm_agents/coact/OAI_CONFIG_LIST b/mm_agents/coact/OAI_CONFIG_LIST new file mode 100644 index 0000000..7f96c0e --- /dev/null +++ b/mm_agents/coact/OAI_CONFIG_LIST @@ -0,0 +1,27 @@ +[ + { + "model": "gpt-4o", + "api_key": "KEY", + "tags": ["gpt-4o", "code", "explainer"] + }, + { + "model": "o3", + "api_key": "KEY", + "tags": ["o3", "coding", "explainer"] + }, + { + "model": "gpt-4.1", + "api_key": "KEY", + "tags": ["gpt-4.1", "coding", "explainer"] + }, + { + "model": "o4-mini", + "api_key": "KEY", + "tags": ["o4-mini", "coding", "explainer"] + }, + { + "model": "o3-mini", + "api_key": "KEY", + "tags": ["o3-mini", "coding", "explainer"] + } +] \ No newline at end of file diff --git a/mm_agents/coact/__init__.py b/mm_agents/coact/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mm_agents/coact/autogen/__init__.py b/mm_agents/coact/autogen/__init__.py new file mode 100644 index 0000000..6681e5e --- /dev/null +++ b/mm_agents/coact/autogen/__init__.py @@ -0,0 +1,81 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 +# +# Portions derived from https://github.com/microsoft/autogen are under the MIT License. +# SPDX-License-Identifier: MIT +import logging + +from .agentchat import ( + Agent, + AssistantAgent, + ChatResult, + ConversableAgent, + GroupChat, + GroupChatManager, + UpdateSystemMessage, + UserProxyAgent, + gather_usage_summary, + initiate_chats, + register_function, +) +from .agentchat.group.context_expression import ContextExpression +from .code_utils import DEFAULT_MODEL, FAST_MODEL +from .exception_utils import ( + AgentNameConflictError, + InvalidCarryOverTypeError, + NoEligibleSpeakerError, + SenderRequiredError, + UndefinedNextAgentError, +) +from .llm_config import LLMConfig +from .oai import ( + Cache, + ModelClient, + OpenAIWrapper, + config_list_from_dotenv, + config_list_from_json, + config_list_from_models, + config_list_gpt4_gpt35, + config_list_openai_aoai, + filter_config, + get_config_list, +) + +# Set the root logger. +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +__all__ = [ + "DEFAULT_MODEL", + "FAST_MODEL", + "Agent", + "AgentNameConflictError", + "AssistantAgent", + "Cache", + "ChatResult", + "ContextExpression", + "ConversableAgent", + "GroupChat", + "GroupChatManager", + "InvalidCarryOverTypeError", + "LLMConfig", + "ModelClient", + "NoEligibleSpeakerError", + "OpenAIWrapper", + "SenderRequiredError", + "UndefinedNextAgentError", + "UpdateSystemMessage", + "UserProxyAgent", + "config_list_from_dotenv", + "config_list_from_json", + "config_list_from_models", + "config_list_gpt4_gpt35", + "config_list_openai_aoai", + "filter_config", + "gather_usage_summary", + "get_config_list", + "initiate_chats", + "register_function", +] diff --git a/mm_agents/coact/autogen/agentchat/__init__.py b/mm_agents/coact/autogen/agentchat/__init__.py new file mode 100644 index 0000000..dcc4508 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/__init__.py @@ -0,0 +1,38 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 +# +# Portions derived from https://github.com/microsoft/autogen are under the MIT License. +# SPDX-License-Identifier: MIT +from .agent import Agent, LLMAgent +from .assistant_agent import AssistantAgent +from .chat import ChatResult, a_initiate_chats, initiate_chats + +from .conversable_agent import ConversableAgent, UpdateSystemMessage, register_function +from .group.multi_agent_chat import a_initiate_group_chat, a_run_group_chat, initiate_group_chat, run_group_chat +from .groupchat import GroupChat, GroupChatManager +from .user_proxy_agent import UserProxyAgent +from .utils import gather_usage_summary + +__all__ = [ + "Agent", + "AssistantAgent", + "ChatResult", + "ConversableAgent", + "GroupChat", + "GroupChatManager", + "LLMAgent", + "UpdateSystemMessage", + "UserProxyAgent", + "a_initiate_chats", + "a_initiate_group_chat", + "a_initiate_swarm_chat", + "a_run_group_chat", + "a_run_swarm", + "gather_usage_summary", + "initiate_chats", + "initiate_group_chat", + "register_function", + "run_group_chat", + "run_swarm", +] diff --git a/mm_agents/coact/autogen/agentchat/agent.py b/mm_agents/coact/autogen/agentchat/agent.py new file mode 100644 index 0000000..6e034e3 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/agent.py @@ -0,0 +1,182 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 +# +# Portions derived from https://github.com/microsoft/autogen are under the MIT License. +# SPDX-License-Identifier: MIT +from typing import TYPE_CHECKING, Any, Optional, Protocol, TypeVar, Union, runtime_checkable + +from ..doc_utils import export_module + +__all__ = ["Agent", "LLMAgent", "LLMMessageType"] + +Tool = TypeVar("Tool") + +LLMMessageType = dict[str, Any] + +DEFAULT_SUMMARY_METHOD = "last_msg" + + +@runtime_checkable +@export_module("autogen") +class Agent(Protocol): + """(In preview) A protocol for Agent. + + An agent can communicate with other agents and perform actions. + Different agents can differ in what actions they perform in the `receive` method. + """ + + @property + def name(self) -> str: + """The name of the agent.""" + ... + + @property + def description(self) -> str: + """The description of the agent. Used for the agent's introduction in + a group chat setting. + """ + ... + + def send( + self, + message: Union[dict[str, Any], str], + recipient: "Agent", + request_reply: Optional[bool] = None, + ) -> None: + """Send a message to another agent. + + Args: + message (dict or str): the message to send. If a dict, it should be + a JSON-serializable and follows the OpenAI's ChatCompletion schema. + recipient (Agent): the recipient of the message. + request_reply (bool): whether to request a reply from the recipient. + """ + ... + + async def a_send( + self, + message: Union[dict[str, Any], str], + recipient: "Agent", + request_reply: Optional[bool] = None, + ) -> None: + """(Async) Send a message to another agent. + + Args: + message (dict or str): the message to send. If a dict, it should be + a JSON-serializable and follows the OpenAI's ChatCompletion schema. + recipient (Agent): the recipient of the message. + request_reply (bool): whether to request a reply from the recipient. + """ + ... + + def receive( + self, + message: Union[dict[str, Any], str], + sender: "Agent", + request_reply: Optional[bool] = None, + ) -> None: + """Receive a message from another agent. + + Args: + message (dict or str): the message received. If a dict, it should be + a JSON-serializable and follows the OpenAI's ChatCompletion schema. + sender (Agent): the sender of the message. + request_reply (bool): whether the sender requests a reply. + """ + + async def a_receive( + self, + message: Union[dict[str, Any], str], + sender: "Agent", + request_reply: Optional[bool] = None, + ) -> None: + """(Async) Receive a message from another agent. + + Args: + message (dict or str): the message received. If a dict, it should be + a JSON-serializable and follows the OpenAI's ChatCompletion schema. + sender (Agent): the sender of the message. + request_reply (bool): whether the sender requests a reply. + """ + ... + + def generate_reply( + self, + messages: Optional[list[dict[str, Any]]] = None, + sender: Optional["Agent"] = None, + **kwargs: Any, + ) -> Union[str, dict[str, Any], None]: + """Generate a reply based on the received messages. + + Args: + messages (list[dict[str, Any]]): a list of messages received from other agents. + The messages are dictionaries that are JSON-serializable and + follows the OpenAI's ChatCompletion schema. + sender: sender of an Agent instance. + **kwargs: Additional keyword arguments. + + Returns: + str or dict or None: the generated reply. If None, no reply is generated. + """ + + async def a_generate_reply( + self, + messages: Optional[list[dict[str, Any]]] = None, + sender: Optional["Agent"] = None, + **kwargs: Any, + ) -> Union[str, dict[str, Any], None]: + """(Async) Generate a reply based on the received messages. + + Args: + messages (list[dict[str, Any]]): a list of messages received from other agents. + The messages are dictionaries that are JSON-serializable and + follows the OpenAI's ChatCompletion schema. + sender: sender of an Agent instance. + **kwargs: Additional keyword arguments. + + Returns: + str or dict or None: the generated reply. If None, no reply is generated. + """ + ... + + def set_ui_tools(self, tools: list[Tool]) -> None: + """Set the UI tools for the agent. + + Args: + tools: a list of UI tools to set. + """ + ... + + def unset_ui_tools(self, tools: list[Tool]) -> None: + """Unset the UI tools for the agent. + + Args: + tools: a list of UI tools to set. + """ + ... + + +@runtime_checkable +@export_module("autogen") +class LLMAgent(Agent, Protocol): + """(In preview) A protocol for an LLM agent.""" + + @property + def system_message(self) -> str: + """The system message of this agent.""" + + def update_system_message(self, system_message: str) -> None: + """Update this agent's system message. + + Args: + system_message (str): system message for inference. + """ + + +if TYPE_CHECKING: + # mypy will fail if Conversable agent does not implement Agent protocol + from .conversable_agent import ConversableAgent + + def _check_protocol_implementation(agent: ConversableAgent) -> Agent: + return agent diff --git a/mm_agents/coact/autogen/agentchat/assistant_agent.py b/mm_agents/coact/autogen/agentchat/assistant_agent.py new file mode 100644 index 0000000..60cefb2 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/assistant_agent.py @@ -0,0 +1,85 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 +# +# Portions derived from https://github.com/microsoft/autogen are under the MIT License. +# SPDX-License-Identifier: MIT +from typing import Any, Callable, Literal, Optional, Union + +from ..doc_utils import export_module +from ..llm_config import LLMConfig +from ..runtime_logging import log_new_agent, logging_enabled +from .conversable_agent import ConversableAgent + + +@export_module("autogen") +class AssistantAgent(ConversableAgent): + """(In preview) Assistant agent, designed to solve a task with LLM. + + AssistantAgent is a subclass of ConversableAgent configured with a default system message. + The default system message is designed to solve a task with LLM, + including suggesting python code blocks and debugging. + `human_input_mode` is default to "NEVER" + and `code_execution_config` is default to False. + This agent doesn't execute code by default, and expects the user to execute the code. + """ + + DEFAULT_SYSTEM_MESSAGE = """You are a helpful AI assistant. +Solve tasks using your coding and language skills. +In the following cases, suggest python code (in a python coding block) or shell script (in a sh coding block) for the user to execute. + 1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time, check the operating system. After sufficient info is printed and the task is ready to be solved based on your language skill, you can solve the task by yourself. + 2. When you need to perform some task with code, use the code to perform the task and output the result. Finish the task smartly. +Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill. +When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user. +If you want the user to save the code in a file before executing it, put # filename: inside the code block as the first line. Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user. +If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try. +When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible. +Reply "TERMINATE" in the end when everything is done. + """ + + DEFAULT_DESCRIPTION = "A helpful and general-purpose AI assistant that has strong language skills, Python skills, and Linux command line skills." + + def __init__( + self, + name: str, + system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE, + llm_config: Optional[Union[LLMConfig, dict[str, Any], Literal[False]]] = None, + is_termination_msg: Optional[Callable[[dict[str, Any]], bool]] = None, + max_consecutive_auto_reply: Optional[int] = None, + human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER", + description: Optional[str] = None, + **kwargs: Any, + ): + """Args: + name (str): agent name. + system_message (str): system message for the ChatCompletion inference. + Please override this attribute if you want to reprogram the agent. + llm_config (dict or False or None): llm inference configuration. + Please refer to [OpenAIWrapper.create](https://docs.ag2.ai/latest/docs/api-reference/autogen/OpenAIWrapper/#autogen.OpenAIWrapper.create) + for available options. + is_termination_msg (function): a function that takes a message in the form of a dictionary + and returns a boolean value indicating if this received message is a termination message. + The dict can contain the following keys: "content", "role", "name", "function_call". + max_consecutive_auto_reply (int): the maximum number of consecutive auto replies. + default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case). + The limit only plays a role when human_input_mode is not "ALWAYS". + **kwargs (dict): Please refer to other kwargs in + [ConversableAgent](https://docs.ag2.ai/latest/docs/api-reference/autogen/ConversableAgent). + """ + super().__init__( + name, + system_message, + is_termination_msg, + max_consecutive_auto_reply, + human_input_mode, + llm_config=llm_config, + description=description, + **kwargs, + ) + if logging_enabled(): + log_new_agent(self, locals()) + + # Update the provided description if None, and we are using the default system_message, + # then use the default description. + if description is None and system_message == self.DEFAULT_SYSTEM_MESSAGE: + self.description = self.DEFAULT_DESCRIPTION diff --git a/mm_agents/coact/autogen/agentchat/chat.py b/mm_agents/coact/autogen/agentchat/chat.py new file mode 100644 index 0000000..0577d7b --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/chat.py @@ -0,0 +1,309 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 +# +# Portions derived from https://github.com/microsoft/autogen are under the MIT License. +# SPDX-License-Identifier: MIT +import asyncio +import datetime +import logging +import warnings +from collections import defaultdict +from dataclasses import dataclass +from functools import partial +from typing import Any + +from ..doc_utils import export_module +from ..events.agent_events import PostCarryoverProcessingEvent +from ..io.base import IOStream +from .utils import consolidate_chat_info + +logger = logging.getLogger(__name__) +Prerequisite = tuple[int, int] + +__all__ = ["ChatResult", "a_initiate_chats", "initiate_chats"] + + +@dataclass +@export_module("autogen") +class ChatResult: + """(Experimental) The result of a chat. Almost certain to be changed.""" + + chat_id: int = None + """chat id""" + chat_history: list[dict[str, Any]] = None + """The chat history.""" + summary: str = None + """A summary obtained from the chat.""" + cost: dict[str, dict[str, Any]] = ( + None # keys: "usage_including_cached_inference", "usage_excluding_cached_inference" + ) + """The cost of the chat. + The value for each usage type is a dictionary containing cost information for that specific type. + - "usage_including_cached_inference": Cost information on the total usage, including the tokens in cached inference. + - "usage_excluding_cached_inference": Cost information on the usage of tokens, excluding the tokens in cache. No larger than "usage_including_cached_inference". + """ + human_input: list[str] = None + """A list of human input solicited during the chat.""" + + +def _validate_recipients(chat_queue: list[dict[str, Any]]) -> None: + """Validate recipients exits and warn repetitive recipients.""" + receipts_set = set() + for chat_info in chat_queue: + assert "recipient" in chat_info, "recipient must be provided." + receipts_set.add(chat_info["recipient"]) + if len(receipts_set) < len(chat_queue): + warnings.warn( + "Repetitive recipients detected: The chat history will be cleared by default if a recipient appears more than once. To retain the chat history, please set 'clear_history=False' in the configuration of the repeating agent.", + UserWarning, + ) + + +def __create_async_prerequisites(chat_queue: list[dict[str, Any]]) -> list[Prerequisite]: + """Create list of Prerequisite (prerequisite_chat_id, chat_id)""" + prerequisites = [] + for chat_info in chat_queue: + if "chat_id" not in chat_info: + raise ValueError("Each chat must have a unique id for async multi-chat execution.") + chat_id = chat_info["chat_id"] + pre_chats = chat_info.get("prerequisites", []) + for pre_chat_id in pre_chats: + if not isinstance(pre_chat_id, int): + raise ValueError("Prerequisite chat id is not int.") + prerequisites.append((chat_id, pre_chat_id)) + return prerequisites + + +def __find_async_chat_order(chat_ids: set[int], prerequisites: list[Prerequisite]) -> list[int]: + """Find chat order for async execution based on the prerequisite chats + + Args: + chat_ids: A set of all chat IDs that need to be scheduled + prerequisites: A list of tuples (chat_id, prerequisite_chat_id) where each tuple indicates that chat_id depends on prerequisite_chat_id + + Returns: + list: a list of chat_id in order. + """ + edges = defaultdict(set) + indegree = defaultdict(int) + for pair in prerequisites: + chat, pre = pair[0], pair[1] + if chat not in edges[pre]: + indegree[chat] += 1 + edges[pre].add(chat) + bfs = [i for i in chat_ids if i not in indegree] + chat_order = [] + steps = len(indegree) + for _ in range(steps + 1): + if not bfs: + break + chat_order.extend(bfs) + nxt = [] + for node in bfs: + if node in edges: + for course in edges[node]: + indegree[course] -= 1 + if indegree[course] == 0: + nxt.append(course) + indegree.pop(course) + edges.pop(node) + bfs = nxt + + if indegree: + return [] + return chat_order + + +def _post_process_carryover_item(carryover_item): + if isinstance(carryover_item, str): + return carryover_item + elif isinstance(carryover_item, dict) and "content" in carryover_item: + return str(carryover_item["content"]) + else: + return str(carryover_item) + + +def __post_carryover_processing(chat_info: dict[str, Any]) -> None: + iostream = IOStream.get_default() + + if "message" not in chat_info: + warnings.warn( + "message is not provided in a chat_queue entry. input() will be called to get the initial message.", + UserWarning, + ) + + iostream.send(PostCarryoverProcessingEvent(chat_info=chat_info)) + + +@export_module("autogen") +def initiate_chats(chat_queue: list[dict[str, Any]]) -> list[ChatResult]: + """Initiate a list of chats. + + Args: + chat_queue (List[Dict]): A list of dictionaries containing the information about the chats. + + Each dictionary should contain the input arguments for + [`ConversableAgent.initiate_chat`](../ConversableAgent#initiate-chat). + For example: + - `"sender"` - the sender agent. + - `"recipient"` - the recipient agent. + - `"clear_history"` (bool) - whether to clear the chat history with the agent. + Default is True. + - `"silent"` (bool or None) - (Experimental) whether to print the messages in this + conversation. Default is False. + - `"cache"` (Cache or None) - the cache client to use for this conversation. + Default is None. + - `"max_turns"` (int or None) - maximum number of turns for the chat. If None, the chat + will continue until a termination condition is met. Default is None. + - `"summary_method"` (str or callable) - a string or callable specifying the method to get + a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg". + - `"summary_args"` (dict) - a dictionary of arguments to be passed to the summary_method. + Default is {}. + - `"message"` (str, callable or None) - if None, input() will be called to get the + initial message. + - `**context` - additional context information to be passed to the chat. + - `"carryover"` - It can be used to specify the carryover information to be passed + to this chat. If provided, we will combine this carryover with the "message" content when + generating the initial chat message in `generate_init_message`. + - `"finished_chat_indexes_to_exclude_from_carryover"` - It can be used by specifying a list of indexes of the finished_chats list, + from which to exclude the summaries for carryover. If 'finished_chat_indexes_to_exclude_from_carryover' is not provided or an empty list, + then summary from all the finished chats will be taken. + + Returns: + (list): a list of ChatResult objects corresponding to the finished chats in the chat_queue. + """ + consolidate_chat_info(chat_queue) + _validate_recipients(chat_queue) + current_chat_queue = chat_queue.copy() + finished_chats = [] + while current_chat_queue: + chat_info = current_chat_queue.pop(0) + _chat_carryover = chat_info.get("carryover", []) + finished_chat_indexes_to_exclude_from_carryover = chat_info.get( + "finished_chat_indexes_to_exclude_from_carryover", [] + ) + + if isinstance(_chat_carryover, str): + _chat_carryover = [_chat_carryover] + chat_info["carryover"] = _chat_carryover + [ + r.summary for i, r in enumerate(finished_chats) if i not in finished_chat_indexes_to_exclude_from_carryover + ] + + if not chat_info.get("silent", False): + __post_carryover_processing(chat_info) + + sender = chat_info["sender"] + chat_res = sender.initiate_chat(**chat_info) + finished_chats.append(chat_res) + return finished_chats + + +def __system_now_str(): + ct = datetime.datetime.now() + return f" System time at {ct}. " + + +def _on_chat_future_done(chat_future: asyncio.Future, chat_id: int): + """Update ChatResult when async Task for Chat is completed.""" + logger.debug(f"Update chat {chat_id} result on task completion." + __system_now_str()) + chat_result = chat_future.result() + chat_result.chat_id = chat_id + + +async def _dependent_chat_future( + chat_id: int, chat_info: dict[str, Any], prerequisite_chat_futures: dict[int, asyncio.Future] +) -> asyncio.Task: + """Create an async Task for each chat.""" + logger.debug(f"Create Task for chat {chat_id}." + __system_now_str()) + _chat_carryover = chat_info.get("carryover", []) + finished_chat_indexes_to_exclude_from_carryover = chat_info.get( + "finished_chat_indexes_to_exclude_from_carryover", [] + ) + finished_chats = dict() + for chat in prerequisite_chat_futures: + chat_future = prerequisite_chat_futures[chat] + if chat_future.cancelled(): + raise RuntimeError(f"Chat {chat} is cancelled.") + + # wait for prerequisite chat results for the new chat carryover + finished_chats[chat] = await chat_future + + if isinstance(_chat_carryover, str): + _chat_carryover = [_chat_carryover] + data = [ + chat_result.summary + for chat_id, chat_result in finished_chats.items() + if chat_id not in finished_chat_indexes_to_exclude_from_carryover + ] + chat_info["carryover"] = _chat_carryover + data + if not chat_info.get("silent", False): + __post_carryover_processing(chat_info) + + sender = chat_info["sender"] + chat_res_future = asyncio.create_task(sender.a_initiate_chat(**chat_info)) + call_back_with_args = partial(_on_chat_future_done, chat_id=chat_id) + chat_res_future.add_done_callback(call_back_with_args) + logger.debug(f"Task for chat {chat_id} created." + __system_now_str()) + return chat_res_future + + +async def a_initiate_chats(chat_queue: list[dict[str, Any]]) -> dict[int, ChatResult]: + """(async) Initiate a list of chats. + + Args: + chat_queue (List[Dict]): A list of dictionaries containing the information about the chats. + + Each dictionary should contain the input arguments for + [`ConversableAgent.initiate_chat`](../../../ConversableAgent#initiate-chat). + For example: + - `"sender"` - the sender agent. + - `"recipient"` - the recipient agent. + - `"clear_history"` (bool) - whether to clear the chat history with the agent. + Default is True. + - `"silent"` (bool or None) - (Experimental) whether to print the messages in this + conversation. Default is False. + - `"cache"` (Cache or None) - the cache client to use for this conversation. + Default is None. + - `"max_turns"` (int or None) - maximum number of turns for the chat. If None, the chat + will continue until a termination condition is met. Default is None. + - `"summary_method"` (str or callable) - a string or callable specifying the method to get + a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg". + - `"summary_args"` (dict) - a dictionary of arguments to be passed to the summary_method. + Default is {}. + - `"message"` (str, callable or None) - if None, input() will be called to get the + initial message. + - `**context` - additional context information to be passed to the chat. + - `"carryover"` - It can be used to specify the carryover information to be passed + to this chat. If provided, we will combine this carryover with the "message" content when + generating the initial chat message in `generate_init_message`. + - `"finished_chat_indexes_to_exclude_from_carryover"` - It can be used by specifying a list of indexes of the finished_chats list, + from which to exclude the summaries for carryover. If 'finished_chat_indexes_to_exclude_from_carryover' is not provided or an empty list, + then summary from all the finished chats will be taken. + + + Returns: + - (Dict): a dict of ChatId: ChatResult corresponding to the finished chats in the chat_queue. + """ + consolidate_chat_info(chat_queue) + _validate_recipients(chat_queue) + chat_book = {chat_info["chat_id"]: chat_info for chat_info in chat_queue} + num_chats = chat_book.keys() + prerequisites = __create_async_prerequisites(chat_queue) + chat_order_by_id = __find_async_chat_order(num_chats, prerequisites) + finished_chat_futures = dict() + for chat_id in chat_order_by_id: + chat_info = chat_book[chat_id] + prerequisite_chat_ids = chat_info.get("prerequisites", []) + pre_chat_futures = dict() + for pre_chat_id in prerequisite_chat_ids: + pre_chat_future = finished_chat_futures[pre_chat_id] + pre_chat_futures[pre_chat_id] = pre_chat_future + current_chat_future = await _dependent_chat_future(chat_id, chat_info, pre_chat_futures) + finished_chat_futures[chat_id] = current_chat_future + await asyncio.gather(*list(finished_chat_futures.values())) + finished_chats = dict() + for chat in finished_chat_futures: + chat_result = finished_chat_futures[chat].result() + finished_chats[chat] = chat_result + return finished_chats diff --git a/mm_agents/coact/autogen/agentchat/contrib/__init__.py b/mm_agents/coact/autogen/agentchat/contrib/__init__.py new file mode 100644 index 0000000..a80fb86 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/contrib/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +__all__: list[str] = [] diff --git a/mm_agents/coact/autogen/agentchat/contrib/capabilities/__init__.py b/mm_agents/coact/autogen/agentchat/contrib/capabilities/__init__.py new file mode 100644 index 0000000..a80fb86 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/contrib/capabilities/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +__all__: list[str] = [] diff --git a/mm_agents/coact/autogen/agentchat/contrib/capabilities/agent_capability.py b/mm_agents/coact/autogen/agentchat/contrib/capabilities/agent_capability.py new file mode 100644 index 0000000..ee7bdf7 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/contrib/capabilities/agent_capability.py @@ -0,0 +1,20 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 +# +# Portions derived from https://github.com/microsoft/autogen are under the MIT License. +# SPDX-License-Identifier: MIT +from ...assistant_agent import ConversableAgent + + +class AgentCapability: + """Base class for composable capabilities that can be added to an agent.""" + + def __init__(self): + pass + + def add_to_agent(self, agent: ConversableAgent): + """Adds a particular capability to the given agent. Must be implemented by the capability subclass. + An implementation will typically call agent.register_hook() one or more times. See teachability.py as an example. + """ + raise NotImplementedError diff --git a/mm_agents/coact/autogen/agentchat/contrib/capabilities/generate_images.py b/mm_agents/coact/autogen/agentchat/contrib/capabilities/generate_images.py new file mode 100644 index 0000000..d98eb58 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/contrib/capabilities/generate_images.py @@ -0,0 +1,301 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 +# +# Portions derived from https://github.com/microsoft/autogen are under the MIT License. +# SPDX-License-Identifier: MIT +import re +from typing import Any, Literal, Optional, Protocol, Union + +from .... import Agent, ConversableAgent, code_utils +from ....cache import AbstractCache +from ....import_utils import optional_import_block, require_optional_import +from ....llm_config import LLMConfig +from .. import img_utils +from ..capabilities.agent_capability import AgentCapability +from ..text_analyzer_agent import TextAnalyzerAgent + +with optional_import_block(): + from PIL.Image import Image + from openai import OpenAI + +SYSTEM_MESSAGE = "You've been given the special ability to generate images." +DESCRIPTION_MESSAGE = "This agent has the ability to generate images." + +PROMPT_INSTRUCTIONS = """In detail, please summarize the provided prompt to generate the image described in the TEXT. +DO NOT include any advice. RESPOND like the following example: +EXAMPLE: Blue background, 3D shapes, ... +""" + + +class ImageGenerator(Protocol): + """This class defines an interface for image generators. + + Concrete implementations of this protocol must provide a `generate_image` method that takes a string prompt as + input and returns a PIL Image object. + + NOTE: Current implementation does not allow you to edit a previously existing image. + """ + + def generate_image(self, prompt: str) -> "Image": + """Generates an image based on the provided prompt. + + Args: + prompt: A string describing the desired image. + + Returns: + A PIL Image object representing the generated image. + + Raises: + ValueError: If the image generation fails. + """ + ... + + def cache_key(self, prompt: str) -> str: + """Generates a unique cache key for the given prompt. + + This key can be used to store and retrieve generated images based on the prompt. + + Args: + prompt: A string describing the desired image. + + Returns: + A unique string that can be used as a cache key. + """ + ... + + +@require_optional_import("PIL", "unknown") +@require_optional_import("openai>=1.66.2", "openai") +class DalleImageGenerator: + """Generates images using OpenAI's DALL-E models. + + This class provides a convenient interface for generating images based on textual prompts using OpenAI's DALL-E + models. It allows you to specify the DALL-E model, resolution, quality, and the number of images to generate. + + Note: Current implementation does not allow you to edit a previously existing image. + """ + + def __init__( + self, + llm_config: Union[LLMConfig, dict[str, Any]], + resolution: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] = "1024x1024", + quality: Literal["standard", "hd"] = "standard", + num_images: int = 1, + ): + """Args: + llm_config (LLMConfig or dict): llm config, must contain a valid dalle model and OpenAI API key in config_list. + resolution (str): The resolution of the image you want to generate. Must be one of "256x256", "512x512", "1024x1024", "1792x1024", "1024x1792". + quality (str): The quality of the image you want to generate. Must be one of "standard", "hd". + num_images (int): The number of images to generate. + """ + config_list = llm_config["config_list"] + _validate_dalle_model(config_list[0]["model"]) + _validate_resolution_format(resolution) + + self._model = config_list[0]["model"] + self._resolution = resolution + self._quality = quality + self._num_images = num_images + self._dalle_client = OpenAI(api_key=config_list[0]["api_key"]) + + def generate_image(self, prompt: str) -> "Image": + response = self._dalle_client.images.generate( + model=self._model, + prompt=prompt, + size=self._resolution, + quality=self._quality, + n=self._num_images, + ) + + image_url = response.data[0].url + if image_url is None: + raise ValueError("Failed to generate image.") + + return img_utils.get_pil_image(image_url) + + def cache_key(self, prompt: str) -> str: + keys = (prompt, self._model, self._resolution, self._quality, self._num_images) + return ",".join([str(k) for k in keys]) + + +@require_optional_import("PIL", "unknown") +class ImageGeneration(AgentCapability): + """This capability allows a ConversableAgent to generate images based on the message received from other Agents. + + 1. Utilizes a TextAnalyzerAgent to analyze incoming messages to identify requests for image generation and + extract relevant details. + 2. Leverages the provided ImageGenerator (e.g., DalleImageGenerator) to create the image. + 3. Optionally caches generated images for faster retrieval in future conversations. + + NOTE: This capability increases the token usage of the agent, as it uses TextAnalyzerAgent to analyze every + message received by the agent. + + Example: + ```python + import autogen + from autogen.agentchat.contrib.capabilities.image_generation import ImageGeneration + + # Assuming you have llm configs configured for the LLMs you want to use and Dalle. + # Create the agent + agent = autogen.ConversableAgent( + name="dalle", llm_config={...}, max_consecutive_auto_reply=3, human_input_mode="NEVER" + ) + + # Create an ImageGenerator with desired settings + dalle_gen = generate_images.DalleImageGenerator(llm_config={...}) + + # Add the ImageGeneration capability to the agent + agent.add_capability(ImageGeneration(image_generator=dalle_gen)) + ``` + """ + + def __init__( + self, + image_generator: ImageGenerator, + cache: Optional[AbstractCache] = None, + text_analyzer_llm_config: Optional[Union[LLMConfig, dict[str, Any]]] = None, + text_analyzer_instructions: str = PROMPT_INSTRUCTIONS, + verbosity: int = 0, + register_reply_position: int = 2, + ): + """Args: + image_generator (ImageGenerator): The image generator you would like to use to generate images. + cache (None or AbstractCache): The cache client to use to store and retrieve generated images. If None, + no caching will be used. + text_analyzer_llm_config (LLMConfig or Dict or None): The LLM config for the text analyzer. If None, the LLM config will + be retrieved from the agent you're adding the ability to. + text_analyzer_instructions (str): Instructions provided to the TextAnalyzerAgent used to analyze + incoming messages and extract the prompt for image generation. The default instructions focus on + summarizing the prompt. You can customize the instructions to achieve more granular control over prompt + extraction. + Example: 'Extract specific details from the message, like desired objects, styles, or backgrounds.' + verbosity (int): The verbosity level. Defaults to 0 and must be greater than or equal to 0. The text + analyzer llm calls will be silent if verbosity is less than 2. + register_reply_position (int): The position of the reply function in the agent's list of reply functions. + This capability registers a new reply function to handle messages with image generation requests. + Defaults to 2 to place it after the check termination and human reply for a ConversableAgent. + """ + self._image_generator = image_generator + self._cache = cache + self._text_analyzer_llm_config = text_analyzer_llm_config + self._text_analyzer_instructions = text_analyzer_instructions + self._verbosity = verbosity + self._register_reply_position = register_reply_position + + self._agent: Optional[ConversableAgent] = None + self._text_analyzer: Optional[TextAnalyzerAgent] = None + + def add_to_agent(self, agent: ConversableAgent): + """Adds the Image Generation capability to the specified ConversableAgent. + + This function performs the following modifications to the agent: + + 1. Registers a reply function: A new reply function is registered with the agent to handle messages that + potentially request image generation. This function analyzes the message and triggers image generation if + necessary. + 2. Creates an Agent (TextAnalyzerAgent): This is used to analyze messages for image generation requirements. + 3. Updates System Message: The agent's system message is updated to include a message indicating the + capability to generate images has been added. + 4. Updates Description: The agent's description is updated to reflect the addition of the Image Generation + capability. This might be helpful in certain use cases, like group chats. + + Args: + agent (ConversableAgent): The ConversableAgent to add the capability to. + """ + self._agent = agent + + agent.register_reply([Agent, None], self._image_gen_reply, position=self._register_reply_position) + + self._text_analyzer_llm_config = self._text_analyzer_llm_config or agent.llm_config + self._text_analyzer = TextAnalyzerAgent(llm_config=self._text_analyzer_llm_config) + + agent.update_system_message(agent.system_message + "\n" + SYSTEM_MESSAGE) + agent.description += "\n" + DESCRIPTION_MESSAGE + + def _image_gen_reply( + self, + recipient: ConversableAgent, + messages: Optional[list[dict[str, Any]]], + sender: Optional[Agent] = None, + config: Optional[Any] = None, + ) -> tuple[bool, Optional[Union[str, dict[str, Any]]]]: + if messages is None: + return False, None + + last_message = code_utils.content_str(messages[-1]["content"]) + + if not last_message: + return False, None + + if self._should_generate_image(last_message): + prompt = self._extract_prompt(last_message) + + image = self._cache_get(prompt) + if image is None: + image = self._image_generator.generate_image(prompt) + self._cache_set(prompt, image) + + return True, self._generate_content_message(prompt, image) + + else: + return False, None + + def _should_generate_image(self, message: str) -> bool: + assert self._text_analyzer is not None + + instructions = """ + Does any part of the TEXT ask the agent to generate an image? + The TEXT must explicitly mention that the image must be generated. + Answer with just one word, yes or no. + """ + analysis = self._text_analyzer.analyze_text(message, instructions) + + return "yes" in self._extract_analysis(analysis).lower() + + def _extract_prompt(self, last_message) -> str: + assert self._text_analyzer is not None + + analysis = self._text_analyzer.analyze_text(last_message, self._text_analyzer_instructions) + return self._extract_analysis(analysis) + + def _cache_get(self, prompt: str) -> Optional["Image"]: + if self._cache: + key = self._image_generator.cache_key(prompt) + cached_value = self._cache.get(key) + + if cached_value: + return img_utils.get_pil_image(cached_value) + + def _cache_set(self, prompt: str, image: "Image"): + if self._cache: + key = self._image_generator.cache_key(prompt) + self._cache.set(key, img_utils.pil_to_data_uri(image)) + + def _extract_analysis(self, analysis: Optional[Union[str, dict[str, Any]]]) -> str: + if isinstance(analysis, dict): + return code_utils.content_str(analysis["content"]) + else: + return code_utils.content_str(analysis) + + def _generate_content_message(self, prompt: str, image: "Image") -> dict[str, Any]: + return { + "content": [ + {"type": "text", "text": f"I generated an image with the prompt: {prompt}"}, + {"type": "image_url", "image_url": {"url": img_utils.pil_to_data_uri(image)}}, + ] + } + + +# Helpers +def _validate_resolution_format(resolution: str): + """Checks if a string is in a valid resolution format (e.g., "1024x768").""" + pattern = r"^\d+x\d+$" # Matches a pattern of digits, "x", and digits + matched_resolution = re.match(pattern, resolution) + if matched_resolution is None: + raise ValueError(f"Invalid resolution format: {resolution}") + + +def _validate_dalle_model(model: str): + if model not in ["dall-e-3", "dall-e-2"]: + raise ValueError(f"Invalid DALL-E model: {model}. Must be 'dall-e-3' or 'dall-e-2'") diff --git a/mm_agents/coact/autogen/agentchat/contrib/capabilities/teachability.py b/mm_agents/coact/autogen/agentchat/contrib/capabilities/teachability.py new file mode 100644 index 0000000..b6e6634 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/contrib/capabilities/teachability.py @@ -0,0 +1,393 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 +# +# Portions derived from https://github.com/microsoft/autogen are under the MIT License. +# SPDX-License-Identifier: MIT +import os +import pickle +from typing import Any, Optional, Union + +from ....formatting_utils import colored +from ....import_utils import optional_import_block, require_optional_import +from ....llm_config import LLMConfig +from ...assistant_agent import ConversableAgent +from ..text_analyzer_agent import TextAnalyzerAgent +from .agent_capability import AgentCapability + +with optional_import_block(): + import chromadb + from chromadb.config import Settings + + +class Teachability(AgentCapability): + """Teachability uses a vector database to give an agent the ability to remember user teachings, + where the user is any caller (human or not) sending messages to the teachable agent. + Teachability is designed to be composable with other agent capabilities. + To make any conversable agent teachable, instantiate both the agent and the Teachability class, + then pass the agent to teachability.add_to_agent(agent). + Note that teachable agents in a group chat must be given unique path_to_db_dir values. + + When adding Teachability to an agent, the following are modified: + - The agent's system message is appended with a note about the agent's new ability. + - A hook is added to the agent's `process_last_received_message` hookable method, + and the hook potentially modifies the last of the received messages to include earlier teachings related to the message. + Added teachings do not propagate into the stored message history. + If new user teachings are detected, they are added to new memos in the vector database. + """ + + def __init__( + self, + verbosity: Optional[int] = 0, + reset_db: Optional[bool] = False, + path_to_db_dir: Optional[str] = "./tmp/teachable_agent_db", + recall_threshold: Optional[float] = 1.5, + max_num_retrievals: Optional[int] = 10, + llm_config: Optional[Union[LLMConfig, dict[str, Any], bool]] = None, + ): + """Args: + verbosity (Optional, int): # 0 (default) for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists. + reset_db (Optional, bool): True to clear the DB before starting. Default False. + path_to_db_dir (Optional, str): path to the directory where this particular agent's DB is stored. Default "./tmp/teachable_agent_db" + recall_threshold (Optional, float): The maximum distance for retrieved memos, where 0.0 is exact match. Default 1.5. Larger values allow more (but less relevant) memos to be recalled. + max_num_retrievals (Optional, int): The maximum number of memos to retrieve from the DB. Default 10. + llm_config (LLMConfig or dict or False): llm inference configuration passed to TextAnalyzerAgent. + If None, TextAnalyzerAgent uses llm_config from the teachable agent. + """ + self.verbosity = verbosity + self.path_to_db_dir = path_to_db_dir + self.recall_threshold = recall_threshold + self.max_num_retrievals = max_num_retrievals + self.llm_config = llm_config + + self.analyzer = None + self.teachable_agent = None + + # Create the memo store. + self.memo_store = MemoStore(self.verbosity, reset_db, self.path_to_db_dir) + + def add_to_agent(self, agent: ConversableAgent): + """Adds teachability to the given agent.""" + self.teachable_agent = agent + + # Register a hook for processing the last message. + agent.register_hook(hookable_method="process_last_received_message", hook=self.process_last_received_message) + + # Was an llm_config passed to the constructor? + if self.llm_config is None: + # No. Use the agent's llm_config. + self.llm_config = agent.llm_config + assert self.llm_config, "Teachability requires a valid llm_config." + + # Create the analyzer agent. + self.analyzer = TextAnalyzerAgent(llm_config=self.llm_config) + + # Append extra info to the system message. + agent.update_system_message( + agent.system_message + + "\nYou've been given the special ability to remember user teachings from prior conversations." + ) + + def prepopulate_db(self): + """Adds a few arbitrary memos to the DB.""" + self.memo_store.prepopulate() + + def process_last_received_message(self, text: Union[dict[str, Any], str]): + """Appends any relevant memos to the message text, and stores any apparent teachings in new memos. + Uses TextAnalyzerAgent to make decisions about memo storage and retrieval. + """ + # Try to retrieve relevant memos from the DB. + expanded_text = text + if self.memo_store.last_memo_id > 0: + expanded_text = self._consider_memo_retrieval(text) + + # Try to store any user teachings in new memos to be used in the future. + self._consider_memo_storage(text) + + # Return the (possibly) expanded message text. + return expanded_text + + def _consider_memo_storage(self, comment: Union[dict[str, Any], str]): + """Decides whether to store something from one user comment in the DB.""" + memo_added = False + + # Check for a problem-solution pair. + response = self._analyze( + comment, + "Does any part of the TEXT ask the agent to perform a task or solve a problem? Answer with just one word, yes or no.", + ) + if "yes" in response.lower(): + # Can we extract advice? + advice = self._analyze( + comment, + "Briefly copy any advice from the TEXT that may be useful for a similar but different task in the future. But if no advice is present, just respond with 'none'.", + ) + if "none" not in advice.lower(): + # Yes. Extract the task. + task = self._analyze( + comment, + "Briefly copy just the task from the TEXT, then stop. Don't solve it, and don't include any advice.", + ) + # Generalize the task. + general_task = self._analyze( + task, + "Summarize very briefly, in general terms, the type of task described in the TEXT. Leave out details that might not appear in a similar problem.", + ) + # Add the task-advice (problem-solution) pair to the vector DB. + if self.verbosity >= 1: + print(colored("\nREMEMBER THIS TASK-ADVICE PAIR", "light_yellow")) + self.memo_store.add_input_output_pair(general_task, advice) + memo_added = True + + # Check for information to be learned. + response = self._analyze( + comment, + "Does the TEXT contain information that could be committed to memory? Answer with just one word, yes or no.", + ) + if "yes" in response.lower(): + # Yes. What question would this information answer? + question = self._analyze( + comment, + "Imagine that the user forgot this information in the TEXT. How would they ask you for this information? Include no other text in your response.", + ) + # Extract the information. + answer = self._analyze( + comment, "Copy the information from the TEXT that should be committed to memory. Add no explanation." + ) + # Add the question-answer pair to the vector DB. + if self.verbosity >= 1: + print(colored("\nREMEMBER THIS QUESTION-ANSWER PAIR", "light_yellow")) + self.memo_store.add_input_output_pair(question, answer) + memo_added = True + + # Were any memos added? + if memo_added: + # Yes. Save them to disk. + self.memo_store._save_memos() + + def _consider_memo_retrieval(self, comment: Union[dict[str, Any], str]): + """Decides whether to retrieve memos from the DB, and add them to the chat context.""" + # First, use the comment directly as the lookup key. + if self.verbosity >= 1: + print(colored("\nLOOK FOR RELEVANT MEMOS, AS QUESTION-ANSWER PAIRS", "light_yellow")) + memo_list = self._retrieve_relevant_memos(comment) + + # Next, if the comment involves a task, then extract and generalize the task before using it as the lookup key. + response = self._analyze( + comment, + "Does any part of the TEXT ask the agent to perform a task or solve a problem? Answer with just one word, yes or no.", + ) + if "yes" in response.lower(): + if self.verbosity >= 1: + print(colored("\nLOOK FOR RELEVANT MEMOS, AS TASK-ADVICE PAIRS", "light_yellow")) + # Extract the task. + task = self._analyze( + comment, "Copy just the task from the TEXT, then stop. Don't solve it, and don't include any advice." + ) + # Generalize the task. + general_task = self._analyze( + task, + "Summarize very briefly, in general terms, the type of task described in the TEXT. Leave out details that might not appear in a similar problem.", + ) + # Append any relevant memos. + memo_list.extend(self._retrieve_relevant_memos(general_task)) + + # De-duplicate the memo list. + memo_list = list(set(memo_list)) + + # Append the memos to the text of the last message. + return comment + self._concatenate_memo_texts(memo_list) + + def _retrieve_relevant_memos(self, input_text: str) -> list: + """Returns semantically related memos from the DB.""" + memo_list = self.memo_store.get_related_memos( + input_text, n_results=self.max_num_retrievals, threshold=self.recall_threshold + ) + + if self.verbosity >= 1: # noqa: SIM102 + # Was anything retrieved? + if len(memo_list) == 0: + # No. Look at the closest memo. + print(colored("\nTHE CLOSEST MEMO IS BEYOND THE THRESHOLD:", "light_yellow")) + self.memo_store.get_nearest_memo(input_text) + print() # Print a blank line. The memo details were printed by get_nearest_memo(). + + # Create a list of just the memo output_text strings. + memo_list = [memo[1] for memo in memo_list] + return memo_list + + def _concatenate_memo_texts(self, memo_list: list) -> str: + """Concatenates the memo texts into a single string for inclusion in the chat context.""" + memo_texts = "" + if len(memo_list) > 0: + info = "\n# Memories that might help\n" + for memo in memo_list: + info = info + "- " + memo + "\n" + if self.verbosity >= 1: + print(colored("\nMEMOS APPENDED TO LAST MESSAGE...\n" + info + "\n", "light_yellow")) + memo_texts = memo_texts + "\n" + info + return memo_texts + + def _analyze(self, text_to_analyze: Union[dict[str, Any], str], analysis_instructions: Union[dict[str, Any], str]): + """Asks TextAnalyzerAgent to analyze the given text according to specific instructions.""" + self.analyzer.reset() # Clear the analyzer's list of messages. + self.teachable_agent.send( + recipient=self.analyzer, message=text_to_analyze, request_reply=False, silent=(self.verbosity < 2) + ) # Put the message in the analyzer's list. + self.teachable_agent.send( + recipient=self.analyzer, message=analysis_instructions, request_reply=True, silent=(self.verbosity < 2) + ) # Request the reply. + return self.teachable_agent.last_message(self.analyzer)["content"] + + +@require_optional_import("chromadb", "teachable") +class MemoStore: + """Provides memory storage and retrieval for a teachable agent, using a vector database. + Each DB entry (called a memo) is a pair of strings: an input text and an output text. + The input text might be a question, or a task to perform. + The output text might be an answer to the question, or advice on how to perform the task. + Vector embeddings are currently supplied by Chroma's default Sentence Transformers. + """ + + def __init__( + self, + verbosity: Optional[int] = 0, + reset: Optional[bool] = False, + path_to_db_dir: Optional[str] = "./tmp/teachable_agent_db", + ): + """Args: + - verbosity (Optional, int): 1 to print memory operations, 0 to omit them. 3+ to print memo lists. + - reset (Optional, bool): True to clear the DB before starting. Default False. + - path_to_db_dir (Optional, str): path to the directory where the DB is stored. + """ + self.verbosity = verbosity + self.path_to_db_dir = path_to_db_dir + + # Load or create the vector DB on disk. + settings = Settings( + anonymized_telemetry=False, allow_reset=True, is_persistent=True, persist_directory=path_to_db_dir + ) + self.db_client = chromadb.Client(settings) + self.vec_db = self.db_client.create_collection("memos", get_or_create=True) # The collection is the DB. + + # Load or create the associated memo dict on disk. + self.path_to_dict = os.path.join(path_to_db_dir, "uid_text_dict.pkl") + self.uid_text_dict = {} + self.last_memo_id = 0 + if (not reset) and os.path.exists(self.path_to_dict): + print(colored("\nLOADING MEMORY FROM DISK", "light_green")) + print(colored(f" Location = {self.path_to_dict}", "light_green")) + with open(self.path_to_dict, "rb") as f: + self.uid_text_dict = pickle.load(f) + self.last_memo_id = len(self.uid_text_dict) + if self.verbosity >= 3: + self.list_memos() + + # Clear the DB if requested. + if reset: + self.reset_db() + + def list_memos(self): + """Prints the contents of MemoStore.""" + print(colored("LIST OF MEMOS", "light_green")) + for uid, text in self.uid_text_dict.items(): + input_text, output_text = text + print( + colored( + f" ID: {uid}\n INPUT TEXT: {input_text}\n OUTPUT TEXT: {output_text}", + "light_green", + ) + ) + + def _save_memos(self): + """Saves self.uid_text_dict to disk.""" + with open(self.path_to_dict, "wb") as file: + pickle.dump(self.uid_text_dict, file) + + def reset_db(self): + """Forces immediate deletion of the DB's contents, in memory and on disk.""" + print(colored("\nCLEARING MEMORY", "light_green")) + self.db_client.delete_collection("memos") + self.vec_db = self.db_client.create_collection("memos") + self.uid_text_dict = {} + self._save_memos() + + def add_input_output_pair(self, input_text: str, output_text: str): + """Adds an input-output pair to the vector DB.""" + self.last_memo_id += 1 + self.vec_db.add(documents=[input_text], ids=[str(self.last_memo_id)]) + self.uid_text_dict[str(self.last_memo_id)] = input_text, output_text + if self.verbosity >= 1: + print( + colored( + f"\nINPUT-OUTPUT PAIR ADDED TO VECTOR DATABASE:\n ID\n {self.last_memo_id}\n INPUT\n {input_text}\n OUTPUT\n {output_text}\n", + "light_yellow", + ) + ) + if self.verbosity >= 3: + self.list_memos() + + def get_nearest_memo(self, query_text: str): + """Retrieves the nearest memo to the given query text.""" + results = self.vec_db.query(query_texts=[query_text], n_results=1) + uid, input_text, distance = results["ids"][0][0], results["documents"][0][0], results["distances"][0][0] + input_text_2, output_text = self.uid_text_dict[uid] + assert input_text == input_text_2 + if self.verbosity >= 1: + print( + colored( + f"\nINPUT-OUTPUT PAIR RETRIEVED FROM VECTOR DATABASE:\n INPUT1\n {input_text}\n OUTPUT\n {output_text}\n DISTANCE\n {distance}", + "light_yellow", + ) + ) + return input_text, output_text, distance + + def get_related_memos(self, query_text: str, n_results: int, threshold: Union[int, float]): + """Retrieves memos that are related to the given query text within the specified distance threshold.""" + if n_results > len(self.uid_text_dict): + n_results = len(self.uid_text_dict) + results = self.vec_db.query(query_texts=[query_text], n_results=n_results) + memos = [] + num_results = len(results["ids"][0]) + for i in range(num_results): + uid, input_text, distance = results["ids"][0][i], results["documents"][0][i], results["distances"][0][i] + if distance < threshold: + input_text_2, output_text = self.uid_text_dict[uid] + assert input_text == input_text_2 + if self.verbosity >= 1: + print( + colored( + f"\nINPUT-OUTPUT PAIR RETRIEVED FROM VECTOR DATABASE:\n INPUT1\n {input_text}\n OUTPUT\n {output_text}\n DISTANCE\n {distance}", + "light_yellow", + ) + ) + memos.append((input_text, output_text, distance)) + return memos + + def prepopulate(self): + """Adds a few arbitrary examples to the vector DB, just to make retrieval less trivial.""" + if self.verbosity >= 1: + print(colored("\nPREPOPULATING MEMORY", "light_green")) + examples = [] + examples.append({"text": "When I say papers I mean research papers, which are typically pdfs.", "label": "yes"}) + examples.append({"text": "Please verify that each paper you listed actually uses langchain.", "label": "no"}) + examples.append({"text": "Tell gpt the output should still be latex code.", "label": "no"}) + examples.append({"text": "Hint: convert pdfs to text and then answer questions based on them.", "label": "yes"}) + examples.append({ + "text": "To create a good PPT, include enough content to make it interesting.", + "label": "yes", + }) + examples.append({ + "text": "No, for this case the columns should be aspects and the rows should be frameworks.", + "label": "no", + }) + examples.append({"text": "When writing code, remember to include any libraries that are used.", "label": "yes"}) + examples.append({"text": "Please summarize the papers by Eric Horvitz on bounded rationality.", "label": "no"}) + examples.append({"text": "Compare the h-index of Daniel Weld and Oren Etzioni.", "label": "no"}) + examples.append({ + "text": "Double check to be sure that the columns in a table correspond to what was asked for.", + "label": "yes", + }) + for example in examples: + self.add_input_output_pair(example["text"], example["label"]) + self._save_memos() diff --git a/mm_agents/coact/autogen/agentchat/contrib/capabilities/text_compressors.py b/mm_agents/coact/autogen/agentchat/contrib/capabilities/text_compressors.py new file mode 100644 index 0000000..ef4b6ef --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/contrib/capabilities/text_compressors.py @@ -0,0 +1,66 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 +# +# Portions derived from https://github.com/microsoft/autogen are under the MIT License. +# SPDX-License-Identifier: MIT +from typing import Any, Protocol + +from ....import_utils import optional_import_block, require_optional_import + +with optional_import_block() as result: + import llmlingua + from llmlingua import PromptCompressor + + +class TextCompressor(Protocol): + """Defines a protocol for text compression to optimize agent interactions.""" + + def compress_text(self, text: str, **compression_params) -> dict[str, Any]: + """This method takes a string as input and returns a dictionary containing the compressed text and other + relevant information. The compressed text should be stored under the 'compressed_text' key in the dictionary. + To calculate the number of saved tokens, the dictionary should include 'origin_tokens' and 'compressed_tokens' keys. + """ + ... + + +@require_optional_import("llmlingua", "long-context") +class LLMLingua: + """Compresses text messages using LLMLingua for improved efficiency in processing and response generation. + + NOTE: The effectiveness of compression and the resultant token savings can vary based on the content of the messages + and the specific configurations used for the PromptCompressor. + """ + + def __init__( + self, + prompt_compressor_kwargs: dict = dict( + model_name="microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank", + use_llmlingua2=True, + device_map="cpu", + ), + structured_compression: bool = False, + ) -> None: + """Args: + prompt_compressor_kwargs (dict): A dictionary of keyword arguments for the PromptCompressor. Defaults to a + dictionary with model_name set to "microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank", + use_llmlingua2 set to True, and device_map set to "cpu". + structured_compression (bool): A flag indicating whether to use structured compression. If True, the + structured_compress_prompt method of the PromptCompressor is used. Otherwise, the compress_prompt method + is used. Defaults to False. + dictionary. + + Raises: + ImportError: If the llmlingua library is not installed. + """ + self._prompt_compressor = PromptCompressor(**prompt_compressor_kwargs) + + assert isinstance(self._prompt_compressor, llmlingua.PromptCompressor) + self._compression_method = ( + self._prompt_compressor.structured_compress_prompt + if structured_compression + else self._prompt_compressor.compress_prompt + ) + + def compress_text(self, text: str, **compression_params) -> dict[str, Any]: + return self._compression_method([text], **compression_params) diff --git a/mm_agents/coact/autogen/agentchat/contrib/capabilities/tools_capability.py b/mm_agents/coact/autogen/agentchat/contrib/capabilities/tools_capability.py new file mode 100644 index 0000000..bd69396 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/contrib/capabilities/tools_capability.py @@ -0,0 +1,22 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from ....agentchat import ConversableAgent +from ....tools import Tool + + +class ToolsCapability: + """Adding a list of tools as composable capabilities to a single agent. + This class can be inherited from to allow code to run at the point of creating or adding the capability. + + Note: both caller and executor of the tools are the same agent. + """ + + def __init__(self, tool_list: list[Tool]): + self.tools = [tool for tool in tool_list] + + def add_to_agent(self, agent: ConversableAgent): + """Add tools to the given agent.""" + for tool in self.tools: + tool.register_tool(agent=agent) diff --git a/mm_agents/coact/autogen/agentchat/contrib/capabilities/transform_messages.py b/mm_agents/coact/autogen/agentchat/contrib/capabilities/transform_messages.py new file mode 100644 index 0000000..ad6c537 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/contrib/capabilities/transform_messages.py @@ -0,0 +1,93 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 +# +# Portions derived from https://github.com/microsoft/autogen are under the MIT License. +# SPDX-License-Identifier: MIT +import copy +from typing import TYPE_CHECKING, Any + +from ....formatting_utils import colored +from .transforms import MessageTransform + +if TYPE_CHECKING: + from ...conversable_agent import ConversableAgent + + +class TransformMessages: + """Agent capability for transforming messages before reply generation. + + This capability allows you to apply a series of message transformations to + a ConversableAgent's incoming messages before they are processed for response + generation. This is useful for tasks such as: + + - Limiting the number of messages considered for context. + - Truncating messages to meet token limits. + - Filtering sensitive information. + - Customizing message formatting. + + To use `TransformMessages`: + + 1. Create message transformations (e.g., `MessageHistoryLimiter`, `MessageTokenLimiter`). + 2. Instantiate `TransformMessages` with a list of these transformations. + 3. Add the `TransformMessages` instance to your `ConversableAgent` using `add_to_agent`. + + NOTE: Order of message transformations is important. You could get different results based on + the order of transformations. + + Example: + ```python + from agentchat import ConversableAgent + from agentchat.contrib.capabilities import TransformMessages, MessageHistoryLimiter, MessageTokenLimiter + + max_messages = MessageHistoryLimiter(max_messages=2) + truncate_messages = MessageTokenLimiter(max_tokens=500) + transform_messages = TransformMessages(transforms=[max_messages, truncate_messages]) + + agent = ConversableAgent(...) + transform_messages.add_to_agent(agent) + ``` + """ + + def __init__(self, *, transforms: list[MessageTransform] = [], verbose: bool = True): + """Args: + transforms: A list of message transformations to apply. + verbose: Whether to print logs of each transformation or not. + """ + self._transforms = transforms + self._verbose = verbose + + def add_to_agent(self, agent: "ConversableAgent"): + """Adds the message transformations capability to the specified ConversableAgent. + + This function performs the following modifications to the agent: + + 1. Registers a hook that automatically transforms all messages before they are processed for + response generation. + """ + agent.register_hook(hookable_method="process_all_messages_before_reply", hook=self._transform_messages) + + def _transform_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + post_transform_messages = copy.deepcopy(messages) + system_message = None + + if messages[0]["role"] == "system": + system_message = copy.deepcopy(messages[0]) + post_transform_messages.pop(0) + + for transform in self._transforms: + # deepcopy in case pre_transform_messages will later be used for logs printing + pre_transform_messages = ( + copy.deepcopy(post_transform_messages) if self._verbose else post_transform_messages + ) + post_transform_messages = transform.apply_transform(pre_transform_messages) + + if self._verbose: + logs_str, had_effect = transform.get_logs(pre_transform_messages, post_transform_messages) + if had_effect: + print(colored(logs_str, "yellow")) + + if system_message: + post_transform_messages.insert(0, system_message) + + return post_transform_messages diff --git a/mm_agents/coact/autogen/agentchat/contrib/capabilities/transforms.py b/mm_agents/coact/autogen/agentchat/contrib/capabilities/transforms.py new file mode 100644 index 0000000..cedb1cb --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/contrib/capabilities/transforms.py @@ -0,0 +1,579 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 +# +# Portions derived from https://github.com/microsoft/autogen are under the MIT License. +# SPDX-License-Identifier: MIT +import copy +import sys +from typing import Any, Optional, Protocol, Union + +import tiktoken +from termcolor import colored + +from .... import token_count_utils +from ....cache import AbstractCache, Cache +from ....types import MessageContentType +from . import transforms_util +from .text_compressors import LLMLingua, TextCompressor + + +class MessageTransform(Protocol): + """Defines a contract for message transformation. + + Classes implementing this protocol should provide an `apply_transform` method + that takes a list of messages and returns the transformed list. + """ + + def apply_transform(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Applies a transformation to a list of messages. + + Args: + messages: A list of dictionaries representing messages. + + Returns: + A new list of dictionaries containing the transformed messages. + """ + ... + + def get_logs( + self, pre_transform_messages: list[dict[str, Any]], post_transform_messages: list[dict[str, Any]] + ) -> tuple[str, bool]: + """Creates the string including the logs of the transformation + + Alongside the string, it returns a boolean indicating whether the transformation had an effect or not. + + Args: + pre_transform_messages: A list of dictionaries representing messages before the transformation. + post_transform_messages: A list of dictionaries representig messages after the transformation. + + Returns: + A tuple with a string with the logs and a flag indicating whether the transformation had an effect or not. + """ + ... + + +class MessageHistoryLimiter: + """Limits the number of messages considered by an agent for response generation. + + This transform keeps only the most recent messages up to the specified maximum number of messages (max_messages). + It trims the conversation history by removing older messages, retaining only the most recent messages. + """ + + def __init__( + self, + max_messages: Optional[int] = None, + keep_first_message: bool = False, + exclude_names: Optional[list[str]] = None, + ): + """Args: + max_messages Optional[int]: Maximum number of messages to keep in the context. Must be greater than 0 if not None. + keep_first_message bool: Whether to keep the original first message in the conversation history. + Defaults to False. + exclude_names Optional[list[str]]: List of message sender names to exclude from the message history. + Messages from these senders will be filtered out before applying the message limit. Defaults to None. + """ + self._validate_max_messages(max_messages) + self._max_messages = max_messages + self._keep_first_message = keep_first_message + self._exclude_names = exclude_names + + def apply_transform(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Truncates the conversation history to the specified maximum number of messages. + + This method returns a new list containing the most recent messages up to the specified + maximum number of messages (max_messages). If max_messages is None, it returns the + original list of messages unmodified. + + Args: + messages (List[Dict]): The list of messages representing the conversation history. + + Returns: + List[Dict]: A new list containing the most recent messages up to the specified maximum. + """ + + exclude_names = getattr(self, "_exclude_names", None) + + filtered = [msg for msg in messages if msg.get("name") not in exclude_names] if exclude_names else messages + + if self._max_messages is None or len(filtered) <= self._max_messages: + return filtered + + truncated_messages = [] + remaining_count = self._max_messages + + # Start with the first message if we need to keep it + if self._keep_first_message and filtered: + truncated_messages = [filtered[0]] + remaining_count -= 1 + + # Loop through messages in reverse + for i in range(len(filtered) - 1, 0, -1): + if remaining_count > 1: + truncated_messages.insert(1 if self._keep_first_message else 0, filtered[i]) + if remaining_count == 1: # noqa: SIM102 + # If there's only 1 slot left and it's a 'tools' message, ignore it. + if filtered[i].get("role") != "tool": + truncated_messages.insert(1, filtered[i]) + + remaining_count -= 1 + if remaining_count == 0: + break + + return truncated_messages + + def get_logs( + self, pre_transform_messages: list[dict[str, Any]], post_transform_messages: list[dict[str, Any]] + ) -> tuple[str, bool]: + pre_transform_messages_len = len(pre_transform_messages) + post_transform_messages_len = len(post_transform_messages) + + if post_transform_messages_len < pre_transform_messages_len: + logs_str = ( + f"Removed {pre_transform_messages_len - post_transform_messages_len} messages. " + f"Number of messages reduced from {pre_transform_messages_len} to {post_transform_messages_len}." + ) + return logs_str, True + return "No messages were removed.", False + + def _validate_max_messages(self, max_messages: Optional[int]): + if max_messages is not None and max_messages < 1: + raise ValueError("max_messages must be None or greater than 1") + + +class MessageTokenLimiter: + """Truncates messages to meet token limits for efficient processing and response generation. + + This transformation applies two levels of truncation to the conversation history: + + 1. Truncates each individual message to the maximum number of tokens specified by max_tokens_per_message. + 2. Truncates the overall conversation history to the maximum number of tokens specified by max_tokens. + + NOTE: Tokens are counted using the encoder for the specified model. Different models may yield different token + counts for the same text. + + NOTE: For multimodal LLMs, the token count may be inaccurate as it does not account for the non-text input + (e.g images). + + The truncation process follows these steps in order: + + 1. The minimum tokens threshold (`min_tokens`) is checked (0 by default). If the total number of tokens in messages + is less than this threshold, then the messages are returned as is. In other case, the following process is applied. + 2. Messages are processed in reverse order (newest to oldest). + 3. Individual messages are truncated based on max_tokens_per_message. For multimodal messages containing both text + and other types of content, only the text content is truncated. + 4. The overall conversation history is truncated based on the max_tokens limit. Once the accumulated token count + exceeds this limit, the current message being processed get truncated to meet the total token count and any + remaining messages get discarded. + 5. The truncated conversation history is reconstructed by prepending the messages to a new list to preserve the + original message order. + """ + + def __init__( + self, + max_tokens_per_message: Optional[int] = None, + max_tokens: Optional[int] = None, + min_tokens: Optional[int] = None, + model: str = "gpt-3.5-turbo-0613", + filter_dict: Optional[dict[str, Any]] = None, + exclude_filter: bool = True, + ): + """Args: + max_tokens_per_message (None or int): Maximum number of tokens to keep in each message. + Must be greater than or equal to 0 if not None. + max_tokens (Optional[int]): Maximum number of tokens to keep in the chat history. + Must be greater than or equal to 0 if not None. + min_tokens (Optional[int]): Minimum number of tokens in messages to apply the transformation. + Must be greater than or equal to 0 if not None. + model (str): The target OpenAI model for tokenization alignment. + filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress. + If None, no filters will be applied. + exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be + excluded from token truncation. If False, messages that match the filter will be truncated. + """ + self._model = model + self._max_tokens_per_message = self._validate_max_tokens(max_tokens_per_message) + self._max_tokens = self._validate_max_tokens(max_tokens) + self._min_tokens = self._validate_min_tokens(min_tokens, max_tokens) + self._filter_dict = filter_dict + self._exclude_filter = exclude_filter + + def apply_transform(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Applies token truncation to the conversation history. + + Args: + messages (List[Dict]): The list of messages representing the conversation history. + + Returns: + List[Dict]: A new list containing the truncated messages up to the specified token limits. + """ + assert self._max_tokens_per_message is not None + assert self._max_tokens is not None + assert self._min_tokens is not None + + # if the total number of tokens in the messages is less than the min_tokens, return the messages as is + if not transforms_util.min_tokens_reached(messages, self._min_tokens): + return messages + + temp_messages = copy.deepcopy(messages) + processed_messages = [] + processed_messages_tokens = 0 + + for msg in reversed(temp_messages): + # Some messages may not have content. + if not transforms_util.is_content_right_type(msg.get("content")): + processed_messages.insert(0, msg) + continue + + if not transforms_util.should_transform_message(msg, self._filter_dict, self._exclude_filter): + processed_messages.insert(0, msg) + processed_messages_tokens += transforms_util.count_text_tokens(msg["content"]) + continue + + expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message + + # If adding this message would exceed the token limit, truncate the last message to meet the total token + # limit and discard all remaining messages + if expected_tokens_remained < 0: + msg["content"] = self._truncate_str_to_tokens( + msg["content"], self._max_tokens - processed_messages_tokens + ) + processed_messages.insert(0, msg) + break + + msg["content"] = self._truncate_str_to_tokens(msg["content"], self._max_tokens_per_message) + msg_tokens = transforms_util.count_text_tokens(msg["content"]) + + # prepend the message to the list to preserve order + processed_messages_tokens += msg_tokens + processed_messages.insert(0, msg) + + return processed_messages + + def get_logs( + self, pre_transform_messages: list[dict[str, Any]], post_transform_messages: list[dict[str, Any]] + ) -> tuple[str, bool]: + pre_transform_messages_tokens = sum( + transforms_util.count_text_tokens(msg["content"]) for msg in pre_transform_messages if "content" in msg + ) + post_transform_messages_tokens = sum( + transforms_util.count_text_tokens(msg["content"]) for msg in post_transform_messages if "content" in msg + ) + + if post_transform_messages_tokens < pre_transform_messages_tokens: + logs_str = ( + f"Truncated {pre_transform_messages_tokens - post_transform_messages_tokens} tokens. " + f"Number of tokens reduced from {pre_transform_messages_tokens} to {post_transform_messages_tokens}" + ) + return logs_str, True + return "No tokens were truncated.", False + + def _truncate_str_to_tokens(self, contents: Union[str, list], n_tokens: int) -> Union[str, list]: + if isinstance(contents, str): + return self._truncate_tokens(contents, n_tokens) + elif isinstance(contents, list): + return self._truncate_multimodal_text(contents, n_tokens) + else: + raise ValueError(f"Contents must be a string or a list of dictionaries. Received type: {type(contents)}") + + def _truncate_multimodal_text(self, contents: list[dict[str, Any]], n_tokens: int) -> list[dict[str, Any]]: + """Truncates text content within a list of multimodal elements, preserving the overall structure.""" + tmp_contents = [] + for content in contents: + if content["type"] == "text": + truncated_text = self._truncate_tokens(content["text"], n_tokens) + tmp_contents.append({"type": "text", "text": truncated_text}) + else: + tmp_contents.append(content) + return tmp_contents + + def _truncate_tokens(self, text: str, n_tokens: int) -> str: + encoding = tiktoken.encoding_for_model(self._model) # Get the appropriate tokenizer + + encoded_tokens = encoding.encode(text) + truncated_tokens = encoded_tokens[:n_tokens] + truncated_text = encoding.decode(truncated_tokens) # Decode back to text + + return truncated_text + + def _validate_max_tokens(self, max_tokens: Optional[int] = None) -> Optional[int]: + if max_tokens is not None and max_tokens < 0: + raise ValueError("max_tokens and max_tokens_per_message must be None or greater than or equal to 0") + + try: + allowed_tokens = token_count_utils.get_max_token_limit(self._model) + except Exception: + print(colored(f"Model {self._model} not found in token_count_utils.", "yellow")) + allowed_tokens = None + + if max_tokens is not None and allowed_tokens is not None and max_tokens > allowed_tokens: + print( + colored( + f"Max token was set to {max_tokens}, but {self._model} can only accept {allowed_tokens} tokens. Capping it to {allowed_tokens}.", + "yellow", + ) + ) + return allowed_tokens + + return max_tokens if max_tokens is not None else sys.maxsize + + def _validate_min_tokens(self, min_tokens: Optional[int], max_tokens: Optional[int]) -> int: + if min_tokens is None: + return 0 + if min_tokens < 0: + raise ValueError("min_tokens must be None or greater than or equal to 0.") + if max_tokens is not None and min_tokens > max_tokens: + raise ValueError("min_tokens must not be more than max_tokens.") + return min_tokens + + +class TextMessageCompressor: + """A transform for compressing text messages in a conversation history. + + It uses a specified text compression method to reduce the token count of messages, which can lead to more efficient + processing and response generation by downstream models. + """ + + def __init__( + self, + text_compressor: Optional[TextCompressor] = None, + min_tokens: Optional[int] = None, + compression_params: dict = dict(), + cache: Optional[AbstractCache] = None, + filter_dict: Optional[dict[str, Any]] = None, + exclude_filter: bool = True, + ): + """Args: + text_compressor (TextCompressor or None): An instance of a class that implements the TextCompressor + protocol. If None, it defaults to LLMLingua. + min_tokens (int or None): Minimum number of tokens in messages to apply the transformation. Must be greater + than or equal to 0 if not None. If None, no threshold-based compression is applied. + compression_args (dict): A dictionary of arguments for the compression method. Defaults to an empty + dictionary. + cache (None or AbstractCache): The cache client to use to store and retrieve previously compressed messages. + If None, no caching will be used. + filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress. + If None, no filters will be applied. + exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be + excluded from compression. If False, messages that match the filter will be compressed. + """ + if text_compressor is None: + text_compressor = LLMLingua() + + self._validate_min_tokens(min_tokens) + + self._text_compressor = text_compressor + self._min_tokens = min_tokens + self._compression_args = compression_params + self._filter_dict = filter_dict + self._exclude_filter = exclude_filter + + if cache is None: + self._cache = Cache.disk() + else: + self._cache = cache + + # Optimizing savings calculations to optimize log generation + self._recent_tokens_savings = 0 + + def apply_transform(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Applies compression to messages in a conversation history based on the specified configuration. + + The function processes each message according to the `compression_args` and `min_tokens` settings, applying + the specified compression configuration and returning a new list of messages with reduced token counts + where possible. + + Args: + messages (List[Dict]): A list of message dictionaries to be compressed. + + Returns: + List[Dict]: A list of dictionaries with the message content compressed according to the configured + method and scope. + """ + # Make sure there is at least one message + if not messages: + return messages + + # if the total number of tokens in the messages is less than the min_tokens, return the messages as is + if not transforms_util.min_tokens_reached(messages, self._min_tokens): + return messages + + total_savings = 0 + processed_messages = messages.copy() + for message in processed_messages: + # Some messages may not have content. + if not transforms_util.is_content_right_type(message.get("content")): + continue + + if not transforms_util.should_transform_message(message, self._filter_dict, self._exclude_filter): + continue + + if transforms_util.is_content_text_empty(message["content"]): + continue + + cache_key = transforms_util.cache_key(message["content"], self._min_tokens) + cached_content = transforms_util.cache_content_get(self._cache, cache_key) + if cached_content is not None: + message["content"], savings = cached_content + else: + message["content"], savings = self._compress(message["content"]) + + transforms_util.cache_content_set(self._cache, cache_key, message["content"], savings) + + assert isinstance(savings, int) + total_savings += savings + + self._recent_tokens_savings = total_savings + return processed_messages + + def get_logs( + self, pre_transform_messages: list[dict[str, Any]], post_transform_messages: list[dict[str, Any]] + ) -> tuple[str, bool]: + if self._recent_tokens_savings > 0: + return f"{self._recent_tokens_savings} tokens saved with text compression.", True + else: + return "No tokens saved with text compression.", False + + def _compress(self, content: MessageContentType) -> tuple[MessageContentType, int]: + """Compresses the given text or multimodal content using the specified compression method.""" + if isinstance(content, str): + return self._compress_text(content) + elif isinstance(content, list): + return self._compress_multimodal(content) + else: + return content, 0 + + def _compress_multimodal(self, content: MessageContentType) -> tuple[MessageContentType, int]: + tokens_saved = 0 + for item in content: + if isinstance(item, dict) and "text" in item: + item["text"], savings = self._compress_text(item["text"]) + tokens_saved += savings + + elif isinstance(item, str): + item, savings = self._compress_text(item) + tokens_saved += savings + + return content, tokens_saved + + def _compress_text(self, text: str) -> tuple[str, int]: + """Compresses the given text using the specified compression method.""" + compressed_text = self._text_compressor.compress_text(text, **self._compression_args) + + savings = 0 + if "origin_tokens" in compressed_text and "compressed_tokens" in compressed_text: + savings = compressed_text["origin_tokens"] - compressed_text["compressed_tokens"] + + return compressed_text["compressed_prompt"], savings + + def _validate_min_tokens(self, min_tokens: Optional[int]): + if min_tokens is not None and min_tokens <= 0: + raise ValueError("min_tokens must be greater than 0 or None") + + +class TextMessageContentName: + """A transform for including the agent's name in the content of a message. + + How to create and apply the transform: + # Imports + from autogen.agentchat.contrib.capabilities import transform_messages, transforms + + # Create Transform + name_transform = transforms.TextMessageContentName(position="start", format_string="'{name}' said:\n") + + # Create the TransformMessages + context_handling = transform_messages.TransformMessages( + transforms=[ + name_transform + ] + ) + + # Add it to an agent so when they run inference it will apply to the messages + context_handling.add_to_agent(my_agent) + """ + + def __init__( + self, + position: str = "start", + format_string: str = "{name}:\n", + deduplicate: bool = True, + filter_dict: Optional[dict[str, Any]] = None, + exclude_filter: bool = True, + ): + """Args: + position (str): The position to add the name to the content. The possible options are 'start' or 'end'. Defaults to 'start'. + format_string (str): The f-string to format the message name with. Use '{name}' as a placeholder for the agent's name. Defaults to '{name}:\n' and must contain '{name}'. + deduplicate (bool): Whether to deduplicate the formatted string so it doesn't appear twice (sometimes the LLM will add it to new messages itself). Defaults to True. + filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress. + If None, no filters will be applied. + exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be + excluded from compression. If False, messages that match the filter will be compressed. + """ + assert isinstance(position, str) and position in ["start", "end"] + assert isinstance(format_string, str) and "{name}" in format_string + assert isinstance(deduplicate, bool) and deduplicate is not None + + self._position = position + self._format_string = format_string + self._deduplicate = deduplicate + self._filter_dict = filter_dict + self._exclude_filter = exclude_filter + + # Track the number of messages changed for logging + self._messages_changed = 0 + + def apply_transform(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Applies the name change to the message based on the position and format string. + + Args: + messages (List[Dict]): A list of message dictionaries. + + Returns: + List[Dict]: A list of dictionaries with the message content updated with names. + """ + # Make sure there is at least one message + if not messages: + return messages + + messages_changed = 0 + processed_messages = copy.deepcopy(messages) + for message in processed_messages: + # Some messages may not have content. + if not transforms_util.is_content_right_type( + message.get("content") + ) or not transforms_util.is_content_right_type(message.get("name")): + continue + + if not transforms_util.should_transform_message(message, self._filter_dict, self._exclude_filter): + continue + + if transforms_util.is_content_text_empty(message["content"]) or transforms_util.is_content_text_empty( + message["name"] + ): + continue + + # Get and format the name in the content + content = message["content"] + formatted_name = self._format_string.format(name=message["name"]) + + if self._position == "start": + if not self._deduplicate or not content.startswith(formatted_name): + message["content"] = f"{formatted_name}{content}" + + messages_changed += 1 + else: + if not self._deduplicate or not content.endswith(formatted_name): + message["content"] = f"{content}{formatted_name}" + + messages_changed += 1 + + self._messages_changed = messages_changed + return processed_messages + + def get_logs( + self, pre_transform_messages: list[dict[str, Any]], post_transform_messages: list[dict[str, Any]] + ) -> tuple[str, bool]: + if self._messages_changed > 0: + return f"{self._messages_changed} message(s) changed to incorporate name.", True + else: + return "No messages changed to incorporate name.", False diff --git a/mm_agents/coact/autogen/agentchat/contrib/capabilities/transforms_util.py b/mm_agents/coact/autogen/agentchat/contrib/capabilities/transforms_util.py new file mode 100644 index 0000000..0aa73c1 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/contrib/capabilities/transforms_util.py @@ -0,0 +1,122 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 +# +# Portions derived from https://github.com/microsoft/autogen are under the MIT License. +# SPDX-License-Identifier: MIT +from collections.abc import Hashable +from typing import Any, Optional + +from .... import token_count_utils +from ....cache.abstract_cache_base import AbstractCache +from ....oai.openai_utils import filter_config +from ....types import MessageContentType + + +def cache_key(content: MessageContentType, *args: Hashable) -> str: + """Calculates the cache key for the given message content and any other hashable args. + + Args: + content (MessageContentType): The message content to calculate the cache key for. + *args: Any additional hashable args to include in the cache key. + """ + str_keys = [str(key) for key in (content, *args)] + return "".join(str_keys) + + +def cache_content_get(cache: Optional[AbstractCache], key: str) -> Optional[tuple[MessageContentType, ...]]: + """Retrieves cached content from the cache. + + Args: + cache (None or AbstractCache): The cache to retrieve the content from. If None, the cache is ignored. + key (str): The key to retrieve the content from. + """ + if cache: + cached_value = cache.get(key) + if cached_value: + return cached_value + + +def cache_content_set(cache: Optional[AbstractCache], key: str, content: MessageContentType, *extra_values): + """Sets content into the cache. + + Args: + cache (None or AbstractCache): The cache to set the content into. If None, the cache is ignored. + key (str): The key to set the content into. + content (MessageContentType): The message content to set into the cache. + *extra_values: Additional values to be passed to the cache. + """ + if cache: + cache_value = (content, *extra_values) + cache.set(key, cache_value) + + +def min_tokens_reached(messages: list[dict[str, Any]], min_tokens: Optional[int]) -> bool: + """Returns True if the total number of tokens in the messages is greater than or equal to the specified value. + + Args: + messages (List[Dict]): A list of messages to check. + min_tokens (None or int): The minimum number of tokens to check for. + """ + if not min_tokens: + return True + + messages_tokens = sum(count_text_tokens(msg["content"]) for msg in messages if "content" in msg) + return messages_tokens >= min_tokens + + +def count_text_tokens(content: MessageContentType) -> int: + """Calculates the number of text tokens in the given message content. + + Args: + content (MessageContentType): The message content to calculate the number of text tokens for. + """ + token_count = 0 + if isinstance(content, str): + token_count = token_count_utils.count_token(content) + elif isinstance(content, list): + for item in content: + if isinstance(item, str): + token_count += token_count_utils.count_token(item) + else: + token_count += count_text_tokens(item.get("text", "")) + return token_count + + +def is_content_right_type(content: Any) -> bool: + """A helper function to check if the passed in content is of the right type.""" + return isinstance(content, (str, list)) + + +def is_content_text_empty(content: MessageContentType) -> bool: + """Checks if the content of the message does not contain any text. + + Args: + content (MessageContentType): The message content to check. + """ + if isinstance(content, str): + return content == "" + elif isinstance(content, list): + texts = [] + for item in content: + if isinstance(item, str): + texts.append(item) + elif isinstance(item, dict): + texts.append(item.get("text", "")) + return not any(texts) + else: + return True + + +def should_transform_message(message: dict[str, Any], filter_dict: Optional[dict[str, Any]], exclude: bool) -> bool: + """Validates whether the transform should be applied according to the filter dictionary. + + Args: + message (Dict[str, Any]): The message to validate. + filter_dict (None or Dict[str, Any]): The filter dictionary to validate against. If None, the transform is always applied. + exclude (bool): Whether to exclude messages that match the filter dictionary. + """ + if not filter_dict: + return True + + return len(filter_config([message], filter_dict, exclude)) > 0 diff --git a/mm_agents/coact/autogen/agentchat/contrib/capabilities/vision_capability.py b/mm_agents/coact/autogen/agentchat/contrib/capabilities/vision_capability.py new file mode 100644 index 0000000..a0035c8 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/contrib/capabilities/vision_capability.py @@ -0,0 +1,212 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 +# +# Portions derived from https://github.com/microsoft/autogen are under the MIT License. +# SPDX-License-Identifier: MIT +import copy +from typing import Any, Callable, Optional, Union + +from ....code_utils import content_str +from ....oai.client import OpenAIWrapper +from ...assistant_agent import ConversableAgent +from ..img_utils import ( + convert_base64_to_data_uri, + get_image_data, + get_pil_image, + gpt4v_formatter, +) +from .agent_capability import AgentCapability + +DEFAULT_DESCRIPTION_PROMPT = ( + "Write a detailed caption for this image. " + "Pay special attention to any details that might be useful or relevant " + "to the ongoing conversation." +) + + +class VisionCapability(AgentCapability): + """We can add vision capability to regular ConversableAgent, even if the agent does not have the multimodal capability, + such as GPT-3.5-turbo agent, Llama, Orca, or Mistral agents. This vision capability will invoke a LMM client to describe + the image (captioning) before sending the information to the agent's actual client. + + The vision capability will hook to the ConversableAgent's `process_last_received_message`. + + Some technical details: + When the agent (who has the vision capability) received an message, it will: + 1. _process_received_message: + a. _append_oai_message + 2. generate_reply: if the agent is a MultimodalAgent, it will also use the image tag. + a. hook process_last_received_message (NOTE: this is where the vision capability will be hooked to.) + b. hook process_all_messages_before_reply + 3. send: + a. hook process_message_before_send + b. _append_oai_message + """ + + def __init__( + self, + lmm_config: dict[str, Any], + description_prompt: Optional[str] = DEFAULT_DESCRIPTION_PROMPT, + custom_caption_func: Callable = None, + ) -> None: + """Initializes a new instance, setting up the configuration for interacting with + a Language Multimodal (LMM) client and specifying optional parameters for image + description and captioning. + + Args: + lmm_config (Dict): Configuration for the LMM client, which is used to call + the LMM service for describing the image. This must be a dictionary containing + the necessary configuration parameters. If `lmm_config` is False or an empty dictionary, + it is considered invalid, and initialization will assert. + description_prompt (Optional[str], optional): The prompt to use for generating + descriptions of the image. This parameter allows customization of the + prompt passed to the LMM service. Defaults to `DEFAULT_DESCRIPTION_PROMPT` if not provided. + custom_caption_func (Callable, optional): A callable that, if provided, will be used + to generate captions for images. This allows for custom captioning logic outside + of the standard LMM service interaction. + The callable should take three parameters as input: + 1. an image URL (or local location) + 2. image_data (a PIL image) + 3. lmm_client (to call remote LMM) + and then return a description (as string). + If not provided, captioning will rely on the LMM client configured via `lmm_config`. + If provided, we will not run the default self._get_image_caption method. + + Raises: + AssertionError: If neither a valid `lmm_config` nor a `custom_caption_func` is provided, + an AssertionError is raised to indicate that the Vision Capability requires + one of these to be valid for operation. + """ + self._lmm_config = lmm_config + self._description_prompt = description_prompt + self._parent_agent = None + + if lmm_config: + self._lmm_client = OpenAIWrapper(**lmm_config) + else: + self._lmm_client = None + + self._custom_caption_func = custom_caption_func + assert self._lmm_config or custom_caption_func, ( + "Vision Capability requires a valid lmm_config or custom_caption_func." + ) + + def add_to_agent(self, agent: ConversableAgent) -> None: + self._parent_agent = agent + + # Append extra info to the system message. + agent.update_system_message(agent.system_message + "\nYou've been given the ability to interpret images.") + + # Register a hook for processing the last message. + agent.register_hook(hookable_method="process_last_received_message", hook=self.process_last_received_message) + + def process_last_received_message(self, content: Union[str, list[dict[str, Any]]]) -> str: + """Processes the last received message content by normalizing and augmenting it + with descriptions of any included images. The function supports input content + as either a string or a list of dictionaries, where each dictionary represents + a content item (e.g., text, image). If the content contains image URLs, it + fetches the image data, generates a caption for each image, and inserts the + caption into the augmented content. + + The function aims to transform the content into a format compatible with GPT-4V + multimodal inputs, specifically by formatting strings into PIL-compatible + images if needed and appending text descriptions for images. This allows for + a more accessible presentation of the content, especially in contexts where + images cannot be displayed directly. + + Args: + content (Union[str, List[dict[str, Any]]]): The last received message content, which + can be a plain text string or a list of dictionaries representing + different types of content items (e.g., text, image_url). + + Returns: + str: The augmented message content + + Raises: + AssertionError: If an item in the content list is not a dictionary. + + Examples: + Assuming `self._get_image_caption(img_data)` returns + "A beautiful sunset over the mountains" for the image. + + - Input as String: + content = "Check out this cool photo!" + Output: "Check out this cool photo!" + (Content is a string without an image, remains unchanged.) + + - Input as String, with image location: + content = "What's weather in this cool photo: ``" + Output: "What's weather in this cool photo: `` in case you can not see, the caption of this image is: + A beautiful sunset over the mountains\n" + (Caption added after the image) + + - Input as List with Text Only: + content = `[{"type": "text", "text": "Here's an interesting fact."}]` + Output: "Here's an interesting fact." + (No images in the content, it remains unchanged.) + + - Input as List with Image URL: + ```python + content = [ + {"type": "text", "text": "What's weather in this cool photo:"}, + {"type": "image_url", "image_url": "http://example.com/photo.jpg"}, + ] + ``` + Output: "What's weather in this cool photo: `` in case you can not see, the caption of this image is: + A beautiful sunset over the mountains\n" + (Caption added after the image) + """ + copy.deepcopy(content) + # normalize the content into the gpt-4v format for multimodal + # we want to keep the URL format to keep it concise. + if isinstance(content, str): + content = gpt4v_formatter(content, img_format="url") + + aug_content: str = "" + for item in content: + assert isinstance(item, dict) + if item["type"] == "text": + aug_content += item["text"] + elif item["type"] == "image_url": + img_url = item["image_url"] + img_caption = "" + + if self._custom_caption_func: + img_caption = self._custom_caption_func(img_url, get_pil_image(img_url), self._lmm_client) + elif self._lmm_client: + img_data = get_image_data(img_url) + img_caption = self._get_image_caption(img_data) + else: + img_caption = "" + + aug_content += f" in case you can not see, the caption of this image is: {img_caption}\n" + else: + print(f"Warning: the input type should either be `test` or `image_url`. Skip {item['type']} here.") + + return aug_content + + def _get_image_caption(self, img_data: str) -> str: + """Args: + img_data (str): base64 encoded image data. + + Returns: + str: caption for the given image. + """ + response = self._lmm_client.create( + context=None, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": self._description_prompt}, + { + "type": "image_url", + "image_url": convert_base64_to_data_uri(img_data), + }, + ], + } + ], + ) + description = response.choices[0].message.content + return content_str(description) diff --git a/mm_agents/coact/autogen/agentchat/contrib/img_utils.py b/mm_agents/coact/autogen/agentchat/contrib/img_utils.py new file mode 100644 index 0000000..be88a33 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/contrib/img_utils.py @@ -0,0 +1,411 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 +# +# Portions derived from https://github.com/microsoft/autogen are under the MIT License. +# SPDX-License-Identifier: MIT +import base64 +import copy +import os +import re +from io import BytesIO +from math import ceil +from typing import Any, Union + +import requests + +from ...import_utils import optional_import_block, require_optional_import +from .. import utils + +with optional_import_block(): + from PIL import Image + + +# Parameters for token counting for images for different models +MODEL_PARAMS = { + "gpt-4-vision": { + "max_edge": 2048, + "min_edge": 768, + "tile_size": 512, + "base_token_count": 85, + "token_multiplier": 170, + }, + "gpt-4o-mini": { + "max_edge": 2048, + "min_edge": 768, + "tile_size": 512, + "base_token_count": 2833, + "token_multiplier": 5667, + }, + "gpt-4o": {"max_edge": 2048, "min_edge": 768, "tile_size": 512, "base_token_count": 85, "token_multiplier": 170}, +} + + +@require_optional_import("PIL", "unknown") +def get_pil_image(image_file: Union[str, "Image.Image"]) -> "Image.Image": + """Loads an image from a file and returns a PIL Image object. + + Parameters: + image_file (str, or Image): The filename, URL, URI, or base64 string of the image file. + + Returns: + Image.Image: The PIL Image object. + """ + if isinstance(image_file, Image.Image): + # Already a PIL Image object + return image_file + + # Remove quotes if existed + if image_file.startswith('"') and image_file.endswith('"'): + image_file = image_file[1:-1] + if image_file.startswith("'") and image_file.endswith("'"): + image_file = image_file[1:-1] + + if image_file.startswith("http://") or image_file.startswith("https://"): + # A URL file + response = requests.get(image_file) + content = BytesIO(response.content) + image = Image.open(content) + # Match base64-encoded image URIs for supported formats: jpg, jpeg, png, gif, bmp, webp + elif re.match(r"data:image/(?:jpg|jpeg|png|gif|bmp|webp);base64,", image_file): + # A URI. Remove the prefix and decode the base64 string. + base64_data = re.sub(r"data:image/(?:jpg|jpeg|png|gif|bmp|webp);base64,", "", image_file) + image = _to_pil(base64_data) + elif os.path.exists(image_file): + # A local file + image = Image.open(image_file) + else: + # base64 encoded string + image = _to_pil(image_file) + + return image.convert("RGB") + + +@require_optional_import("PIL", "unknown") +def get_image_data(image_file: Union[str, "Image.Image"], use_b64=True) -> bytes: + """Loads an image and returns its data either as raw bytes or in base64-encoded format. + + This function first loads an image from the specified file, URL, or base64 string using + the `get_pil_image` function. It then saves this image in memory in PNG format and + retrieves its binary content. Depending on the `use_b64` flag, this binary content is + either returned directly or as a base64-encoded string. + + Parameters: + image_file (str, or Image): The path to the image file, a URL to an image, or a base64-encoded + string of the image. + use_b64 (bool): If True, the function returns a base64-encoded string of the image data. + If False, it returns the raw byte data of the image. Defaults to True. + + Returns: + bytes: The image data in raw bytes if `use_b64` is False, or a base64-encoded string + if `use_b64` is True. + """ + image = get_pil_image(image_file) + + buffered = BytesIO() + image.save(buffered, format="PNG") + content = buffered.getvalue() + + if use_b64: + return base64.b64encode(content).decode("utf-8") + else: + return content + + +@require_optional_import("PIL", "unknown") +def llava_formatter(prompt: str, order_image_tokens: bool = False) -> tuple[str, list[str]]: + """Formats the input prompt by replacing image tags and returns the new prompt along with image locations. + + Parameters: + - prompt (str): The input string that may contain image tags like ``. + - order_image_tokens (bool, optional): Whether to order the image tokens with numbers. + It will be useful for GPT-4V. Defaults to False. + + Returns: + - Tuple[str, List[str]]: A tuple containing the formatted string and a list of images (loaded in b64 format). + """ + # Initialize variables + new_prompt = prompt + image_locations = [] + images = [] + image_count = 0 + + # Regular expression pattern for matching tags + img_tag_pattern = re.compile(r"]+)>") + + # Find all image tags + for match in img_tag_pattern.finditer(prompt): + image_location = match.group(1) + + try: + img_data = get_image_data(image_location) + except Exception as e: + # Remove the token + print(f"Warning! Unable to load image from {image_location}, because of {e}") + new_prompt = new_prompt.replace(match.group(0), "", 1) + continue + + image_locations.append(image_location) + images.append(img_data) + + # Increment the image count and replace the tag in the prompt + new_token = f"" if order_image_tokens else "" + + new_prompt = new_prompt.replace(match.group(0), new_token, 1) + image_count += 1 + + return new_prompt, images + + +@require_optional_import("PIL", "unknown") +def pil_to_data_uri(image: "Image.Image") -> str: + """Converts a PIL Image object to a data URI. + + Parameters: + image (Image.Image): The PIL Image object. + + Returns: + str: The data URI string. + """ + buffered = BytesIO() + image.save(buffered, format="PNG") + content = buffered.getvalue() + return convert_base64_to_data_uri(base64.b64encode(content).decode("utf-8")) + + +def convert_base64_to_data_uri(base64_image): + def _get_mime_type_from_data_uri(base64_image): + # Decode the base64 string + image_data = base64.b64decode(base64_image) + # Check the first few bytes for known signatures + if image_data.startswith(b"\xff\xd8\xff"): + return "image/jpeg" + elif image_data.startswith(b"\x89PNG\r\n\x1a\n"): + return "image/png" + elif image_data.startswith(b"GIF87a") or image_data.startswith(b"GIF89a"): + return "image/gif" + elif image_data.startswith(b"RIFF") and image_data[8:12] == b"WEBP": + return "image/webp" + return "image/jpeg" # use jpeg for unknown formats, best guess. + + mime_type = _get_mime_type_from_data_uri(base64_image) + data_uri = f"data:{mime_type};base64,{base64_image}" + return data_uri + + +@require_optional_import("PIL", "unknown") +def gpt4v_formatter(prompt: str, img_format: str = "uri") -> list[Union[str, dict[str, Any]]]: + """Formats the input prompt by replacing image tags and returns a list of text and images. + + Args: + prompt (str): The input string that may contain image tags like ``. + img_format (str): what image format should be used. One of "uri", "url", "pil". + + Returns: + List[Union[str, dict[str, Any]]]: A list of alternating text and image dictionary items. + """ + assert img_format in ["uri", "url", "pil"] + + output = [] + last_index = 0 + image_count = 0 + + # Find all image tags + for parsed_tag in utils.parse_tags_from_content("img", prompt): + image_location = parsed_tag["attr"]["src"] + try: + if img_format == "pil": + img_data = get_pil_image(image_location) + elif img_format == "uri": + img_data = get_image_data(image_location) + img_data = convert_base64_to_data_uri(img_data) + elif img_format == "url": + img_data = image_location + else: + raise ValueError(f"Unknown image format {img_format}") + except Exception as e: + # Warning and skip this token + print(f"Warning! Unable to load image from {image_location}, because {e}") + continue + + # Add text before this image tag to output list + output.append({"type": "text", "text": prompt[last_index : parsed_tag["match"].start()]}) + + # Add image data to output list + output.append({"type": "image_url", "image_url": {"url": img_data}}) + + last_index = parsed_tag["match"].end() + image_count += 1 + + # Add remaining text to output list + if last_index < len(prompt): + output.append({"type": "text", "text": prompt[last_index:]}) + return output + + +def extract_img_paths(paragraph: str) -> list: + """Extract image paths (URLs or local paths) from a text paragraph. + + Parameters: + paragraph (str): The input text paragraph. + + Returns: + list: A list of extracted image paths. + """ + # Regular expression to match image URLs and file paths. + # This regex detects URLs and file paths with common image extensions, including support for the webp format. + img_path_pattern = re.compile( + r"\b(?:http[s]?://\S+\.(?:jpg|jpeg|png|gif|bmp|webp)|\S+\.(?:jpg|jpeg|png|gif|bmp|webp))\b", re.IGNORECASE + ) + + # Find all matches in the paragraph + img_paths = re.findall(img_path_pattern, paragraph) + return img_paths + + +@require_optional_import("PIL", "unknown") +def _to_pil(data: str) -> "Image.Image": + """Converts a base64 encoded image data string to a PIL Image object. + + This function first decodes the base64 encoded string to bytes, then creates a BytesIO object from the bytes, + and finally creates and returns a PIL Image object from the BytesIO object. + + Parameters: + data (str): The encoded image data string. + + Returns: + Image.Image: The PIL Image object created from the input data. + """ + return Image.open(BytesIO(base64.b64decode(data))) + + +@require_optional_import("PIL", "unknown") +def message_formatter_pil_to_b64(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Converts the PIL image URLs in the messages to base64 encoded data URIs. + + This function iterates over a list of message dictionaries. For each message, + if it contains a 'content' key with a list of items, it looks for items + with an 'image_url' key. The function then converts the PIL image URL + (pointed to by 'image_url') to a base64 encoded data URI. + + Parameters: + messages (List[Dict]): A list of message dictionaries. Each dictionary + may contain a 'content' key with a list of items, + some of which might be image URLs. + + Returns: + List[Dict]: A new list of message dictionaries with PIL image URLs in the + 'image_url' key converted to base64 encoded data URIs. + + Example Input: + example 1: + ```python + [ + {'content': [{'type': 'text', 'text': 'You are a helpful AI assistant.'}], 'role': 'system'}, + {'content': [ + {'type': 'text', 'text': "What's the breed of this dog here?"}, + {'type': 'image_url', 'image_url': {'url': a PIL.Image.Image}}, + {'type': 'text', 'text': '.'}], + 'role': 'user'} + ] + ``` + + Example Output: + example 1: + ```python + [ + {'content': [{'type': 'text', 'text': 'You are a helpful AI assistant.'}], 'role': 'system'}, + {'content': [ + {'type': 'text', 'text': "What's the breed of this dog here?"}, + {'type': 'image_url', 'image_url': {'url': a B64 Image}}, + {'type': 'text', 'text': '.'}], + 'role': 'user'} + ] + ``` + """ + new_messages = [] + for message in messages: + # deepcopy to avoid modifying the original message. + message = copy.deepcopy(message) + if isinstance(message, dict) and "content" in message: + # First, if the content is a string, parse it into a list of parts. + # This is for tool output that contains images. + if isinstance(message["content"], str): + message["content"] = gpt4v_formatter(message["content"], img_format="pil") + + # Second, if the content is a list, process any image parts. + if isinstance(message["content"], list): + for item in message["content"]: + if ( + isinstance(item, dict) + and "image_url" in item + and isinstance(item["image_url"]["url"], Image.Image) + ): + item["image_url"]["url"] = pil_to_data_uri(item["image_url"]["url"]) + + new_messages.append(message) + + return new_messages + + +@require_optional_import("PIL", "unknown") +def num_tokens_from_gpt_image( + image_data: Union[str, "Image.Image"], model: str = "gpt-4-vision", low_quality: bool = False +) -> int: + """Calculate the number of tokens required to process an image based on its dimensions + after scaling for different GPT models. Supports "gpt-4-vision", "gpt-4o", and "gpt-4o-mini". + This function scales the image so that its longest edge is at most 2048 pixels and its shortest + edge is at most 768 pixels (for "gpt-4-vision"). It then calculates the number of 512x512 tiles + needed to cover the scaled image and computes the total tokens based on the number of these tiles. + + Reference: https://openai.com/api/pricing/ + + Args: + image_data : Union[str, Image.Image]: The image data which can either be a base64 encoded string, a URL, a file path, or a PIL Image object. + model: str: The model being used for image processing. Can be "gpt-4-vision", "gpt-4o", or "gpt-4o-mini". + low_quality: bool: Whether to use low-quality processing. Defaults to False. + + Returns: + int: The total number of tokens required for processing the image. + + Examples: + -------- + >>> from PIL import Image + >>> img = Image.new("RGB", (2500, 2500), color="red") + >>> num_tokens_from_gpt_image(img, model="gpt-4-vision") + 765 + """ + image = get_pil_image(image_data) # PIL Image + width, height = image.size + + # Determine model parameters + if "gpt-4-vision" in model or "gpt-4-turbo" in model or "gpt-4v" in model or "gpt-4-v" in model: + params = MODEL_PARAMS["gpt-4-vision"] + elif "gpt-4o-mini" in model: + params = MODEL_PARAMS["gpt-4o-mini"] + elif "gpt-4o" in model: + params = MODEL_PARAMS["gpt-4o"] + else: + raise ValueError( + f"Model {model} is not supported. Choose 'gpt-4-vision', 'gpt-4-turbo', 'gpt-4v', 'gpt-4-v', 'gpt-4o', or 'gpt-4o-mini'." + ) + + if low_quality: + return params["base_token_count"] + + # 1. Constrain the longest edge + if max(width, height) > params["max_edge"]: + scale_factor = params["max_edge"] / max(width, height) + width, height = int(width * scale_factor), int(height * scale_factor) + + # 2. Further constrain the shortest edge + if min(width, height) > params["min_edge"]: + scale_factor = params["min_edge"] / min(width, height) + width, height = int(width * scale_factor), int(height * scale_factor) + + # 3. Count how many tiles are needed to cover the image + tiles_width = ceil(width / params["tile_size"]) + tiles_height = ceil(height / params["tile_size"]) + total_tokens = params["base_token_count"] + params["token_multiplier"] * (tiles_width * tiles_height) + + return total_tokens diff --git a/mm_agents/coact/autogen/agentchat/contrib/multimodal_conversable_agent.py b/mm_agents/coact/autogen/agentchat/contrib/multimodal_conversable_agent.py new file mode 100644 index 0000000..d37fff0 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/contrib/multimodal_conversable_agent.py @@ -0,0 +1,153 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 +# +# Portions derived from https://github.com/microsoft/autogen are under the MIT License. +# SPDX-License-Identifier: MIT +import copy +from typing import Any, Optional, Union + +from ... import OpenAIWrapper +from ...code_utils import content_str +from .. import Agent, ConversableAgent +from ..contrib.img_utils import ( + gpt4v_formatter, + message_formatter_pil_to_b64, +) + +DEFAULT_LMM_SYS_MSG = """You are a helpful AI assistant.""" +DEFAULT_MODEL = "gpt-4-vision-preview" + + +class MultimodalConversableAgent(ConversableAgent): + DEFAULT_CONFIG = { + "model": DEFAULT_MODEL, + } + + def __init__( + self, + name: str, + system_message: Optional[Union[str, list]] = DEFAULT_LMM_SYS_MSG, + is_termination_msg: str = None, + *args, + **kwargs: Any, + ): + """Args: + name (str): agent name. + system_message (str): system message for the OpenAIWrapper inference. + Please override this attribute if you want to reprogram the agent. + **kwargs (dict): Please refer to other kwargs in + [ConversableAgent](/docs/api-reference/autogen/ConversableAgent#conversableagent). + """ + super().__init__( + name, + system_message, + is_termination_msg=is_termination_msg, + *args, + **kwargs, + ) + # call the setter to handle special format. + self.update_system_message(system_message) + self._is_termination_msg = ( + is_termination_msg + if is_termination_msg is not None + else (lambda x: content_str(x.get("content")) == "TERMINATE") + ) + + # Override the `generate_oai_reply` + self.replace_reply_func(ConversableAgent.generate_oai_reply, MultimodalConversableAgent.generate_oai_reply) + self.replace_reply_func( + ConversableAgent.a_generate_oai_reply, + MultimodalConversableAgent.a_generate_oai_reply, + ) + + def update_system_message(self, system_message: Union[dict[str, Any], list[str], str]): + """Update the system message. + + Args: + system_message (str): system message for the OpenAIWrapper inference. + """ + self._oai_system_message[0]["content"] = self._message_to_dict(system_message)["content"] + self._oai_system_message[0]["role"] = "system" + + @staticmethod + def _message_to_dict(message: Union[dict[str, Any], list[str], str]) -> dict: + """Convert a message to a dictionary. This implementation + handles the GPT-4V formatting for easier prompts. + + The message can be a string, a dictionary, or a list of dictionaries: + - If it's a string, it will be cast into a list and placed in the 'content' field. + - If it's a list, it will be directly placed in the 'content' field. + - If it's a dictionary, it is already in message dict format. The 'content' field of this dictionary + will be processed using the gpt4v_formatter. + """ + if isinstance(message, str): + return {"content": gpt4v_formatter(message, img_format="pil")} + if isinstance(message, list): + return {"content": message} + if isinstance(message, dict): + assert "content" in message, "The message dict must have a `content` field" + if isinstance(message["content"], str): + message = copy.deepcopy(message) + message["content"] = gpt4v_formatter(message["content"], img_format="pil") + try: + content_str(message["content"]) + except (TypeError, ValueError) as e: + print("The `content` field should be compatible with the content_str function!") + raise e + return message + raise ValueError(f"Unsupported message type: {type(message)}") + + def generate_oai_reply( + self, + messages: Optional[list[dict[str, Any]]] = None, + sender: Optional[Agent] = None, + config: Optional[OpenAIWrapper] = None, + ) -> tuple[bool, Optional[Union[str, dict[str, Any]]]]: + """Generate a reply using autogen.oai.""" + client = self.client if config is None else config + if client is None: + return False, None + if messages is None: + messages = self._oai_messages[sender] + + messages_with_b64_img = message_formatter_pil_to_b64(self._oai_system_message + messages) + + new_messages = [] + for message in messages_with_b64_img: + if 'tool_responses' in message: + for tool_response in message['tool_responses']: + tmp_image = None + tmp_list = [] + for ctx in message['content']: + if ctx['type'] == 'image_url': + tmp_image = ctx + tmp_list.append({ + 'role': 'tool', + 'tool_call_id': tool_response['tool_call_id'], + 'content': [message['content'][0]] + }) + if tmp_image: + tmp_list.append({ + 'role': 'user', + 'content': [ + {'type': 'text', 'text': 'I take a screenshot for the current state for you.'}, + tmp_image + ] + }) + new_messages.extend(tmp_list) + else: + new_messages.append(message) + messages_with_b64_img = new_messages.copy() + + + # TODO: #1143 handle token limit exceeded error + response = client.create( + context=messages[-1].pop("context", None), messages=messages_with_b64_img, agent=self.name + ) + + # TODO: line 301, line 271 is converting messages to dict. Can be removed after ChatCompletionMessage_to_dict is merged. + extracted_response = client.extract_text_or_completion_object(response)[0] + if not isinstance(extracted_response, str): + extracted_response = extracted_response.model_dump() + return True, extracted_response diff --git a/mm_agents/coact/autogen/agentchat/conversable_agent.py b/mm_agents/coact/autogen/agentchat/conversable_agent.py new file mode 100644 index 0000000..e4ee85f --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/conversable_agent.py @@ -0,0 +1,4023 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 +# +# Portions derived from https://github.com/microsoft/autogen are under the MIT License. +# SPDX-License-Identifier: MIT +import asyncio +import copy +import functools +import inspect +import json +import logging +import re +import threading +import warnings +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass +from inspect import signature +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generator, + Iterable, + Literal, + Optional, + TypeVar, + Union, +) + +from ..cache.cache import AbstractCache, Cache +from ..code_utils import ( + PYTHON_VARIANTS, + UNKNOWN, + check_can_use_docker_or_throw, + content_str, + decide_use_docker, + execute_code, + extract_code, + infer_lang, +) +from ..coding.base import CodeExecutor +from ..coding.factory import CodeExecutorFactory +from ..doc_utils import export_module +from ..events.agent_events import ( + ClearConversableAgentHistoryEvent, + ClearConversableAgentHistoryWarningEvent, + ConversableAgentUsageSummaryEvent, + ConversableAgentUsageSummaryNoCostIncurredEvent, + ErrorEvent, + ExecuteCodeBlockEvent, + ExecuteFunctionEvent, + ExecutedFunctionEvent, + GenerateCodeExecutionReplyEvent, + PostCarryoverProcessingEvent, + RunCompletionEvent, + TerminationAndHumanReplyNoInputEvent, + TerminationEvent, + UsingAutoReplyEvent, + create_received_event_model, +) +from ..exception_utils import InvalidCarryOverTypeError, SenderRequiredError +from ..io.base import IOStream +from ..io.run_response import AsyncRunResponse, AsyncRunResponseProtocol, RunResponse, RunResponseProtocol +from ..io.thread_io_stream import AsyncThreadIOStream, ThreadIOStream +from ..llm_config import LLMConfig +from ..oai.client import ModelClient, OpenAIWrapper +from ..runtime_logging import log_event, log_function_use, log_new_agent, logging_enabled +from ..tools import ChatContext, Tool, load_basemodels_if_needed, serialize_to_str +from .agent import Agent, LLMAgent +from .chat import ( + ChatResult, + _post_process_carryover_item, + _validate_recipients, + a_initiate_chats, + initiate_chats, +) +from .group.context_variables import ContextVariables +from .group.handoffs import Handoffs +from .utils import consolidate_chat_info, gather_usage_summary + +if TYPE_CHECKING: + from .group.on_condition import OnCondition + from .group.on_context_condition import OnContextCondition + +__all__ = ("ConversableAgent",) + +logger = logging.getLogger(__name__) + +F = TypeVar("F", bound=Callable[..., Any]) + + +@dataclass +@export_module("autogen") +class UpdateSystemMessage: + """Update the agent's system message before they reply + + Args: + content_updater: The format string or function to update the agent's system message. Can be a format string or a Callable. + If a string, it will be used as a template and substitute the context variables. + If a Callable, it should have the signature: + def my_content_updater(agent: ConversableAgent, messages: List[Dict[str, Any]]) -> str + """ + + content_updater: Union[Callable, str] + + def __post_init__(self): + if isinstance(self.content_updater, str): + # find all {var} in the string + vars = re.findall(r"\{(\w+)\}", self.content_updater) + if len(vars) == 0: + warnings.warn("Update function string contains no variables. This is probably unintended.") + + elif isinstance(self.content_updater, Callable): + sig = signature(self.content_updater) + if len(sig.parameters) != 2: + raise ValueError( + "The update function must accept two parameters of type ConversableAgent and List[Dict[str, Any]], respectively" + ) + if sig.return_annotation != str: + raise ValueError("The update function must return a string") + else: + raise ValueError("The update function must be either a string or a callable") + + +@export_module("autogen") +class ConversableAgent(LLMAgent): + """(In preview) A class for generic conversable agents which can be configured as assistant or user proxy. + + After receiving each message, the agent will send a reply to the sender unless the msg is a termination msg. + For example, AssistantAgent and UserProxyAgent are subclasses of this class, + configured with different default settings. + + To modify auto reply, override `generate_reply` method. + To disable/enable human response in every turn, set `human_input_mode` to "NEVER" or "ALWAYS". + To modify the way to get human input, override `get_human_input` method. + To modify the way to execute code blocks, single code block, or function call, override `execute_code_blocks`, + `run_code`, and `execute_function` methods respectively. + """ + + DEFAULT_CONFIG = False # False or dict, the default config for llm inference + MAX_CONSECUTIVE_AUTO_REPLY = 100 # maximum number of consecutive auto replies (subject to future change) + + DEFAULT_SUMMARY_PROMPT = "Summarize the takeaway from the conversation. Do not add any introductory phrases." + DEFAULT_SUMMARY_METHOD = "last_msg" + llm_config: Union[dict[str, Any], Literal[False]] + + def __init__( + self, + name: str, + system_message: Optional[Union[str, list]] = "You are a helpful AI Assistant.", + is_termination_msg: Optional[Callable[[dict[str, Any]], bool]] = None, + max_consecutive_auto_reply: Optional[int] = None, + human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "TERMINATE", + function_map: Optional[dict[str, Callable[..., Any]]] = None, + code_execution_config: Union[dict[str, Any], Literal[False]] = False, + llm_config: Optional[Union[LLMConfig, dict[str, Any], Literal[False]]] = None, + default_auto_reply: Union[str, dict[str, Any]] = "", + description: Optional[str] = None, + chat_messages: Optional[dict[Agent, list[dict[str, Any]]]] = None, + silent: Optional[bool] = None, + context_variables: Optional["ContextVariables"] = None, + functions: Union[list[Callable[..., Any]], Callable[..., Any]] = None, + update_agent_state_before_reply: Optional[ + Union[list[Union[Callable, UpdateSystemMessage]], Callable, UpdateSystemMessage] + ] = None, + handoffs: Optional[Handoffs] = None, + ): + """ + Args: + name (str): name of the agent. + system_message (str or list): system message for the ChatCompletion inference. + is_termination_msg (function): a function that takes a message in the form of a dictionary + and returns a boolean value indicating if this received message is a termination message. + The dict can contain the following keys: "content", "role", "name", "function_call". + max_consecutive_auto_reply (int): the maximum number of consecutive auto replies. + default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case). + When set to 0, no auto reply will be generated. + human_input_mode (str): whether to ask for human inputs every time a message is received. + Possible values are "ALWAYS", "TERMINATE", "NEVER". + (1) When "ALWAYS", the agent prompts for human input every time a message is received. + Under this mode, the conversation stops when the human input is "exit", + or when is_termination_msg is True and there is no human input. + (2) When "TERMINATE", the agent only prompts for human input only when a termination message is received or + the number of auto reply reaches the max_consecutive_auto_reply. + (3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops + when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True. + function_map (dict[str, callable]): Mapping function names (passed to openai) to callable functions, also used for tool calls. + code_execution_config (dict or False): config for the code execution. + To disable code execution, set to False. Otherwise, set to a dictionary with the following keys: + - work_dir (Optional, str): The working directory for the code execution. + If None, a default working directory will be used. + The default working directory is the "extensions" directory under + "path_to_autogen". + - use_docker (Optional, list, str or bool): The docker image to use for code execution. + Default is True, which means the code will be executed in a docker container. A default list of images will be used. + If a list or a str of image name(s) is provided, the code will be executed in a docker container + with the first image successfully pulled. + If False, the code will be executed in the current environment. + We strongly recommend using docker for code execution. + - timeout (Optional, int): The maximum execution time in seconds. + - last_n_messages (Experimental, int or str): The number of messages to look back for code execution. + If set to 'auto', it will scan backwards through all messages arriving since the agent last spoke, which is typically the last time execution was attempted. (Default: auto) + llm_config (LLMConfig or dict or False or None): llm inference configuration. + Please refer to [OpenAIWrapper.create](https://docs.ag2.ai/latest/docs/api-reference/autogen/OpenAIWrapper/#autogen.OpenAIWrapper.create) + for available options. + When using OpenAI or Azure OpenAI endpoints, please specify a non-empty 'model' either in `llm_config` or in each config of 'config_list' in `llm_config`. + To disable llm-based auto reply, set to False. + When set to None, will use self.DEFAULT_CONFIG, which defaults to False. + default_auto_reply (str or dict): default auto reply when no code execution or llm-based reply is generated. + description (str): a short description of the agent. This description is used by other agents + (e.g. the GroupChatManager) to decide when to call upon this agent. (Default: system_message) + chat_messages (dict or None): the previous chat messages that this agent had in the past with other agents. + Can be used to give the agent a memory by providing the chat history. This will allow the agent to + resume previous had conversations. Defaults to an empty chat history. + silent (bool or None): (Experimental) whether to print the message sent. If None, will use the value of + silent in each function. + context_variables (ContextVariables or None): Context variables that provide a persistent context for the agent. + Note: This will be a reference to a shared context for multi-agent chats. + Behaves like a dictionary with keys and values (akin to dict[str, Any]). + functions (List[Callable[..., Any]]): A list of functions to register with the agent, these will be wrapped up as tools and registered for LLM (not execution). + update_agent_state_before_reply (List[Callable[..., Any]]): A list of functions, including UpdateSystemMessage's, called to update the agent before it replies. + handoffs (Handoffs): Handoffs object containing all handoff transition conditions. + """ + self.handoffs = handoffs if handoffs is not None else Handoffs() + + # we change code_execution_config below and we have to make sure we don't change the input + # in case of UserProxyAgent, without this we could even change the default value {} + code_execution_config = ( + code_execution_config.copy() if hasattr(code_execution_config, "copy") else code_execution_config + ) + + # a dictionary of conversations, default value is list + if chat_messages is None: + self._oai_messages = defaultdict(list) + else: + self._oai_messages = chat_messages + + self._oai_system_message = [{"content": system_message, "role": "system"}] + self._description = description if description is not None else system_message + self._is_termination_msg = ( + is_termination_msg + if is_termination_msg is not None + else (lambda x: content_str(x.get("content")) == "TERMINATE") + ) + self.silent = silent + self.run_executor: Optional[ConversableAgent] = None + + # Take a copy to avoid modifying the given dict + if isinstance(llm_config, dict): + try: + llm_config = copy.deepcopy(llm_config) + except TypeError as e: + raise TypeError( + "Please implement __deepcopy__ method for each value class in llm_config to support deepcopy." + " Refer to the docs for more details: https://docs.ag2.ai/docs/user-guide/advanced-concepts/llm-configuration-deep-dive/#adding-http-client-in-llm_config-for-proxy" + ) from e + + self.llm_config = self._validate_llm_config(llm_config) + self.client = self._create_client(self.llm_config) + self._validate_name(name) + self._name = name + + if logging_enabled(): + log_new_agent(self, locals()) + + # Initialize standalone client cache object. + self.client_cache = None + + # To track UI tools + self._ui_tools: list[Tool] = [] + + self.human_input_mode = human_input_mode + self._max_consecutive_auto_reply = ( + max_consecutive_auto_reply if max_consecutive_auto_reply is not None else self.MAX_CONSECUTIVE_AUTO_REPLY + ) + self._consecutive_auto_reply_counter = defaultdict(int) + self._max_consecutive_auto_reply_dict = defaultdict(self.max_consecutive_auto_reply) + self._function_map = ( + {} + if function_map is None + else {name: callable for name, callable in function_map.items() if self._assert_valid_name(name)} + ) + self._default_auto_reply = default_auto_reply + self._reply_func_list = [] + self._human_input = [] + self.reply_at_receive = defaultdict(bool) + self.register_reply([Agent, None], ConversableAgent.generate_oai_reply) + self.register_reply([Agent, None], ConversableAgent.a_generate_oai_reply, ignore_async_in_sync_chat=True) + + self.context_variables = context_variables if context_variables is not None else ContextVariables() + + self._tools: list[Tool] = [] + + # Register functions to the agent + if isinstance(functions, list): + if not all(isinstance(func, Callable) for func in functions): + raise TypeError("All elements in the functions list must be callable") + self._add_functions(functions) + elif isinstance(functions, Callable): + self._add_single_function(functions) + elif functions is not None: + raise TypeError("Functions must be a callable or a list of callables") + + # Setting up code execution. + # Do not register code execution reply if code execution is disabled. + if code_execution_config is not False: + # If code_execution_config is None, set it to an empty dict. + if code_execution_config is None: + warnings.warn( + "Using None to signal a default code_execution_config is deprecated. " + "Use {} to use default or False to disable code execution.", + stacklevel=2, + ) + code_execution_config = {} + if not isinstance(code_execution_config, dict): + raise ValueError("code_execution_config must be a dict or False.") + + # We have got a valid code_execution_config. + self._code_execution_config: Union[dict[str, Any], Literal[False]] = code_execution_config + + if self._code_execution_config.get("executor") is not None: + if "use_docker" in self._code_execution_config: + raise ValueError( + "'use_docker' in code_execution_config is not valid when 'executor' is set. Use the appropriate arg in the chosen executor instead." + ) + + if "work_dir" in self._code_execution_config: + raise ValueError( + "'work_dir' in code_execution_config is not valid when 'executor' is set. Use the appropriate arg in the chosen executor instead." + ) + + if "timeout" in self._code_execution_config: + raise ValueError( + "'timeout' in code_execution_config is not valid when 'executor' is set. Use the appropriate arg in the chosen executor instead." + ) + + # Use the new code executor. + self._code_executor = CodeExecutorFactory.create(self._code_execution_config) + self.register_reply([Agent, None], ConversableAgent._generate_code_execution_reply_using_executor) + else: + # Legacy code execution using code_utils. + use_docker = self._code_execution_config.get("use_docker", None) + use_docker = decide_use_docker(use_docker) + check_can_use_docker_or_throw(use_docker) + self._code_execution_config["use_docker"] = use_docker + self.register_reply([Agent, None], ConversableAgent.generate_code_execution_reply) + else: + # Code execution is disabled. + self._code_execution_config = False + + self.register_reply([Agent, None], ConversableAgent.generate_tool_calls_reply) + self.register_reply([Agent, None], ConversableAgent.a_generate_tool_calls_reply, ignore_async_in_sync_chat=True) + self.register_reply([Agent, None], ConversableAgent.generate_function_call_reply) + self.register_reply( + [Agent, None], ConversableAgent.a_generate_function_call_reply, ignore_async_in_sync_chat=True + ) + self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply) + self.register_reply( + [Agent, None], ConversableAgent.a_check_termination_and_human_reply, ignore_async_in_sync_chat=True + ) + + # Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration. + # New hookable methods should be added to this list as required to support new agent capabilities. + self.hook_lists: dict[str, list[Callable[..., Any]]] = { + "process_last_received_message": [], + "process_all_messages_before_reply": [], + "process_message_before_send": [], + "update_agent_state": [], + } + + # Associate agent update state hooks + self._register_update_agent_state_before_reply(update_agent_state_before_reply) + + def _validate_name(self, name: str) -> None: + if not self.llm_config: + return + + if any([ + entry for entry in self.llm_config.config_list if entry.api_type == "openai" and re.search(r"\s", name) + ]): + raise ValueError(f"The name of the agent cannot contain any whitespace. The name provided is: '{name}'") + + def _get_display_name(self): + """Get the string representation of the agent. + + If you would like to change the standard string representation for an + instance of ConversableAgent, you can point it to another function. + In this example a function called _group_agent_str that returns a string: + agent._get_display_name = MethodType(_group_agent_str, agent) + """ + return self.name + + def __str__(self): + return self._get_display_name() + + def _add_functions(self, func_list: list[Callable[..., Any]]): + """Add (Register) a list of functions to the agent + + Args: + func_list (list[Callable[..., Any]]): A list of functions to register with the agent.""" + for func in func_list: + self._add_single_function(func) + + def _add_single_function(self, func: Callable, name: Optional[str] = None, description: Optional[str] = ""): + """Add a single function to the agent + + Args: + func (Callable): The function to register. + name (str): The name of the function. If not provided, the function's name will be used. + description (str): The description of the function, used by the LLM. If not provided, the function's docstring will be used. + """ + if name: + func._name = name + elif not hasattr(func, "_name"): + func._name = func.__name__ + + if hasattr(func, "_description") and func._description and not description: + # If the function already has a description, use it + description = func._description + else: + if description: + func._description = description + else: + # Use function's docstring, strip whitespace, fall back to empty string + description = (func.__doc__ or "").strip() + func._description = description + + # Register the function + self.register_for_llm(name=name, description=description, silent_override=True)(func) + + def _register_update_agent_state_before_reply( + self, functions: Optional[Union[list[Callable[..., Any]], Callable[..., Any]]] + ): + """ + Register functions that will be called when the agent is selected and before it speaks. + You can add your own validation or precondition functions here. + + Args: + functions (List[Callable[[], None]]): A list of functions to be registered. Each function + is called when the agent is selected and before it speaks. + """ + if functions is None: + return + if not isinstance(functions, list) and type(functions) not in [UpdateSystemMessage, Callable[..., Any]]: + raise ValueError("functions must be a list of callables") + + if not isinstance(functions, list): + functions = [functions] + + for func in functions: + if isinstance(func, UpdateSystemMessage): + # Wrapper function that allows this to be used in the update_agent_state hook + # Its primary purpose, however, is just to update the agent's system message + # Outer function to create a closure with the update function + def create_wrapper(update_func: UpdateSystemMessage): + def update_system_message_wrapper( + agent: ConversableAgent, messages: list[dict[str, Any]] + ) -> list[dict[str, Any]]: + if isinstance(update_func.content_updater, str): + # Templates like "My context variable passport is {passport}" will + # use the context_variables for substitution + sys_message = OpenAIWrapper.instantiate( + template=update_func.content_updater, + context=agent.context_variables.to_dict(), + allow_format_str_template=True, + ) + else: + sys_message = update_func.content_updater(agent, messages) + + agent.update_system_message(sys_message) + return messages + + return update_system_message_wrapper + + self.register_hook(hookable_method="update_agent_state", hook=create_wrapper(func)) + + else: + self.register_hook(hookable_method="update_agent_state", hook=func) + + @classmethod + def _validate_llm_config( + cls, llm_config: Optional[Union[LLMConfig, dict[str, Any], Literal[False]]] + ) -> Union[LLMConfig, Literal[False]]: + # if not(llm_config in (None, False) or isinstance(llm_config, [dict, LLMConfig])): + # raise ValueError( + # "llm_config must be a dict or False or None." + # ) + + if llm_config is None: + llm_config = LLMConfig.get_current_llm_config() + if llm_config is None: + llm_config = cls.DEFAULT_CONFIG + elif isinstance(llm_config, dict): + llm_config = LLMConfig(**llm_config) + elif isinstance(llm_config, LLMConfig): + llm_config = llm_config.copy() + elif llm_config is False: + pass + else: + raise ValueError("llm_config must be a LLMConfig, dict or False or None.") + + return llm_config + + @classmethod + def _create_client(cls, llm_config: Union[LLMConfig, Literal[False]]) -> Optional[OpenAIWrapper]: + return None if llm_config is False else OpenAIWrapper(**llm_config) + + @staticmethod + def _is_silent(agent: Agent, silent: Optional[bool] = False) -> bool: + return agent.silent if agent.silent is not None else silent + + @property + def name(self) -> str: + """Get the name of the agent.""" + return self._name + + @property + def description(self) -> str: + """Get the description of the agent.""" + return self._description + + @description.setter + def description(self, description: str): + """Set the description of the agent.""" + self._description = description + + @property + def code_executor(self) -> Optional[CodeExecutor]: + """The code executor used by this agent. Returns None if code execution is disabled.""" + if not hasattr(self, "_code_executor"): + return None + return self._code_executor + + def register_reply( + self, + trigger: Union[type[Agent], str, Agent, Callable[[Agent], bool], list], + reply_func: Callable, + position: int = 0, + config: Optional[Any] = None, + reset_config: Optional[Callable[..., Any]] = None, + *, + ignore_async_in_sync_chat: bool = False, + remove_other_reply_funcs: bool = False, + ): + """Register a reply function. + + The reply function will be called when the trigger matches the sender. + The function registered later will be checked earlier by default. + To change the order, set the position to a positive integer. + + Both sync and async reply functions can be registered. The sync reply function will be triggered + from both sync and async chats. However, an async reply function will only be triggered from async + chats (initiated with `ConversableAgent.a_initiate_chat`). If an `async` reply function is registered + and a chat is initialized with a sync function, `ignore_async_in_sync_chat` determines the behaviour as follows: + if `ignore_async_in_sync_chat` is set to `False` (default value), an exception will be raised, and + if `ignore_async_in_sync_chat` is set to `True`, the reply function will be ignored. + + Args: + trigger (Agent class, str, Agent instance, callable, or list): the trigger. + If a class is provided, the reply function will be called when the sender is an instance of the class. + If a string is provided, the reply function will be called when the sender's name matches the string. + If an agent instance is provided, the reply function will be called when the sender is the agent instance. + If a callable is provided, the reply function will be called when the callable returns True. + If a list is provided, the reply function will be called when any of the triggers in the list is activated. + If None is provided, the reply function will be called only when the sender is None. + Note: Be sure to register `None` as a trigger if you would like to trigger an auto-reply function with non-empty messages and `sender=None`. + reply_func (Callable): the reply function. + The function takes a recipient agent, a list of messages, a sender agent and a config as input and returns a reply message. + + ```python + def reply_func( + recipient: ConversableAgent, + messages: Optional[List[Dict]] = None, + sender: Optional[Agent] = None, + config: Optional[Any] = None, + ) -> Tuple[bool, Union[str, Dict, None]]: + ``` + position (int): the position of the reply function in the reply function list. + The function registered later will be checked earlier by default. + To change the order, set the position to a positive integer. + config (Any): the config to be passed to the reply function. + When an agent is reset, the config will be reset to the original value. + reset_config (Callable): the function to reset the config. + The function returns None. Signature: ```def reset_config(config: Any)``` + ignore_async_in_sync_chat (bool): whether to ignore the async reply function in sync chats. If `False`, an exception + will be raised if an async reply function is registered and a chat is initialized with a sync + function. + remove_other_reply_funcs (bool): whether to remove other reply functions when registering this reply function. + """ + if not isinstance(trigger, (type, str, Agent, Callable, list)): + raise ValueError("trigger must be a class, a string, an agent, a callable or a list.") + if remove_other_reply_funcs: + self._reply_func_list.clear() + self._reply_func_list.insert( + position, + { + "trigger": trigger, + "reply_func": reply_func, + "config": copy.copy(config), + "init_config": config, + "reset_config": reset_config, + "ignore_async_in_sync_chat": ignore_async_in_sync_chat and inspect.iscoroutinefunction(reply_func), + }, + ) + + def replace_reply_func(self, old_reply_func: Callable, new_reply_func: Callable): + """Replace a registered reply function with a new one. + + Args: + old_reply_func (Callable): the old reply function to be replaced. + new_reply_func (Callable): the new reply function to replace the old one. + """ + for f in self._reply_func_list: + if f["reply_func"] == old_reply_func: + f["reply_func"] = new_reply_func + + @staticmethod + def _get_chats_to_run( + chat_queue: list[dict[str, Any]], + recipient: Agent, + messages: Optional[list[dict[str, Any]]], + sender: Agent, + config: Any, + ) -> list[dict[str, Any]]: + """A simple chat reply function. + This function initiate one or a sequence of chats between the "recipient" and the agents in the + chat_queue. + + It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue. + + Returns: + Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated. + """ + last_msg = messages[-1].get("content") + chat_to_run = [] + for i, c in enumerate(chat_queue): + current_c = c.copy() + if current_c.get("sender") is None: + current_c["sender"] = recipient + message = current_c.get("message") + # If message is not provided in chat_queue, we by default use the last message from the original chat history as the first message in this nested chat (for the first chat in the chat queue). + # NOTE: This setting is prone to change. + if message is None and i == 0: + message = last_msg + if callable(message): + message = message(recipient, messages, sender, config) + # We only run chat that has a valid message. NOTE: This is prone to change depending on applications. + if message: + current_c["message"] = message + chat_to_run.append(current_c) + return chat_to_run + + @staticmethod + def _process_nested_chat_carryover( + chat: dict[str, Any], + recipient: Agent, + messages: list[dict[str, Any]], + sender: Agent, + config: Any, + trim_n_messages: int = 0, + ) -> None: + """Process carryover messages for a nested chat (typically for the first chat of a group chat) + + The carryover_config key is a dictionary containing: + "summary_method": The method to use to summarise the messages, can be "all", "last_msg", "reflection_with_llm" or a Callable + "summary_args": Optional arguments for the summary method + + Supported carryover 'summary_methods' are: + "all" - all messages will be incorporated + "last_msg" - the last message will be incorporated + "reflection_with_llm" - an llm will summarise all the messages and the summary will be incorporated as a single message + Callable - a callable with the signature: my_method(agent: ConversableAgent, messages: List[Dict[str, Any]]) -> str + + Args: + chat: The chat dictionary containing the carryover configuration + recipient: The recipient agent + messages: The messages from the parent chat + sender: The sender agent + config: The LLM configuration + trim_n_messages: The number of latest messages to trim from the messages list + """ + + def concat_carryover(chat_message: str, carryover_message: Union[str, list[dict[str, Any]]]) -> str: + """Concatenate the carryover message to the chat message.""" + prefix = f"{chat_message}\n" if chat_message else "" + + if isinstance(carryover_message, str): + content = carryover_message + elif isinstance(carryover_message, list): + content = "\n".join( + msg["content"] for msg in carryover_message if "content" in msg and msg["content"] is not None + ) + else: + raise ValueError("Carryover message must be a string or a list of dictionaries") + + return f"{prefix}Context:\n{content}" + + carryover_config = chat["carryover_config"] + + if "summary_method" not in carryover_config: + raise ValueError("Carryover configuration must contain a 'summary_method' key") + + carryover_summary_method = carryover_config["summary_method"] + carryover_summary_args = carryover_config.get("summary_args") or {} + + chat_message = "" + message = chat.get("message") + + # If the message is a callable, run it and get the result + if message: + chat_message = message(recipient, messages, sender, config) if callable(message) else message + + # deep copy and trim the latest messages + content_messages = copy.deepcopy(messages) + content_messages = content_messages[:-trim_n_messages] + + if carryover_summary_method == "all": + # Put a string concatenated value of all parent messages into the first message + # (e.g. message = \nContext: \n\n\n...) + carry_over_message = concat_carryover(chat_message, content_messages) + + elif carryover_summary_method == "last_msg": + # (e.g. message = \nContext: \n) + carry_over_message = concat_carryover(chat_message, content_messages[-1]["content"]) + + elif carryover_summary_method == "reflection_with_llm": + # (e.g. message = \nContext: \n) + + # Add the messages to the nested chat agent for reflection (we'll clear after reflection) + chat["recipient"]._oai_messages[sender] = content_messages + + carry_over_message_llm = ConversableAgent._reflection_with_llm_as_summary( + sender=sender, + recipient=chat["recipient"], # Chat recipient LLM config will be used for the reflection + summary_args=carryover_summary_args, + ) + + recipient._oai_messages[sender] = [] + + carry_over_message = concat_carryover(chat_message, carry_over_message_llm) + + elif isinstance(carryover_summary_method, Callable): + # (e.g. message = \nContext: \n) + carry_over_message_result = carryover_summary_method(recipient, content_messages, carryover_summary_args) + + carry_over_message = concat_carryover(chat_message, carry_over_message_result) + + chat["message"] = carry_over_message + + @staticmethod + def _process_chat_queue_carryover( + chat_queue: list[dict[str, Any]], + recipient: Agent, + messages: Union[str, Callable[..., Any]], + sender: Agent, + config: Any, + trim_messages: int = 2, + ) -> tuple[bool, Optional[str]]: + """Process carryover configuration for the first chat in the queue. + + Args: + chat_queue: List of chat configurations + recipient: Receiving agent + messages: Chat messages + sender: Sending agent + config: LLM configuration + trim_messages: Number of messages to trim for nested chat carryover (default 2 for nested chat in group chats) + + Returns: + Tuple containing: + - restore_flag: Whether the original message needs to be restored + - original_message: The original message to restore (if any) + """ + restore_chat_queue_message = False + original_chat_queue_message = None + + # Carryover configuration allowed on the first chat in the queue only, trim the last two messages specifically for group chat nested chat carryover as these are the messages for the transition to the nested chat agent + if len(chat_queue) > 0 and "carryover_config" in chat_queue[0]: + if "message" in chat_queue[0]: + # As we're updating the message in the nested chat queue, we need to restore it after finishing this nested chat. + restore_chat_queue_message = True + original_chat_queue_message = chat_queue[0]["message"] + + # TODO Check the trimming required if not a group chat, it may not be 2 because other chats don't have the group transition messages. We may need to add as a carryover_config parameter. + ConversableAgent._process_nested_chat_carryover( + chat=chat_queue[0], + recipient=recipient, + messages=messages, + sender=sender, + config=config, + trim_n_messages=trim_messages, + ) + + return restore_chat_queue_message, original_chat_queue_message + + @staticmethod + def _summary_from_nested_chats( + chat_queue: list[dict[str, Any]], + recipient: Agent, + messages: Optional[list[dict[str, Any]]], + sender: Agent, + config: Any, + ) -> tuple[bool, Union[str, None]]: + """A simple chat reply function. + This function initiate one or a sequence of chats between the "recipient" and the agents in the + chat_queue. + + It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue. + + The first chat in the queue can contain a 'carryover_config' which is a dictionary that denotes how to carryover messages from the parent chat into the first chat of the nested chats). Only applies to the first chat. + e.g.: carryover_summarize_chat_config = {"summary_method": "reflection_with_llm", "summary_args": None} + summary_method can be "last_msg", "all", "reflection_with_llm", Callable + The Callable signature: my_method(agent: ConversableAgent, messages: List[Dict[str, Any]]) -> str + The summary will be concatenated to the message of the first chat in the queue. + + Returns: + Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated. + """ + # Process carryover configuration + restore_chat_queue_message, original_chat_queue_message = ConversableAgent._process_chat_queue_carryover( + chat_queue, recipient, messages, sender, config + ) + + chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config) + if not chat_to_run: + return True, None + res = initiate_chats(chat_to_run) + + # We need to restore the chat queue message if it has been modified so that it will be the original message for subsequent uses + if restore_chat_queue_message: + chat_queue[0]["message"] = original_chat_queue_message + + return True, res[-1].summary + + @staticmethod + async def _a_summary_from_nested_chats( + chat_queue: list[dict[str, Any]], + recipient: Agent, + messages: Optional[list[dict[str, Any]]], + sender: Agent, + config: Any, + ) -> tuple[bool, Union[str, None]]: + """A simple chat reply function. + This function initiate one or a sequence of chats between the "recipient" and the agents in the + chat_queue. + + It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue. + + The first chat in the queue can contain a 'carryover_config' which is a dictionary that denotes how to carryover messages from the parent chat into the first chat of the nested chats). Only applies to the first chat. + e.g.: carryover_summarize_chat_config = {"summary_method": "reflection_with_llm", "summary_args": None} + summary_method can be "last_msg", "all", "reflection_with_llm", Callable + The Callable signature: my_method(agent: ConversableAgent, messages: List[Dict[str, Any]]) -> str + The summary will be concatenated to the message of the first chat in the queue. + + Returns: + Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated. + """ + # Process carryover configuration + restore_chat_queue_message, original_chat_queue_message = ConversableAgent._process_chat_queue_carryover( + chat_queue, recipient, messages, sender, config + ) + + chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config) + if not chat_to_run: + return True, None + res = await a_initiate_chats(chat_to_run) + index_of_last_chat = chat_to_run[-1]["chat_id"] + + # We need to restore the chat queue message if it has been modified so that it will be the original message for subsequent uses + if restore_chat_queue_message: + chat_queue[0]["message"] = original_chat_queue_message + + return True, res[index_of_last_chat].summary + + def register_nested_chats( + self, + chat_queue: list[dict[str, Any]], + trigger: Union[type[Agent], str, Agent, Callable[[Agent], bool], list], + reply_func_from_nested_chats: Union[str, Callable[..., Any]] = "summary_from_nested_chats", + position: int = 2, + use_async: Union[bool, None] = None, + **kwargs: Any, + ) -> None: + """Register a nested chat reply function. + + Args: + chat_queue (list): a list of chat objects to be initiated. If use_async is used, then all messages in chat_queue must have a chat-id associated with them. + trigger (Agent class, str, Agent instance, callable, or list): refer to `register_reply` for details. + reply_func_from_nested_chats (Callable, str): the reply function for the nested chat. + The function takes a chat_queue for nested chat, recipient agent, a list of messages, a sender agent and a config as input and returns a reply message. + Default to "summary_from_nested_chats", which corresponds to a built-in reply function that get summary from the nested chat_queue. + ```python + def reply_func_from_nested_chats( + chat_queue: List[Dict], + recipient: ConversableAgent, + messages: Optional[List[Dict]] = None, + sender: Optional[Agent] = None, + config: Optional[Any] = None, + ) -> Tuple[bool, Union[str, Dict, None]]: + ``` + position (int): Ref to `register_reply` for details. Default to 2. It means we first check the termination and human reply, then check the registered nested chat reply. + use_async: Uses a_initiate_chats internally to start nested chats. If the original chat is initiated with a_initiate_chats, you may set this to true so nested chats do not run in sync. + kwargs: Ref to `register_reply` for details. + """ + if use_async: + for chat in chat_queue: + if chat.get("chat_id") is None: + raise ValueError("chat_id is required for async nested chats") + + if use_async: + if reply_func_from_nested_chats == "summary_from_nested_chats": + reply_func_from_nested_chats = self._a_summary_from_nested_chats + if not callable(reply_func_from_nested_chats) or not inspect.iscoroutinefunction( + reply_func_from_nested_chats + ): + raise ValueError("reply_func_from_nested_chats must be a callable and a coroutine") + + async def wrapped_reply_func(recipient, messages=None, sender=None, config=None): + return await reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config) + + else: + if reply_func_from_nested_chats == "summary_from_nested_chats": + reply_func_from_nested_chats = self._summary_from_nested_chats + if not callable(reply_func_from_nested_chats): + raise ValueError("reply_func_from_nested_chats must be a callable") + + def wrapped_reply_func(recipient, messages=None, sender=None, config=None): + return reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config) + + functools.update_wrapper(wrapped_reply_func, reply_func_from_nested_chats) + + self.register_reply( + trigger, + wrapped_reply_func, + position, + kwargs.get("config"), + kwargs.get("reset_config"), + ignore_async_in_sync_chat=( + not use_async if use_async is not None else kwargs.get("ignore_async_in_sync_chat") + ), + ) + + @property + def system_message(self) -> str: + """Return the system message.""" + return self._oai_system_message[0]["content"] + + def update_system_message(self, system_message: str) -> None: + """Update the system message. + + Args: + system_message (str): system message for the ChatCompletion inference. + """ + self._oai_system_message[0]["content"] = system_message + + def update_max_consecutive_auto_reply(self, value: int, sender: Optional[Agent] = None): + """Update the maximum number of consecutive auto replies. + + Args: + value (int): the maximum number of consecutive auto replies. + sender (Agent): when the sender is provided, only update the max_consecutive_auto_reply for that sender. + """ + if sender is None: + self._max_consecutive_auto_reply = value + for k in self._max_consecutive_auto_reply_dict: + self._max_consecutive_auto_reply_dict[k] = value + else: + self._max_consecutive_auto_reply_dict[sender] = value + + def max_consecutive_auto_reply(self, sender: Optional[Agent] = None) -> int: + """The maximum number of consecutive auto replies.""" + return self._max_consecutive_auto_reply if sender is None else self._max_consecutive_auto_reply_dict[sender] + + @property + def chat_messages(self) -> dict[Agent, list[dict[str, Any]]]: + """A dictionary of conversations from agent to list of messages.""" + return self._oai_messages + + def chat_messages_for_summary(self, agent: Agent) -> list[dict[str, Any]]: + """A list of messages as a conversation to summarize.""" + return self._oai_messages[agent] + + def last_message(self, agent: Optional[Agent] = None) -> Optional[dict[str, Any]]: + """The last message exchanged with the agent. + + Args: + agent (Agent): The agent in the conversation. + If None and more than one agent's conversations are found, an error will be raised. + If None and only one conversation is found, the last message of the only conversation will be returned. + + Returns: + The last message exchanged with the agent. + """ + if agent is None: + n_conversations = len(self._oai_messages) + if n_conversations == 0: + return None + if n_conversations == 1: + for conversation in self._oai_messages.values(): + return conversation[-1] + raise ValueError("More than one conversation is found. Please specify the sender to get the last message.") + if agent not in self._oai_messages: + raise KeyError( + f"The agent '{agent.name}' is not present in any conversation. No history available for this agent." + ) + return self._oai_messages[agent][-1] + + @property + def use_docker(self) -> Union[bool, str, None]: + """Bool value of whether to use docker to execute the code, + or str value of the docker image name to use, or None when code execution is disabled. + """ + return None if self._code_execution_config is False else self._code_execution_config.get("use_docker") + + @staticmethod + def _message_to_dict(message: Union[dict[str, Any], str]) -> dict: + """Convert a message to a dictionary. + + The message can be a string or a dictionary. The string will be put in the "content" field of the new dictionary. + """ + if isinstance(message, str): + return {"content": message} + elif isinstance(message, dict): + return message + else: + return dict(message) + + @staticmethod + def _normalize_name(name): + """LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_". + + Prefer _assert_valid_name for validating user configuration or input + """ + return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64] + + @staticmethod + def _assert_valid_name(name): + """Ensure that configured names are valid, raises ValueError if not. + + For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API. + """ + if not re.match(r"^[a-zA-Z0-9_-]+$", name): + raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.") + if len(name) > 64: + raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.") + return name + + def _append_oai_message( + self, message: Union[dict[str, Any], str], role, conversation_id: Agent, is_sending: bool + ) -> bool: + """Append a message to the ChatCompletion conversation. + + If the message received is a string, it will be put in the "content" field of the new dictionary. + If the message received is a dictionary but does not have any of the three fields "content", "function_call", or "tool_calls", + this message is not a valid ChatCompletion message. + If only "function_call" or "tool_calls" is provided, "content" will be set to None if not provided, and the role of the message will be forced "assistant". + + Args: + message (dict or str): message to be appended to the ChatCompletion conversation. + role (str): role of the message, can be "assistant" or "function". + conversation_id (Agent): id of the conversation, should be the recipient or sender. + is_sending (bool): If the agent (aka self) is sending to the conversation_id agent, otherwise receiving. + + Returns: + bool: whether the message is appended to the ChatCompletion conversation. + """ + message = self._message_to_dict(message) + # create oai message to be appended to the oai conversation that can be passed to oai directly. + oai_message = { + k: message[k] + for k in ("content", "function_call", "tool_calls", "tool_responses", "tool_call_id", "name", "context") + if k in message and message[k] is not None + } + if "content" not in oai_message: + if "function_call" in oai_message or "tool_calls" in oai_message: + oai_message["content"] = None # if only function_call is provided, content will be set to None. + else: + return False + + if message.get("role") in ["function", "tool"]: + oai_message["role"] = message.get("role") + if "tool_responses" in oai_message: + for tool_response in oai_message["tool_responses"]: + tool_response["content"] = str(tool_response["content"]) + elif "override_role" in message: + # If we have a direction to override the role then set the + # role accordingly. Used to customise the role for the + # select speaker prompt. + oai_message["role"] = message.get("override_role") + else: + oai_message["role"] = role + + if oai_message.get("function_call", False) or oai_message.get("tool_calls", False): + oai_message["role"] = "assistant" # only messages with role 'assistant' can have a function call. + elif "name" not in oai_message: + # If we don't have a name field, append it + if is_sending: + oai_message["name"] = self.name + else: + oai_message["name"] = conversation_id.name + + self._oai_messages[conversation_id].append(oai_message) + + return True + + def _process_message_before_send( + self, message: Union[dict[str, Any], str], recipient: Agent, silent: bool + ) -> Union[dict[str, Any], str]: + """Process the message before sending it to the recipient.""" + hook_list = self.hook_lists["process_message_before_send"] + for hook in hook_list: + message = hook( + sender=self, message=message, recipient=recipient, silent=ConversableAgent._is_silent(self, silent) + ) + return message + + def send( + self, + message: Union[dict[str, Any], str], + recipient: Agent, + request_reply: Optional[bool] = None, + silent: Optional[bool] = False, + ): + """Send a message to another agent. + + Args: + message (dict or str): message to be sent. + The message could contain the following fields: + - content (str or List): Required, the content of the message. (Can be None) + - function_call (str): the name of the function to be called. + - name (str): the name of the function to be called. + - role (str): the role of the message, any role that is not "function" + will be modified to "assistant". + - context (dict): the context of the message, which will be passed to + [OpenAIWrapper.create](https://docs.ag2.ai/latest/docs/api-reference/autogen/OpenAIWrapper/#autogen.OpenAIWrapper.create). + For example, one agent can send a message A as: + ```python + { + "content": lambda context: context["use_tool_msg"], + "context": {"use_tool_msg": "Use tool X if they are relevant."}, + } + ``` + Next time, one agent can send a message B with a different "use_tool_msg". + Then the content of message A will be refreshed to the new "use_tool_msg". + So effectively, this provides a way for an agent to send a "link" and modify + the content of the "link" later. + recipient (Agent): the recipient of the message. + request_reply (bool or None): whether to request a reply from the recipient. + silent (bool or None): (Experimental) whether to print the message sent. + + Raises: + ValueError: if the message can't be converted into a valid ChatCompletion message. + """ + message = self._process_message_before_send(message, recipient, ConversableAgent._is_silent(self, silent)) + # When the agent composes and sends the message, the role of the message is "assistant" + # unless it's "function". + valid = self._append_oai_message(message, "assistant", recipient, is_sending=True) + if valid: + recipient.receive(message, self, request_reply, silent) + else: + raise ValueError( + "Message can't be converted into a valid ChatCompletion message. Either content or function_call must be provided." + ) + + async def a_send( + self, + message: Union[dict[str, Any], str], + recipient: Agent, + request_reply: Optional[bool] = None, + silent: Optional[bool] = False, + ): + """(async) Send a message to another agent. + + Args: + message (dict or str): message to be sent. + The message could contain the following fields: + - content (str or List): Required, the content of the message. (Can be None) + - function_call (str): the name of the function to be called. + - name (str): the name of the function to be called. + - role (str): the role of the message, any role that is not "function" + will be modified to "assistant". + - context (dict): the context of the message, which will be passed to + [OpenAIWrapper.create](https://docs.ag2.ai/latest/docs/api-reference/autogen/OpenAIWrapper/#autogen.OpenAIWrapper.create). + For example, one agent can send a message A as: + ```python + { + "content": lambda context: context["use_tool_msg"], + "context": {"use_tool_msg": "Use tool X if they are relevant."}, + } + ``` + Next time, one agent can send a message B with a different "use_tool_msg". + Then the content of message A will be refreshed to the new "use_tool_msg". + So effectively, this provides a way for an agent to send a "link" and modify + the content of the "link" later. + recipient (Agent): the recipient of the message. + request_reply (bool or None): whether to request a reply from the recipient. + silent (bool or None): (Experimental) whether to print the message sent. + + Raises: + ValueError: if the message can't be converted into a valid ChatCompletion message. + """ + message = self._process_message_before_send(message, recipient, ConversableAgent._is_silent(self, silent)) + # When the agent composes and sends the message, the role of the message is "assistant" + # unless it's "function". + valid = self._append_oai_message(message, "assistant", recipient, is_sending=True) + if valid: + await recipient.a_receive(message, self, request_reply, silent) + else: + raise ValueError( + "Message can't be converted into a valid ChatCompletion message. Either content or function_call must be provided." + ) + + def _print_received_message(self, message: Union[dict[str, Any], str], sender: Agent, skip_head: bool = False): + message = self._message_to_dict(message) + message_model = create_received_event_model(event=message, sender=sender, recipient=self) + iostream = IOStream.get_default() + # message_model.print(iostream.print) + iostream.send(message_model) + + def _process_received_message(self, message: Union[dict[str, Any], str], sender: Agent, silent: bool): + # When the agent receives a message, the role of the message is "user". (If 'role' exists and is 'function', it will remain unchanged.) + valid = self._append_oai_message(message, "user", sender, is_sending=False) + if logging_enabled(): + log_event(self, "received_message", message=message, sender=sender.name, valid=valid) + + if not valid: + raise ValueError( + "Received message can't be converted into a valid ChatCompletion message. Either content or function_call must be provided." + ) + + if not ConversableAgent._is_silent(sender, silent): + self._print_received_message(message, sender) + + def receive( + self, + message: Union[dict[str, Any], str], + sender: Agent, + request_reply: Optional[bool] = None, + silent: Optional[bool] = False, + ): + """Receive a message from another agent. + + Once a message is received, this function sends a reply to the sender or stop. + The reply can be generated automatically or entered manually by a human. + + Args: + message (dict or str): message from the sender. If the type is dict, it may contain the following reserved fields (either content or function_call need to be provided). + 1. "content": content of the message, can be None. + 2. "function_call": a dictionary containing the function name and arguments. (deprecated in favor of "tool_calls") + 3. "tool_calls": a list of dictionaries containing the function name and arguments. + 4. "role": role of the message, can be "assistant", "user", "function", "tool". + This field is only needed to distinguish between "function" or "assistant"/"user". + 5. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name. + 6. "context" (dict): the context of the message, which will be passed to + [OpenAIWrapper.create](https://docs.ag2.ai/latest/docs/api-reference/autogen/OpenAIWrapper/#autogen.OpenAIWrapper.create). + sender: sender of an Agent instance. + request_reply (bool or None): whether a reply is requested from the sender. + If None, the value is determined by `self.reply_at_receive[sender]`. + silent (bool or None): (Experimental) whether to print the message received. + + Raises: + ValueError: if the message can't be converted into a valid ChatCompletion message. + """ + self._process_received_message(message, sender, silent) + if request_reply is False or (request_reply is None and self.reply_at_receive[sender] is False): + return + reply = self.generate_reply(messages=self.chat_messages[sender], sender=sender) + if reply is not None: + self.send(reply, sender, silent=silent) + + async def a_receive( + self, + message: Union[dict[str, Any], str], + sender: Agent, + request_reply: Optional[bool] = None, + silent: Optional[bool] = False, + ): + """(async) Receive a message from another agent. + + Once a message is received, this function sends a reply to the sender or stop. + The reply can be generated automatically or entered manually by a human. + + Args: + message (dict or str): message from the sender. If the type is dict, it may contain the following reserved fields (either content or function_call need to be provided). + 1. "content": content of the message, can be None. + 2. "function_call": a dictionary containing the function name and arguments. (deprecated in favor of "tool_calls") + 3. "tool_calls": a list of dictionaries containing the function name and arguments. + 4. "role": role of the message, can be "assistant", "user", "function". + This field is only needed to distinguish between "function" or "assistant"/"user". + 5. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name. + 6. "context" (dict): the context of the message, which will be passed to + [OpenAIWrapper.create](https://docs.ag2.ai/latest/docs/api-reference/autogen/OpenAIWrapper/#autogen.OpenAIWrapper.create). + sender: sender of an Agent instance. + request_reply (bool or None): whether a reply is requested from the sender. + If None, the value is determined by `self.reply_at_receive[sender]`. + silent (bool or None): (Experimental) whether to print the message received. + + Raises: + ValueError: if the message can't be converted into a valid ChatCompletion message. + """ + self._process_received_message(message, sender, silent) + if request_reply is False or (request_reply is None and self.reply_at_receive[sender] is False): + return + reply = await self.a_generate_reply(messages=self.chat_messages[sender], sender=sender) + if reply is not None: + await self.a_send(reply, sender, silent=silent) + + def _prepare_chat( + self, + recipient: "ConversableAgent", + clear_history: bool, + prepare_recipient: bool = True, + reply_at_receive: bool = True, + ) -> None: + self.reset_consecutive_auto_reply_counter(recipient) + self.reply_at_receive[recipient] = reply_at_receive + if clear_history: + self.clear_history(recipient) + self._human_input = [] + if prepare_recipient: + recipient._prepare_chat(self, clear_history, False, reply_at_receive) + + def _raise_exception_on_async_reply_functions(self) -> None: + """Raise an exception if any async reply functions are registered. + + Raises: + RuntimeError: if any async reply functions are registered. + """ + reply_functions = { + f["reply_func"] for f in self._reply_func_list if not f.get("ignore_async_in_sync_chat", False) + } + + async_reply_functions = [f for f in reply_functions if inspect.iscoroutinefunction(f)] + if async_reply_functions: + msg = ( + "Async reply functions can only be used with ConversableAgent.a_initiate_chat(). The following async reply functions are found: " + + ", ".join([f.__name__ for f in async_reply_functions]) + ) + + raise RuntimeError(msg) + + def initiate_chat( + self, + recipient: "ConversableAgent", + clear_history: bool = True, + silent: Optional[bool] = False, + cache: Optional[AbstractCache] = None, + max_turns: Optional[int] = None, + summary_method: Optional[Union[str, Callable[..., Any]]] = DEFAULT_SUMMARY_METHOD, + summary_args: Optional[dict[str, Any]] = {}, + message: Optional[Union[dict[str, Any], str, Callable[..., Any]]] = None, + **kwargs: Any, + ) -> ChatResult: + """Initiate a chat with the recipient agent. + + Reset the consecutive auto reply counter. + If `clear_history` is True, the chat history with the recipient agent will be cleared. + + + Args: + recipient: the recipient agent. + clear_history (bool): whether to clear the chat history with the agent. Default is True. + silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False. + cache (AbstractCache or None): the cache client to be used for this conversation. Default is None. + max_turns (int or None): the maximum number of turns for the chat between the two agents. One turn means one conversation round trip. Note that this is different from + `max_consecutive_auto_reply` which is the maximum number of consecutive auto replies; and it is also different from `max_rounds` in GroupChat which is the maximum number of rounds in a group chat session. + If max_turns is set to None, the chat will continue until a termination condition is met. Default is None. + summary_method (str or callable): a method to get a summary from the chat. Default is DEFAULT_SUMMARY_METHOD, i.e., "last_msg". + Supported strings are "last_msg" and "reflection_with_llm": + - when set to "last_msg", it returns the last message of the dialog as the summary. + - when set to "reflection_with_llm", it returns a summary extracted using an llm client. + `llm_config` must be set in either the recipient or sender. + + A callable summary_method should take the recipient and sender agent in a chat as input and return a string of summary. E.g., + + ```python + def my_summary_method( + sender: ConversableAgent, + recipient: ConversableAgent, + summary_args: dict, + ): + return recipient.last_message(sender)["content"] + ``` + summary_args (dict): a dictionary of arguments to be passed to the summary_method. + One example key is "summary_prompt", and value is a string of text used to prompt a LLM-based agent (the sender or recipient agent) to reflect + on the conversation and extract a summary when summary_method is "reflection_with_llm". + The default summary_prompt is DEFAULT_SUMMARY_PROMPT, i.e., "Summarize takeaway from the conversation. Do not add any introductory phrases. If the intended request is NOT properly addressed, please point it out." + Another available key is "summary_role", which is the role of the message sent to the agent in charge of summarizing. Default is "system". + message (str, dict or Callable): the initial message to be sent to the recipient. Needs to be provided. Otherwise, input() will be called to get the initial message. + - If a string or a dict is provided, it will be used as the initial message. `generate_init_message` is called to generate the initial message for the agent based on this string and the context. + If dict, it may contain the following reserved fields (either content or tool_calls need to be provided). + + 1. "content": content of the message, can be None. + 2. "function_call": a dictionary containing the function name and arguments. (deprecated in favor of "tool_calls") + 3. "tool_calls": a list of dictionaries containing the function name and arguments. + 4. "role": role of the message, can be "assistant", "user", "function". + This field is only needed to distinguish between "function" or "assistant"/"user". + 5. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name. + 6. "context" (dict): the context of the message, which will be passed to + `OpenAIWrapper.create`. + + - If a callable is provided, it will be called to get the initial message in the form of a string or a dict. + If the returned type is dict, it may contain the reserved fields mentioned above. + + Example of a callable message (returning a string): + + ```python + def my_message( + sender: ConversableAgent, recipient: ConversableAgent, context: dict + ) -> Union[str, Dict]: + carryover = context.get("carryover", "") + if isinstance(message, list): + carryover = carryover[-1] + final_msg = "Write a blogpost." + "\\nContext: \\n" + carryover + return final_msg + ``` + + Example of a callable message (returning a dict): + + ```python + def my_message( + sender: ConversableAgent, recipient: ConversableAgent, context: dict + ) -> Union[str, Dict]: + final_msg = {} + carryover = context.get("carryover", "") + if isinstance(message, list): + carryover = carryover[-1] + final_msg["content"] = "Write a blogpost." + "\\nContext: \\n" + carryover + final_msg["context"] = {"prefix": "Today I feel"} + return final_msg + ``` + **kwargs: any additional information. It has the following reserved fields: + - "carryover": a string or a list of string to specify the carryover information to be passed to this chat. + If provided, we will combine this carryover (by attaching a "context: " string and the carryover content after the message content) with the "message" content when generating the initial chat + message in `generate_init_message`. + - "verbose": a boolean to specify whether to print the message and carryover in a chat. Default is False. + + Raises: + RuntimeError: if any async reply functions are registered and not ignored in sync chat. + + Returns: + ChatResult: an ChatResult object. + """ + iostream = IOStream.get_default() + + cache = Cache.get_current_cache(cache) + _chat_info = locals().copy() + _chat_info["sender"] = self + consolidate_chat_info(_chat_info, uniform_sender=self) + for agent in [self, recipient]: + agent._raise_exception_on_async_reply_functions() + agent.previous_cache = agent.client_cache + agent.client_cache = cache + if isinstance(max_turns, int): + self._prepare_chat(recipient, clear_history, reply_at_receive=False) + for i in range(max_turns): + # check recipient max consecutive auto reply limit + if self._consecutive_auto_reply_counter[recipient] >= recipient._max_consecutive_auto_reply: + break + if i == 0: + if isinstance(message, Callable): + msg2send = message(_chat_info["sender"], _chat_info["recipient"], kwargs) + else: + msg2send = self.generate_init_message(message, **kwargs) + else: + msg2send = self.generate_reply(messages=self.chat_messages[recipient], sender=recipient) + if msg2send is None: + break + self.send(msg2send, recipient, request_reply=True, silent=silent) + + else: # No breaks in the for loop, so we have reached max turns + iostream.send(TerminationEvent(termination_reason=f"Maximum turns ({max_turns}) reached")) + else: + self._prepare_chat(recipient, clear_history) + if isinstance(message, Callable): + msg2send = message(_chat_info["sender"], _chat_info["recipient"], kwargs) + else: + msg2send = self.generate_init_message(message, **kwargs) + self.send(msg2send, recipient, silent=silent) + summary = self._summarize_chat( + summary_method, + summary_args, + recipient, + cache=cache, + ) + for agent in [self, recipient]: + agent.client_cache = agent.previous_cache + agent.previous_cache = None + chat_result = ChatResult( + chat_history=self.chat_messages[recipient], + summary=summary, + cost=gather_usage_summary([self, recipient]), + human_input=self._human_input, + ) + return chat_result + + def run( + self, + recipient: Optional["ConversableAgent"] = None, + clear_history: bool = True, + silent: Optional[bool] = False, + cache: Optional[AbstractCache] = None, + max_turns: Optional[int] = None, + summary_method: Optional[Union[str, Callable[..., Any]]] = DEFAULT_SUMMARY_METHOD, + summary_args: Optional[dict[str, Any]] = {}, + message: Optional[Union[dict[str, Any], str, Callable[..., Any]]] = None, + executor_kwargs: Optional[dict[str, Any]] = None, + tools: Optional[Union[Tool, Iterable[Tool]]] = None, + user_input: Optional[bool] = False, + msg_to: Optional[str] = "agent", + **kwargs: Any, + ) -> RunResponseProtocol: + iostream = ThreadIOStream() + agents = [self, recipient] if recipient else [self] + response = RunResponse(iostream, agents=agents) + + if recipient is None: + + def initiate_chat( + self=self, + iostream: ThreadIOStream = iostream, + response: RunResponse = response, + ) -> None: + with ( + IOStream.set_default(iostream), + self._create_or_get_executor( + executor_kwargs=executor_kwargs, + tools=tools, + agent_name="user", + agent_human_input_mode="ALWAYS" if user_input else "NEVER", + ) as executor, + ): + try: + if msg_to == "agent": + chat_result = executor.initiate_chat( + self, + message=message, + clear_history=clear_history, + max_turns=max_turns, + summary_method=summary_method, + ) + else: + chat_result = self.initiate_chat( + executor, + message=message, + clear_history=clear_history, + max_turns=max_turns, + summary_method=summary_method, + ) + + IOStream.get_default().send( + RunCompletionEvent( + history=chat_result.chat_history, + summary=chat_result.summary, + cost=chat_result.cost, + last_speaker=self.name, + ) + ) + except Exception as e: + response.iostream.send(ErrorEvent(error=e)) + + else: + + def initiate_chat( + self=self, + iostream: ThreadIOStream = iostream, + response: RunResponse = response, + ) -> None: + with IOStream.set_default(iostream): # type: ignore[arg-type] + try: + chat_result = self.initiate_chat( + recipient, + clear_history=clear_history, + silent=silent, + cache=cache, + max_turns=max_turns, + summary_method=summary_method, + summary_args=summary_args, + message=message, + **kwargs, + ) + + response._summary = chat_result.summary + response._messages = chat_result.chat_history + + _last_speaker = recipient if chat_result.chat_history[-1]["name"] == recipient.name else self + if hasattr(recipient, "last_speaker"): + _last_speaker = recipient.last_speaker + + IOStream.get_default().send( + RunCompletionEvent( + history=chat_result.chat_history, + summary=chat_result.summary, + cost=chat_result.cost, + last_speaker=_last_speaker.name, + ) + ) + except Exception as e: + response.iostream.send(ErrorEvent(error=e)) + + threading.Thread( + target=initiate_chat, + ).start() + + return response + + async def a_initiate_chat( + self, + recipient: "ConversableAgent", + clear_history: bool = True, + silent: Optional[bool] = False, + cache: Optional[AbstractCache] = None, + max_turns: Optional[int] = None, + summary_method: Optional[Union[str, Callable[..., Any]]] = DEFAULT_SUMMARY_METHOD, + summary_args: Optional[dict[str, Any]] = {}, + message: Optional[Union[str, Callable[..., Any]]] = None, + **kwargs: Any, + ) -> ChatResult: + """(async) Initiate a chat with the recipient agent. + + Reset the consecutive auto reply counter. + If `clear_history` is True, the chat history with the recipient agent will be cleared. + `a_generate_init_message` is called to generate the initial message for the agent. + + Args: Please refer to `initiate_chat`. + + Returns: + ChatResult: an ChatResult object. + """ + iostream = IOStream.get_default() + + _chat_info = locals().copy() + _chat_info["sender"] = self + consolidate_chat_info(_chat_info, uniform_sender=self) + for agent in [self, recipient]: + agent.previous_cache = agent.client_cache + agent.client_cache = cache + if isinstance(max_turns, int): + self._prepare_chat(recipient, clear_history, reply_at_receive=False) + for _ in range(max_turns): + if _ == 0: + if isinstance(message, Callable): + msg2send = message(_chat_info["sender"], _chat_info["recipient"], kwargs) + else: + msg2send = await self.a_generate_init_message(message, **kwargs) + else: + msg2send = await self.a_generate_reply(messages=self.chat_messages[recipient], sender=recipient) + if msg2send is None: + break + await self.a_send(msg2send, recipient, request_reply=True, silent=silent) + else: # No breaks in the for loop, so we have reached max turns + iostream.send(TerminationEvent(termination_reason=f"Maximum turns ({max_turns}) reached")) + else: + self._prepare_chat(recipient, clear_history) + if isinstance(message, Callable): + msg2send = message(_chat_info["sender"], _chat_info["recipient"], kwargs) + else: + msg2send = await self.a_generate_init_message(message, **kwargs) + await self.a_send(msg2send, recipient, silent=silent) + summary = self._summarize_chat( + summary_method, + summary_args, + recipient, + cache=cache, + ) + for agent in [self, recipient]: + agent.client_cache = agent.previous_cache + agent.previous_cache = None + chat_result = ChatResult( + chat_history=self.chat_messages[recipient], + summary=summary, + cost=gather_usage_summary([self, recipient]), + human_input=self._human_input, + ) + return chat_result + + async def a_run( + self, + recipient: Optional["ConversableAgent"] = None, + clear_history: bool = True, + silent: Optional[bool] = False, + cache: Optional[AbstractCache] = None, + max_turns: Optional[int] = None, + summary_method: Optional[Union[str, Callable[..., Any]]] = DEFAULT_SUMMARY_METHOD, + summary_args: Optional[dict[str, Any]] = {}, + message: Optional[Union[dict[str, Any], str, Callable[..., Any]]] = None, + executor_kwargs: Optional[dict[str, Any]] = None, + tools: Optional[Union[Tool, Iterable[Tool]]] = None, + user_input: Optional[bool] = False, + msg_to: Optional[str] = "agent", + **kwargs: Any, + ) -> AsyncRunResponseProtocol: + iostream = AsyncThreadIOStream() + agents = [self, recipient] if recipient else [self] + response = AsyncRunResponse(iostream, agents=agents) + + if recipient is None: + + async def initiate_chat( + self=self, + iostream: AsyncThreadIOStream = iostream, + response: AsyncRunResponse = response, + ) -> None: + with ( + IOStream.set_default(iostream), + self._create_or_get_executor( + executor_kwargs=executor_kwargs, + tools=tools, + agent_name="user", + agent_human_input_mode="ALWAYS" if user_input else "NEVER", + ) as executor, + ): + try: + if msg_to == "agent": + chat_result = await executor.a_initiate_chat( + self, + message=message, + clear_history=clear_history, + max_turns=max_turns, + summary_method=summary_method, + ) + else: + chat_result = await self.a_initiate_chat( + executor, + message=message, + clear_history=clear_history, + max_turns=max_turns, + summary_method=summary_method, + ) + + IOStream.get_default().send( + RunCompletionEvent( + history=chat_result.chat_history, + summary=chat_result.summary, + cost=chat_result.cost, + last_speaker=self.name, + ) + ) + except Exception as e: + response.iostream.send(ErrorEvent(error=e)) + + else: + + async def initiate_chat( + self=self, + iostream: AsyncThreadIOStream = iostream, + response: AsyncRunResponse = response, + ) -> None: + with IOStream.set_default(iostream): # type: ignore[arg-type] + try: + chat_result = await self.a_initiate_chat( + recipient, + clear_history=clear_history, + silent=silent, + cache=cache, + max_turns=max_turns, + summary_method=summary_method, + summary_args=summary_args, + message=message, + **kwargs, + ) + + last_speaker = recipient if chat_result.chat_history[-1]["name"] == recipient.name else self + if hasattr(recipient, "last_speaker"): + last_speaker = recipient.last_speaker + + IOStream.get_default().send( + RunCompletionEvent( + history=chat_result.chat_history, + summary=chat_result.summary, + cost=chat_result.cost, + last_speaker=last_speaker.name, + ) + ) + + except Exception as e: + response.iostream.send(ErrorEvent(error=e)) + + asyncio.create_task(initiate_chat()) + + return response + + def _summarize_chat( + self, + summary_method, + summary_args, + recipient: Optional[Agent] = None, + cache: Optional[AbstractCache] = None, + ) -> str: + """Get a chat summary from an agent participating in a chat. + + Args: + summary_method (str or callable): the summary_method to get the summary. + The callable summary_method should take the recipient and sender agent in a chat as input and return a string of summary. E.g, + ```python + def my_summary_method( + sender: ConversableAgent, + recipient: ConversableAgent, + summary_args: dict, + ): + return recipient.last_message(sender)["content"] + ``` + summary_args (dict): a dictionary of arguments to be passed to the summary_method. + recipient: the recipient agent in a chat. + cache: the cache client to be used for this conversation. When provided, + the cache will be used to store and retrieve LLM responses when generating + summaries, which can improve performance and reduce API costs for + repetitive summary requests. The cache is passed to the summary_method + via summary_args['cache']. + + Returns: + str: a chat summary from the agent. + """ + summary = "" + if summary_method is None: + return summary + if "cache" not in summary_args: + summary_args["cache"] = cache + if summary_method == "reflection_with_llm": + summary_method = self._reflection_with_llm_as_summary + elif summary_method == "last_msg": + summary_method = self._last_msg_as_summary + + if isinstance(summary_method, Callable): + summary = summary_method(self, recipient, summary_args) + else: + raise ValueError( + "If not None, the summary_method must be a string from [`reflection_with_llm`, `last_msg`] or a callable." + ) + return summary + + @staticmethod + def _last_msg_as_summary(sender, recipient, summary_args) -> str: + """Get a chat summary from the last message of the recipient.""" + summary = "" + try: + content = recipient.last_message(sender)["content"] + if isinstance(content, str): + summary = content.replace("TERMINATE", "") + elif isinstance(content, list): + # Remove the `TERMINATE` word in the content list. + summary = "\n".join( + x["text"].replace("TERMINATE", "") for x in content if isinstance(x, dict) and "text" in x + ) + except (IndexError, AttributeError) as e: + warnings.warn(f"Cannot extract summary using last_msg: {e}. Using an empty str as summary.", UserWarning) + return summary + + @staticmethod + def _reflection_with_llm_as_summary(sender, recipient, summary_args): + prompt = summary_args.get("summary_prompt") + prompt = ConversableAgent.DEFAULT_SUMMARY_PROMPT if prompt is None else prompt + if not isinstance(prompt, str): + raise ValueError("The summary_prompt must be a string.") + msg_list = recipient.chat_messages_for_summary(sender) + agent = sender if recipient is None else recipient + role = summary_args.get("summary_role", None) + if role and not isinstance(role, str): + raise ValueError("The summary_role in summary_arg must be a string.") + try: + summary = sender._reflection_with_llm( + prompt, msg_list, llm_agent=agent, cache=summary_args.get("cache"), role=role + ) + except Exception as e: + warnings.warn( + f"Cannot extract summary using reflection_with_llm: {e}. Using an empty str as summary.", UserWarning + ) + summary = "" + return summary + + def _reflection_with_llm( + self, + prompt, + messages, + llm_agent: Optional[Agent] = None, + cache: Optional[AbstractCache] = None, + role: Union[str, None] = None, + ) -> str: + """Get a chat summary using reflection with an llm client based on the conversation history. + + Args: + prompt (str): The prompt (in this method it is used as system prompt) used to get the summary. + messages (list): The messages generated as part of a chat conversation. + llm_agent: the agent with an llm client. + cache (AbstractCache or None): the cache client to be used for this conversation. + role (str): the role of the message, usually "system" or "user". Default is "system". + """ + if not role: + role = "system" + + system_msg = [ + { + "role": role, + "content": prompt, + } + ] + + messages = messages + system_msg + if llm_agent and llm_agent.client is not None: + llm_client = llm_agent.client + elif self.client is not None: + llm_client = self.client + else: + raise ValueError("No OpenAIWrapper client is found.") + response = self._generate_oai_reply_from_client(llm_client=llm_client, messages=messages, cache=cache) + return response + + def _check_chat_queue_for_sender(self, chat_queue: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Check the chat queue and add the "sender" key if it's missing. + + Args: + chat_queue (List[Dict[str, Any]]): A list of dictionaries containing chat information. + + Returns: + List[Dict[str, Any]]: A new list of dictionaries with the "sender" key added if it was missing. + """ + chat_queue_with_sender = [] + for chat_info in chat_queue: + if chat_info.get("sender") is None: + chat_info["sender"] = self + chat_queue_with_sender.append(chat_info) + return chat_queue_with_sender + + def initiate_chats(self, chat_queue: list[dict[str, Any]]) -> list[ChatResult]: + """(Experimental) Initiate chats with multiple agents. + + Args: + chat_queue (List[Dict]): a list of dictionaries containing the information of the chats. + Each dictionary should contain the input arguments for [`initiate_chat`](#initiate-chat) + + Returns: a list of ChatResult objects corresponding to the finished chats in the chat_queue. + """ + _chat_queue = self._check_chat_queue_for_sender(chat_queue) + self._finished_chats = initiate_chats(_chat_queue) + + return self._finished_chats + + def sequential_run( + self, + chat_queue: list[dict[str, Any]], + ) -> list[RunResponseProtocol]: + """(Experimental) Initiate chats with multiple agents sequentially. + + Args: + chat_queue (List[Dict]): a list of dictionaries containing the information of the chats. + Each dictionary should contain the input arguments for [`initiate_chat`](#initiate-chat) + + Returns: a list of ChatResult objects corresponding to the finished chats in the chat_queue. + """ + iostreams = [ThreadIOStream() for _ in range(len(chat_queue))] + # todo: add agents + responses = [RunResponse(iostream, agents=[]) for iostream in iostreams] + + def _initiate_chats( + iostreams: list[ThreadIOStream] = iostreams, + responses: list[RunResponseProtocol] = responses, + ) -> None: + response = responses[0] + try: + _chat_queue = self._check_chat_queue_for_sender(chat_queue) + + consolidate_chat_info(_chat_queue) + _validate_recipients(_chat_queue) + finished_chats = [] + for chat_info, response, iostream in zip(_chat_queue, responses, iostreams): + with IOStream.set_default(iostream): + _chat_carryover = chat_info.get("carryover", []) + finished_chat_indexes_to_exclude_from_carryover = chat_info.get( + "finished_chat_indexes_to_exclude_from_carryover", [] + ) + + if isinstance(_chat_carryover, str): + _chat_carryover = [_chat_carryover] + chat_info["carryover"] = _chat_carryover + [ + r.summary + for i, r in enumerate(finished_chats) + if i not in finished_chat_indexes_to_exclude_from_carryover + ] + + if not chat_info.get("silent", False): + IOStream.get_default().send(PostCarryoverProcessingEvent(chat_info=chat_info)) + + sender = chat_info["sender"] + chat_res = sender.initiate_chat(**chat_info) + + IOStream.get_default().send( + RunCompletionEvent( + history=chat_res.chat_history, + summary=chat_res.summary, + cost=chat_res.cost, + last_speaker=(self if chat_res.chat_history[-1]["name"] == self.name else sender).name, + ) + ) + + finished_chats.append(chat_res) + except Exception as e: + response.iostream.send(ErrorEvent(error=e)) + + threading.Thread(target=_initiate_chats).start() + + return responses + + async def a_initiate_chats(self, chat_queue: list[dict[str, Any]]) -> dict[int, ChatResult]: + _chat_queue = self._check_chat_queue_for_sender(chat_queue) + self._finished_chats = await a_initiate_chats(_chat_queue) + return self._finished_chats + + async def a_sequential_run( + self, + chat_queue: list[dict[str, Any]], + ) -> list[AsyncRunResponseProtocol]: + """(Experimental) Initiate chats with multiple agents sequentially. + + Args: + chat_queue (List[Dict]): a list of dictionaries containing the information of the chats. + Each dictionary should contain the input arguments for [`initiate_chat`](#initiate-chat) + + Returns: a list of ChatResult objects corresponding to the finished chats in the chat_queue. + """ + iostreams = [AsyncThreadIOStream() for _ in range(len(chat_queue))] + # todo: add agents + responses = [AsyncRunResponse(iostream, agents=[]) for iostream in iostreams] + + async def _a_initiate_chats( + iostreams: list[AsyncThreadIOStream] = iostreams, + responses: list[AsyncRunResponseProtocol] = responses, + ) -> None: + response = responses[0] + try: + _chat_queue = self._check_chat_queue_for_sender(chat_queue) + + consolidate_chat_info(_chat_queue) + _validate_recipients(_chat_queue) + finished_chats = [] + for chat_info, response, iostream in zip(_chat_queue, responses, iostreams): + with IOStream.set_default(iostream): + _chat_carryover = chat_info.get("carryover", []) + finished_chat_indexes_to_exclude_from_carryover = chat_info.get( + "finished_chat_indexes_to_exclude_from_carryover", [] + ) + + if isinstance(_chat_carryover, str): + _chat_carryover = [_chat_carryover] + chat_info["carryover"] = _chat_carryover + [ + r.summary + for i, r in enumerate(finished_chats) + if i not in finished_chat_indexes_to_exclude_from_carryover + ] + + if not chat_info.get("silent", False): + IOStream.get_default().send(PostCarryoverProcessingEvent(chat_info=chat_info)) + + sender = chat_info["sender"] + chat_res = await sender.a_initiate_chat(**chat_info) + + IOStream.get_default().send( + RunCompletionEvent( + history=chat_res.chat_history, + summary=chat_res.summary, + cost=chat_res.cost, + last_speaker=(self if chat_res.chat_history[-1]["name"] == self.name else sender).name, + ) + ) + + finished_chats.append(chat_res) + + except Exception as e: + response.iostream.send(ErrorEvent(error=e)) + + asyncio.create_task(_a_initiate_chats()) + + return responses + + def get_chat_results(self, chat_index: Optional[int] = None) -> Union[list[ChatResult], ChatResult]: + """A summary from the finished chats of particular agents.""" + if chat_index is not None: + return self._finished_chats[chat_index] + else: + return self._finished_chats + + def reset(self) -> None: + """Reset the agent.""" + self.clear_history() + self.reset_consecutive_auto_reply_counter() + self.stop_reply_at_receive() + if self.client is not None: + self.client.clear_usage_summary() + for reply_func_tuple in self._reply_func_list: + if reply_func_tuple["reset_config"] is not None: + reply_func_tuple["reset_config"](reply_func_tuple["config"]) + else: + reply_func_tuple["config"] = copy.copy(reply_func_tuple["init_config"]) + + def stop_reply_at_receive(self, sender: Optional[Agent] = None): + """Reset the reply_at_receive of the sender.""" + if sender is None: + self.reply_at_receive.clear() + else: + self.reply_at_receive[sender] = False + + def reset_consecutive_auto_reply_counter(self, sender: Optional[Agent] = None): + """Reset the consecutive_auto_reply_counter of the sender.""" + if sender is None: + self._consecutive_auto_reply_counter.clear() + else: + self._consecutive_auto_reply_counter[sender] = 0 + + def clear_history(self, recipient: Optional[Agent] = None, nr_messages_to_preserve: Optional[int] = None): + """Clear the chat history of the agent. + + Args: + recipient: the agent with whom the chat history to clear. If None, clear the chat history with all agents. + nr_messages_to_preserve: the number of newest messages to preserve in the chat history. + """ + iostream = IOStream.get_default() + if recipient is None: + no_messages_preserved = 0 + if nr_messages_to_preserve: + for key in self._oai_messages: + nr_messages_to_preserve_internal = nr_messages_to_preserve + # if breaking history between function call and function response, save function call message + # additionally, otherwise openai will return error + first_msg_to_save = self._oai_messages[key][-nr_messages_to_preserve_internal] + if "tool_responses" in first_msg_to_save: + nr_messages_to_preserve_internal += 1 + # clear_conversable_agent_history.print_preserving_message(iostream.print) + no_messages_preserved += 1 + # Remove messages from history except last `nr_messages_to_preserve` messages. + self._oai_messages[key] = self._oai_messages[key][-nr_messages_to_preserve_internal:] + iostream.send(ClearConversableAgentHistoryEvent(agent=self, no_events_preserved=no_messages_preserved)) + else: + self._oai_messages.clear() + else: + self._oai_messages[recipient].clear() + # clear_conversable_agent_history.print_warning(iostream.print) + if nr_messages_to_preserve: + iostream.send(ClearConversableAgentHistoryWarningEvent(recipient=self)) + + def generate_oai_reply( + self, + messages: Optional[list[dict[str, Any]]] = None, + sender: Optional[Agent] = None, + config: Optional[OpenAIWrapper] = None, + ) -> tuple[bool, Optional[Union[str, dict[str, Any]]]]: + """Generate a reply using autogen.oai.""" + client = self.client if config is None else config + if client is None: + return False, None + if messages is None: + messages = self._oai_messages[sender] + extracted_response = self._generate_oai_reply_from_client( + client, self._oai_system_message + messages, self.client_cache + ) + return (False, None) if extracted_response is None else (True, extracted_response) + + def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> Optional[Union[str, dict[str, Any]]]: + # unroll tool_responses + all_messages = [] + for message in messages: + tool_responses = message.get("tool_responses", []) + if tool_responses: + all_messages += tool_responses + # tool role on the parent message means the content is just concatenation of all of the tool_responses + if message.get("role") != "tool": + all_messages.append({key: message[key] for key in message if key != "tool_responses"}) + else: + all_messages.append(message) + + # TODO: #1143 handle token limit exceeded error + response = llm_client.create( + context=messages[-1].pop("context", None), + messages=all_messages, + cache=cache, + agent=self, + ) + extracted_response = llm_client.extract_text_or_completion_object(response)[0] + + if extracted_response is None: + warnings.warn(f"Extracted_response from {response} is None.", UserWarning) + return None + # ensure function and tool calls will be accepted when sent back to the LLM + if not isinstance(extracted_response, str) and hasattr(extracted_response, "model_dump"): + extracted_response = extracted_response.model_dump() + if isinstance(extracted_response, dict): + if extracted_response.get("function_call"): + extracted_response["function_call"]["name"] = self._normalize_name( + extracted_response["function_call"]["name"] + ) + for tool_call in extracted_response.get("tool_calls") or []: + tool_call["function"]["name"] = self._normalize_name(tool_call["function"]["name"]) + # Remove id and type if they are not present. + # This is to make the tool call object compatible with Mistral API. + if tool_call.get("id") is None: + tool_call.pop("id") + if tool_call.get("type") is None: + tool_call.pop("type") + return extracted_response + + async def a_generate_oai_reply( + self, + messages: Optional[list[dict[str, Any]]] = None, + sender: Optional[Agent] = None, + config: Optional[Any] = None, + ) -> tuple[bool, Optional[Union[str, dict[str, Any]]]]: + """Generate a reply using autogen.oai asynchronously.""" + iostream = IOStream.get_default() + + def _generate_oai_reply( + self, iostream: IOStream, *args: Any, **kwargs: Any + ) -> tuple[bool, Optional[Union[str, dict[str, Any]]]]: + with IOStream.set_default(iostream): + return self.generate_oai_reply(*args, **kwargs) + + return await asyncio.get_event_loop().run_in_executor( + None, + functools.partial( + _generate_oai_reply, self=self, iostream=iostream, messages=messages, sender=sender, config=config + ), + ) + + def _generate_code_execution_reply_using_executor( + self, + messages: Optional[list[dict[str, Any]]] = None, + sender: Optional[Agent] = None, + config: Optional[Union[dict[str, Any], Literal[False]]] = None, + ): + """Generate a reply using code executor.""" + iostream = IOStream.get_default() + + if config is not None: + raise ValueError("config is not supported for _generate_code_execution_reply_using_executor.") + if self._code_execution_config is False: + return False, None + if messages is None: + messages = self._oai_messages[sender] + last_n_messages = self._code_execution_config.get("last_n_messages", "auto") + + if not (isinstance(last_n_messages, (int, float)) and last_n_messages >= 0) and last_n_messages != "auto": + raise ValueError("last_n_messages must be either a non-negative integer, or the string 'auto'.") + + num_messages_to_scan = last_n_messages + if last_n_messages == "auto": + # Find when the agent last spoke + num_messages_to_scan = 0 + for message in reversed(messages): + if "role" not in message or message["role"] != "user": + break + else: + num_messages_to_scan += 1 + num_messages_to_scan = min(len(messages), num_messages_to_scan) + messages_to_scan = messages[-num_messages_to_scan:] + + # iterate through the last n messages in reverse + # if code blocks are found, execute the code blocks and return the output + # if no code blocks are found, continue + for message in reversed(messages_to_scan): + if not message["content"]: + continue + code_blocks = self._code_executor.code_extractor.extract_code_blocks(message["content"]) + if len(code_blocks) == 0: + continue + + iostream.send(GenerateCodeExecutionReplyEvent(code_blocks=code_blocks, sender=sender, recipient=self)) + + # found code blocks, execute code. + code_result = self._code_executor.execute_code_blocks(code_blocks) + exitcode2str = "execution succeeded" if code_result.exit_code == 0 else "execution failed" + return True, f"exitcode: {code_result.exit_code} ({exitcode2str})\nCode output: {code_result.output}" + + return False, None + + def generate_code_execution_reply( + self, + messages: Optional[list[dict[str, Any]]] = None, + sender: Optional[Agent] = None, + config: Optional[Union[dict[str, Any], Literal[False]]] = None, + ): + """Generate a reply using code execution.""" + code_execution_config = config if config is not None else self._code_execution_config + if code_execution_config is False: + return False, None + if messages is None: + messages = self._oai_messages[sender] + last_n_messages = code_execution_config.pop("last_n_messages", "auto") + + if not (isinstance(last_n_messages, (int, float)) and last_n_messages >= 0) and last_n_messages != "auto": + raise ValueError("last_n_messages must be either a non-negative integer, or the string 'auto'.") + + messages_to_scan = last_n_messages + if last_n_messages == "auto": + # Find when the agent last spoke + messages_to_scan = 0 + for i in range(len(messages)): + message = messages[-(i + 1)] + if "role" not in message or message["role"] != "user": + break + else: + messages_to_scan += 1 + + # iterate through the last n messages in reverse + # if code blocks are found, execute the code blocks and return the output + # if no code blocks are found, continue + for i in range(min(len(messages), messages_to_scan)): + message = messages[-(i + 1)] + if not message["content"]: + continue + code_blocks = extract_code(message["content"]) + if len(code_blocks) == 1 and code_blocks[0][0] == UNKNOWN: + continue + + # found code blocks, execute code and push "last_n_messages" back + exitcode, logs = self.execute_code_blocks(code_blocks) + code_execution_config["last_n_messages"] = last_n_messages + exitcode2str = "execution succeeded" if exitcode == 0 else "execution failed" + return True, f"exitcode: {exitcode} ({exitcode2str})\nCode output: {logs}" + + # no code blocks are found, push last_n_messages back and return. + code_execution_config["last_n_messages"] = last_n_messages + + return False, None + + def _run_async_in_thread(self, coro): + """Run an async coroutine in a separate thread with its own event loop.""" + result = {} + + def runner(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + result["value"] = loop.run_until_complete(coro) + loop.close() + + t = threading.Thread(target=runner) + t.start() + t.join() + return result["value"] + + def generate_function_call_reply( + self, + messages: Optional[list[dict[str, Any]]] = None, + sender: Optional[Agent] = None, + config: Optional[Any] = None, + ) -> tuple[bool, Optional[dict[str, Any]]]: + """Generate a reply using function call. + + "function_call" replaced by "tool_calls" as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0) + See https://platform.openai.com/docs/api-reference/chat/create#chat-create-functions + """ + if config is None: + config = self + if messages is None: + messages = self._oai_messages[sender] + message = messages[-1] + if message.get("function_call"): + call_id = message.get("id", None) + func_call = message["function_call"] + func = self._function_map.get(func_call.get("name", None), None) + if inspect.iscoroutinefunction(func): + coro = self.a_execute_function(func_call, call_id=call_id) + _, func_return = self._run_async_in_thread(coro) + else: + _, func_return = self.execute_function(message["function_call"], call_id=call_id) + return True, func_return + return False, None + + async def a_generate_function_call_reply( + self, + messages: Optional[list[dict[str, Any]]] = None, + sender: Optional[Agent] = None, + config: Optional[Any] = None, + ) -> tuple[bool, Optional[dict[str, Any]]]: + """Generate a reply using async function call. + + "function_call" replaced by "tool_calls" as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0) + See https://platform.openai.com/docs/api-reference/chat/create#chat-create-functions + """ + if config is None: + config = self + if messages is None: + messages = self._oai_messages[sender] + message = messages[-1] + if "function_call" in message: + call_id = message.get("id", None) + func_call = message["function_call"] + func_name = func_call.get("name", "") + func = self._function_map.get(func_name, None) + if func and inspect.iscoroutinefunction(func): + _, func_return = await self.a_execute_function(func_call, call_id=call_id) + else: + _, func_return = self.execute_function(func_call, call_id=call_id) + return True, func_return + + return False, None + + def _str_for_tool_response(self, tool_response): + return str(tool_response.get("content", "")) + + def generate_tool_calls_reply( + self, + messages: Optional[list[dict[str, Any]]] = None, + sender: Optional[Agent] = None, + config: Optional[Any] = None, + ) -> tuple[bool, Optional[dict[str, Any]]]: + """Generate a reply using tool call.""" + if config is None: + config = self + if messages is None: + messages = self._oai_messages[sender] + message = messages[-1] + tool_returns = [] + for tool_call in message.get("tool_calls", []): + function_call = tool_call.get("function", {}) + tool_call_id = tool_call.get("id", None) + func = self._function_map.get(function_call.get("name", None), None) + if inspect.iscoroutinefunction(func): + coro = self.a_execute_function(function_call, call_id=tool_call_id) + _, func_return = self._run_async_in_thread(coro) + else: + _, func_return = self.execute_function(function_call, call_id=tool_call_id) + content = func_return.get("content", "") + if content is None: + content = "" + + if tool_call_id is not None: + tool_call_response = { + "tool_call_id": tool_call_id, + "role": "tool", + "content": content, + } + else: + # Do not include tool_call_id if it is not present. + # This is to make the tool call object compatible with Mistral API. + tool_call_response = { + "role": "tool", + "content": content, + } + tool_returns.append(tool_call_response) + if tool_returns: + return True, { + "role": "tool", + "tool_responses": tool_returns, + "content": "\n\n".join([self._str_for_tool_response(tool_return) for tool_return in tool_returns]), + } + return False, None + + async def _a_execute_tool_call(self, tool_call): + tool_call_id = tool_call["id"] + function_call = tool_call.get("function", {}) + _, func_return = await self.a_execute_function(function_call, call_id=tool_call_id) + return { + "tool_call_id": tool_call_id, + "role": "tool", + "content": func_return.get("content", ""), + } + + async def a_generate_tool_calls_reply( + self, + messages: Optional[list[dict[str, Any]]] = None, + sender: Optional[Agent] = None, + config: Optional[Any] = None, + ) -> tuple[bool, Optional[dict[str, Any]]]: + """Generate a reply using async function call.""" + if config is None: + config = self + if messages is None: + messages = self._oai_messages[sender] + message = messages[-1] + async_tool_calls = [] + for tool_call in message.get("tool_calls", []): + async_tool_calls.append(self._a_execute_tool_call(tool_call)) + if async_tool_calls: + tool_returns = await asyncio.gather(*async_tool_calls) + return True, { + "role": "tool", + "tool_responses": tool_returns, + "content": "\n\n".join([self._str_for_tool_response(tool_return) for tool_return in tool_returns]), + } + + return False, None + + def check_termination_and_human_reply( + self, + messages: Optional[list[dict[str, Any]]] = None, + sender: Optional[Agent] = None, + config: Optional[Any] = None, + ) -> tuple[bool, Union[str, None]]: + """Check if the conversation should be terminated, and if human reply is provided. + + This method checks for conditions that require the conversation to be terminated, such as reaching + a maximum number of consecutive auto-replies or encountering a termination message. Additionally, + it prompts for and processes human input based on the configured human input mode, which can be + 'ALWAYS', 'NEVER', or 'TERMINATE'. The method also manages the consecutive auto-reply counter + for the conversation and prints relevant messages based on the human input received. + + Args: + messages: A list of message dictionaries, representing the conversation history. + sender: The agent object representing the sender of the message. + config: Configuration object, defaults to the current instance if not provided. + + Returns: + A tuple containing a boolean indicating if the conversation + should be terminated, and a human reply which can be a string, a dictionary, or None. + """ + iostream = IOStream.get_default() + + if config is None: + config = self + if messages is None: + messages = self._oai_messages[sender] if sender else [] + + termination_reason = None + + # if there are no messages, continue the conversation + if not messages: + return False, None + message = messages[-1] + + reply = "" + no_human_input_msg = "" + sender_name = "the sender" if sender is None else sender.name + if self.human_input_mode == "ALWAYS": + reply = self.get_human_input( + f"Replying as {self.name}. Provide feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: " + ) + no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" + # if the human input is empty, and the message is a termination message, then we will terminate the conversation + if not reply and self._is_termination_msg(message): + termination_reason = f"Termination message condition on agent '{self.name}' met" + elif reply == "exit": + termination_reason = "User requested to end the conversation" + + reply = reply if reply or not self._is_termination_msg(message) else "exit" + else: + if self._consecutive_auto_reply_counter[sender] >= self._max_consecutive_auto_reply_dict[sender]: + if self.human_input_mode == "NEVER": + termination_reason = "Maximum number of consecutive auto-replies reached" + reply = "exit" + else: + # self.human_input_mode == "TERMINATE": + terminate = self._is_termination_msg(message) + reply = self.get_human_input( + f"Please give feedback to {sender_name}. Press enter or type 'exit' to stop the conversation: " + if terminate + else f"Please give feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to stop the conversation: " + ) + no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" + # if the human input is empty, and the message is a termination message, then we will terminate the conversation + if reply != "exit" and terminate: + termination_reason = ( + f"Termination message condition on agent '{self.name}' met and no human input provided" + ) + elif reply == "exit": + termination_reason = "User requested to end the conversation" + + reply = reply if reply or not terminate else "exit" + elif self._is_termination_msg(message): + if self.human_input_mode == "NEVER": + termination_reason = f"Termination message condition on agent '{self.name}' met" + reply = "exit" + else: + # self.human_input_mode == "TERMINATE": + reply = self.get_human_input( + f"Please give feedback to {sender_name}. Press enter or type 'exit' to stop the conversation: " + ) + no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" + + # if the human input is empty, and the message is a termination message, then we will terminate the conversation + if not reply or reply == "exit": + termination_reason = ( + f"Termination message condition on agent '{self.name}' met and no human input provided" + ) + + reply = reply or "exit" + + # print the no_human_input_msg + if no_human_input_msg: + iostream.send( + TerminationAndHumanReplyNoInputEvent( + no_human_input_msg=no_human_input_msg, sender=sender, recipient=self + ) + ) + + # stop the conversation + if reply == "exit": + # reset the consecutive_auto_reply_counter + self._consecutive_auto_reply_counter[sender] = 0 + + if termination_reason: + iostream.send(TerminationEvent(termination_reason=termination_reason)) + + return True, None + + # send the human reply + if reply or self._max_consecutive_auto_reply_dict[sender] == 0: + # reset the consecutive_auto_reply_counter + self._consecutive_auto_reply_counter[sender] = 0 + # User provided a custom response, return function and tool failures indicating user interruption + tool_returns = [] + if message.get("function_call", False): + tool_returns.append({ + "role": "function", + "name": message["function_call"].get("name", ""), + "content": "USER INTERRUPTED", + }) + + if message.get("tool_calls", False): + tool_returns.extend([ + {"role": "tool", "tool_call_id": tool_call.get("id", ""), "content": "USER INTERRUPTED"} + for tool_call in message["tool_calls"] + ]) + + response = {"role": "user", "content": reply} + if tool_returns: + response["tool_responses"] = tool_returns + + return True, response + + # increment the consecutive_auto_reply_counter + self._consecutive_auto_reply_counter[sender] += 1 + if self.human_input_mode != "NEVER": + iostream.send(UsingAutoReplyEvent(human_input_mode=self.human_input_mode, sender=sender, recipient=self)) + + return False, None + + async def a_check_termination_and_human_reply( + self, + messages: Optional[list[dict[str, Any]]] = None, + sender: Optional[Agent] = None, + config: Optional[Any] = None, + ) -> tuple[bool, Union[str, None]]: + """(async) Check if the conversation should be terminated, and if human reply is provided. + + This method checks for conditions that require the conversation to be terminated, such as reaching + a maximum number of consecutive auto-replies or encountering a termination message. Additionally, + it prompts for and processes human input based on the configured human input mode, which can be + 'ALWAYS', 'NEVER', or 'TERMINATE'. The method also manages the consecutive auto-reply counter + for the conversation and prints relevant messages based on the human input received. + + Args: + messages (Optional[List[Dict]]): A list of message dictionaries, representing the conversation history. + sender (Optional[Agent]): The agent object representing the sender of the message. + config (Optional[Any]): Configuration object, defaults to the current instance if not provided. + + Returns: + Tuple[bool, Union[str, Dict, None]]: A tuple containing a boolean indicating if the conversation + should be terminated, and a human reply which can be a string, a dictionary, or None. + """ + iostream = IOStream.get_default() + + if config is None: + config = self + if messages is None: + messages = self._oai_messages[sender] if sender else [] + + termination_reason = None + + message = messages[-1] if messages else {} + reply = "" + no_human_input_msg = "" + sender_name = "the sender" if sender is None else sender.name + if self.human_input_mode == "ALWAYS": + reply = await self.a_get_human_input( + f"Replying as {self.name}. Provide feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: " + ) + no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" + # if the human input is empty, and the message is a termination message, then we will terminate the conversation + if not reply and self._is_termination_msg(message): + termination_reason = f"Termination message condition on agent '{self.name}' met" + elif reply == "exit": + termination_reason = "User requested to end the conversation" + + reply = reply if reply or not self._is_termination_msg(message) else "exit" + else: + if self._consecutive_auto_reply_counter[sender] >= self._max_consecutive_auto_reply_dict[sender]: + if self.human_input_mode == "NEVER": + termination_reason = "Maximum number of consecutive auto-replies reached" + reply = "exit" + else: + # self.human_input_mode == "TERMINATE": + terminate = self._is_termination_msg(message) + reply = await self.a_get_human_input( + f"Please give feedback to {sender_name}. Press enter or type 'exit' to stop the conversation: " + if terminate + else f"Please give feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to stop the conversation: " + ) + no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" + # if the human input is empty, and the message is a termination message, then we will terminate the conversation + if reply != "exit" and terminate: + termination_reason = ( + f"Termination message condition on agent '{self.name}' met and no human input provided" + ) + elif reply == "exit": + termination_reason = "User requested to end the conversation" + + reply = reply if reply or not terminate else "exit" + elif self._is_termination_msg(message): + if self.human_input_mode == "NEVER": + termination_reason = f"Termination message condition on agent '{self.name}' met" + reply = "exit" + else: + # self.human_input_mode == "TERMINATE": + reply = await self.a_get_human_input( + f"Please give feedback to {sender_name}. Press enter or type 'exit' to stop the conversation: " + ) + no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" + + # if the human input is empty, and the message is a termination message, then we will terminate the conversation + if not reply or reply == "exit": + termination_reason = ( + f"Termination message condition on agent '{self.name}' met and no human input provided" + ) + + reply = reply or "exit" + + # print the no_human_input_msg + if no_human_input_msg: + iostream.send( + TerminationAndHumanReplyNoInputEvent( + no_human_input_msg=no_human_input_msg, sender=sender, recipient=self + ) + ) + + # stop the conversation + if reply == "exit": + # reset the consecutive_auto_reply_counter + self._consecutive_auto_reply_counter[sender] = 0 + + if termination_reason: + iostream.send(TerminationEvent(termination_reason=termination_reason)) + + return True, None + + # send the human reply + if reply or self._max_consecutive_auto_reply_dict[sender] == 0: + # User provided a custom response, return function and tool results indicating user interruption + # reset the consecutive_auto_reply_counter + self._consecutive_auto_reply_counter[sender] = 0 + tool_returns = [] + if message.get("function_call", False): + tool_returns.append({ + "role": "function", + "name": message["function_call"].get("name", ""), + "content": "USER INTERRUPTED", + }) + + if message.get("tool_calls", False): + tool_returns.extend([ + {"role": "tool", "tool_call_id": tool_call.get("id", ""), "content": "USER INTERRUPTED"} + for tool_call in message["tool_calls"] + ]) + + response = {"role": "user", "content": reply} + if tool_returns: + response["tool_responses"] = tool_returns + + return True, response + + # increment the consecutive_auto_reply_counter + self._consecutive_auto_reply_counter[sender] += 1 + if self.human_input_mode != "NEVER": + iostream.send(UsingAutoReplyEvent(human_input_mode=self.human_input_mode, sender=sender, recipient=self)) + + return False, None + + def generate_reply( + self, + messages: Optional[list[dict[str, Any]]] = None, + sender: Optional["Agent"] = None, + **kwargs: Any, + ) -> Optional[Union[str, dict[str, Any]]]: + """Reply based on the conversation history and the sender. + + Either messages or sender must be provided. + Register a reply_func with `None` as one trigger for it to be activated when `messages` is non-empty and `sender` is `None`. + Use registered auto reply functions to generate replies. + By default, the following functions are checked in order: + 1. check_termination_and_human_reply + 2. generate_function_call_reply (deprecated in favor of tool_calls) + 3. generate_tool_calls_reply + 4. generate_code_execution_reply + 5. generate_oai_reply + Every function returns a tuple (final, reply). + When a function returns final=False, the next function will be checked. + So by default, termination and human reply will be checked first. + If not terminating and human reply is skipped, execute function or code and return the result. + AI replies are generated only when no code execution is performed. + + Args: + messages: a list of messages in the conversation history. + sender: sender of an Agent instance. + **kwargs (Any): Additional arguments to customize reply generation. Supported kwargs: + - exclude (List[Callable[..., Any]]): A list of reply functions to exclude from + the reply generation process. Functions in this list will be skipped even if + they would normally be triggered. + + Returns: + str or dict or None: reply. None if no reply is generated. + """ + if all((messages is None, sender is None)): + error_msg = f"Either {messages=} or {sender=} must be provided." + logger.error(error_msg) + raise AssertionError(error_msg) + + if messages is None: + messages = self._oai_messages[sender] + + # Call the hookable method that gives registered hooks a chance to update agent state, used for their context variables. + self.update_agent_state_before_reply(messages) + + # Call the hookable method that gives registered hooks a chance to process the last message. + # Message modifications do not affect the incoming messages or self._oai_messages. + messages = self.process_last_received_message(messages) + + # Call the hookable method that gives registered hooks a chance to process all messages. + # Message modifications do not affect the incoming messages or self._oai_messages. + messages = self.process_all_messages_before_reply(messages) + + for reply_func_tuple in self._reply_func_list: + reply_func = reply_func_tuple["reply_func"] + if "exclude" in kwargs and reply_func in kwargs["exclude"]: + continue + if inspect.iscoroutinefunction(reply_func): + continue + if self._match_trigger(reply_func_tuple["trigger"], sender): + final, reply = reply_func(self, messages=messages, sender=sender, config=reply_func_tuple["config"]) + if logging_enabled(): + log_event( + self, + "reply_func_executed", + reply_func_module=reply_func.__module__, + reply_func_name=reply_func.__name__, + final=final, + reply=reply, + ) + if final: + return reply + return self._default_auto_reply + + async def a_generate_reply( + self, + messages: Optional[list[dict[str, Any]]] = None, + sender: Optional["Agent"] = None, + **kwargs: Any, + ) -> Union[str, dict[str, Any], None]: + """(async) Reply based on the conversation history and the sender. + + Either messages or sender must be provided. + Register a reply_func with `None` as one trigger for it to be activated when `messages` is non-empty and `sender` is `None`. + Use registered auto reply functions to generate replies. + By default, the following functions are checked in order: + 1. check_termination_and_human_reply + 2. generate_function_call_reply + 3. generate_tool_calls_reply + 4. generate_code_execution_reply + 5. generate_oai_reply + Every function returns a tuple (final, reply). + When a function returns final=False, the next function will be checked. + So by default, termination and human reply will be checked first. + If not terminating and human reply is skipped, execute function or code and return the result. + AI replies are generated only when no code execution is performed. + + Args: + messages: a list of messages in the conversation history. + sender: sender of an Agent instance. + **kwargs (Any): Additional arguments to customize reply generation. Supported kwargs: + - exclude (List[Callable[..., Any]]): A list of reply functions to exclude from + the reply generation process. Functions in this list will be skipped even if + they would normally be triggered. + + Returns: + str or dict or None: reply. None if no reply is generated. + """ + if all((messages is None, sender is None)): + error_msg = f"Either {messages=} or {sender=} must be provided." + logger.error(error_msg) + raise AssertionError(error_msg) + + if messages is None: + messages = self._oai_messages[sender] + + # Call the hookable method that gives registered hooks a chance to update agent state, used for their context variables. + self.update_agent_state_before_reply(messages) + + # Call the hookable method that gives registered hooks a chance to process the last message. + # Message modifications do not affect the incoming messages or self._oai_messages. + messages = self.process_last_received_message(messages) + + # Call the hookable method that gives registered hooks a chance to process all messages. + # Message modifications do not affect the incoming messages or self._oai_messages. + messages = self.process_all_messages_before_reply(messages) + + for reply_func_tuple in self._reply_func_list: + reply_func = reply_func_tuple["reply_func"] + if "exclude" in kwargs and reply_func in kwargs["exclude"]: + continue + + if self._match_trigger(reply_func_tuple["trigger"], sender): + if inspect.iscoroutinefunction(reply_func): + final, reply = await reply_func( + self, messages=messages, sender=sender, config=reply_func_tuple["config"] + ) + else: + final, reply = reply_func(self, messages=messages, sender=sender, config=reply_func_tuple["config"]) + if final: + return reply + return self._default_auto_reply + + def _match_trigger(self, trigger: Union[None, str, type, Agent, Callable, list], sender: Optional[Agent]) -> bool: + """Check if the sender matches the trigger. + + Args: + trigger (Union[None, str, type, Agent, Callable, List]): The condition to match against the sender. + Can be `None`, string, type, `Agent` instance, callable, or a list of these. + sender (Agent): The sender object or type to be matched against the trigger. + + Returns: + `True` if the sender matches the trigger, otherwise `False`. + + Raises: + ValueError: If the trigger type is unsupported. + """ + if trigger is None: + return sender is None + elif isinstance(trigger, str): + if sender is None: + raise SenderRequiredError() + return trigger == sender.name + elif isinstance(trigger, type): + return isinstance(sender, trigger) + elif isinstance(trigger, Agent): + # return True if the sender is the same type (class) as the trigger + return trigger == sender + elif isinstance(trigger, Callable): + rst = trigger(sender) + assert isinstance(rst, bool), f"trigger {trigger} must return a boolean value." + return rst + elif isinstance(trigger, list): + return any(self._match_trigger(t, sender) for t in trigger) + else: + raise ValueError(f"Unsupported trigger type: {type(trigger)}") + + def get_human_input(self, prompt: str) -> str: + """Get human input. + + Override this method to customize the way to get human input. + + Args: + prompt (str): prompt for the human input. + + Returns: + str: human input. + """ + iostream = IOStream.get_default() + + reply = iostream.input(prompt) + self._human_input.append(reply) + return reply + + async def a_get_human_input(self, prompt: str) -> str: + """(Async) Get human input. + + Override this method to customize the way to get human input. + + Args: + prompt (str): prompt for the human input. + + Returns: + str: human input. + """ + iostream = IOStream.get_default() + + reply = await iostream.input(prompt) + self._human_input.append(reply) + return reply + + # def _get_human_input( + # self, iostream: IOStream, prompt: str, + # ) -> tuple[bool, Optional[Union[str, dict[str, Any]]]]: + # with IOStream.set_default(iostream): + # print("!"*100) + # print("Getting human input...") + # return self.get_human_input(prompt) + + # return await asyncio.get_event_loop().run_in_executor( + # None, + # functools.partial( + # _get_human_input, self=self, iostream=iostream, prompt=prompt, + # ), + # ) + + def run_code(self, code: str, **kwargs: Any) -> tuple[int, str, Optional[str]]: + """Run the code and return the result. + + Override this function to modify the way to run the code. + + Args: + code (str): the code to be executed. + **kwargs: other keyword arguments. + + Returns: + A tuple of (exitcode, logs, image). + exitcode (int): the exit code of the code execution. + logs (str): the logs of the code execution. + image (str or None): the docker image used for the code execution. + """ + return execute_code(code, **kwargs) + + def execute_code_blocks(self, code_blocks): + """Execute the code blocks and return the result.""" + iostream = IOStream.get_default() + + logs_all = "" + for i, code_block in enumerate(code_blocks): + lang, code = code_block + if not lang: + lang = infer_lang(code) + + iostream.send(ExecuteCodeBlockEvent(code=code, language=lang, code_block_count=i, recipient=self)) + + if lang in ["bash", "shell", "sh"]: + exitcode, logs, image = self.run_code(code, lang=lang, **self._code_execution_config) + elif lang in PYTHON_VARIANTS: + filename = code[11 : code.find("\n")].strip() if code.startswith("# filename: ") else None + exitcode, logs, image = self.run_code( + code, + lang="python", + filename=filename, + **self._code_execution_config, + ) + else: + # In case the language is not supported, we return an error message. + exitcode, logs, image = ( + 1, + f"unknown language {lang}", + None, + ) + # raise NotImplementedError + if image is not None: + self._code_execution_config["use_docker"] = image + logs_all += "\n" + logs + if exitcode != 0: + return exitcode, logs_all + return exitcode, logs_all + + @staticmethod + def _format_json_str(jstr): + """Remove newlines outside of quotes, and handle JSON escape sequences. + + 1. this function removes the newline in the query outside of quotes otherwise json.loads(s) will fail. + Ex 1: + "{\n"tool": "python",\n"query": "print('hello')\nprint('world')"\n}" -> "{"tool": "python","query": "print('hello')\nprint('world')"}" + Ex 2: + "{\n \"location\": \"Boston, MA\"\n}" -> "{"location": "Boston, MA"}" + + 2. this function also handles JSON escape sequences inside quotes. + Ex 1: + '{"args": "a\na\na\ta"}' -> '{"args": "a\\na\\na\\ta"}' + """ + result = [] + inside_quotes = False + last_char = " " + for char in jstr: + if last_char != "\\" and char == '"': + inside_quotes = not inside_quotes + last_char = char + if not inside_quotes and char == "\n": + continue + if inside_quotes and char == "\n": + char = "\\n" + if inside_quotes and char == "\t": + char = "\\t" + result.append(char) + return "".join(result) + + def execute_function( + self, func_call: dict[str, Any], call_id: Optional[str] = None, verbose: bool = False + ) -> tuple[bool, dict[str, Any]]: + """Execute a function call and return the result. + + Override this function to modify the way to execute function and tool calls. + + Args: + func_call: a dictionary extracted from openai message at "function_call" or "tool_calls" with keys "name" and "arguments". + call_id: a string to identify the tool call. + verbose (bool): Whether to send messages about the execution details to the + output stream. When True, both the function call arguments and the execution + result will be displayed. Defaults to False. + + + Returns: + A tuple of (is_exec_success, result_dict). + is_exec_success (boolean): whether the execution is successful. + result_dict: a dictionary with keys "name", "role", and "content". Value of "role" is "function". + + "function_call" deprecated as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0) + See https://platform.openai.com/docs/api-reference/chat/create#chat-create-function_call + """ + iostream = IOStream.get_default() + + func_name = func_call.get("name", "") + func = self._function_map.get(func_name, None) + + is_exec_success = False + if func is not None: + # Extract arguments from a json-like string and put it into a dict. + input_string = self._format_json_str(func_call.get("arguments", "{}")) + try: + arguments = json.loads(input_string) + except json.JSONDecodeError as e: + arguments = None + content = f"Error: {e}\n The argument must be in JSON format." + + # Try to execute the function + if arguments is not None: + iostream.send( + ExecuteFunctionEvent(func_name=func_name, call_id=call_id, arguments=arguments, recipient=self) + ) + try: + content = func(**arguments) + is_exec_success = True + except Exception as e: + content = f"Error: {e}" + else: + arguments = {} + content = f"Error: Function {func_name} not found." + + iostream.send( + ExecutedFunctionEvent( + func_name=func_name, + call_id=call_id, + arguments=arguments, + content=content, + recipient=self, + is_exec_success=is_exec_success, + ) + ) + + return is_exec_success, { + "name": func_name, + "role": "function", + "content": content, + } + + async def a_execute_function( + self, func_call: dict[str, Any], call_id: Optional[str] = None, verbose: bool = False + ) -> tuple[bool, dict[str, Any]]: + """Execute an async function call and return the result. + + Override this function to modify the way async functions and tools are executed. + + Args: + func_call: a dictionary extracted from openai message at key "function_call" or "tool_calls" with keys "name" and "arguments". + call_id: a string to identify the tool call. + verbose (bool): Whether to send messages about the execution details to the + output stream. When True, both the function call arguments and the execution + result will be displayed. Defaults to False. + + Returns: + A tuple of (is_exec_success, result_dict). + is_exec_success (boolean): whether the execution is successful. + result_dict: a dictionary with keys "name", "role", and "content". Value of "role" is "function". + + "function_call" deprecated as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0) + See https://platform.openai.com/docs/api-reference/chat/create#chat-create-function_call + """ + iostream = IOStream.get_default() + + func_name = func_call.get("name", "") + func = self._function_map.get(func_name, None) + + is_exec_success = False + if func is not None: + # Extract arguments from a json-like string and put it into a dict. + input_string = self._format_json_str(func_call.get("arguments", "{}")) + try: + arguments = json.loads(input_string) + except json.JSONDecodeError as e: + arguments = None + content = f"Error: {e}\n The argument must be in JSON format." + + # Try to execute the function + if arguments is not None: + iostream.send( + ExecuteFunctionEvent(func_name=func_name, call_id=call_id, arguments=arguments, recipient=self) + ) + try: + if inspect.iscoroutinefunction(func): + content = await func(**arguments) + else: + # Fallback to sync function if the function is not async + content = func(**arguments) + is_exec_success = True + except Exception as e: + content = f"Error: {e}" + else: + arguments = {} + content = f"Error: Function {func_name} not found." + + iostream.send( + ExecutedFunctionEvent( + func_name=func_name, + call_id=call_id, + arguments=arguments, + content=content, + recipient=self, + is_exec_success=is_exec_success, + ) + ) + + return is_exec_success, { + "name": func_name, + "role": "function", + "content": content, + } + + def generate_init_message( + self, message: Optional[Union[dict[str, Any], str]], **kwargs: Any + ) -> Union[str, dict[str, Any]]: + """Generate the initial message for the agent. + If message is None, input() will be called to get the initial message. + + Args: + message (str or None): the message to be processed. + **kwargs: any additional information. It has the following reserved fields: + "carryover": a string or a list of string to specify the carryover information to be passed to this chat. It can be a string or a list of string. + If provided, we will combine this carryover with the "message" content when generating the initial chat + message. + + Returns: + str or dict: the processed message. + """ + if message is None: + message = self.get_human_input(">") + + return self._handle_carryover(message, kwargs) + + def _handle_carryover(self, message: Union[str, dict[str, Any]], kwargs: dict) -> Union[str, dict[str, Any]]: + if not kwargs.get("carryover"): + return message + + if isinstance(message, str): + return self._process_carryover(message, kwargs) + + elif isinstance(message, dict): + if isinstance(message.get("content"), str): + # Makes sure the original message is not mutated + message = message.copy() + message["content"] = self._process_carryover(message["content"], kwargs) + elif isinstance(message.get("content"), list): + # Makes sure the original message is not mutated + message = message.copy() + message["content"] = self._process_multimodal_carryover(message["content"], kwargs) + else: + raise InvalidCarryOverTypeError("Carryover should be a string or a list of strings.") + + return message + + def _process_carryover(self, content: str, kwargs: dict) -> str: + # Makes sure there's a carryover + if not kwargs.get("carryover"): + return content + + # if carryover is string + if isinstance(kwargs["carryover"], str): + content += "\nContext: \n" + kwargs["carryover"] + elif isinstance(kwargs["carryover"], list): + content += "\nContext: \n" + ("\n").join([_post_process_carryover_item(t) for t in kwargs["carryover"]]) + else: + raise InvalidCarryOverTypeError( + "Carryover should be a string or a list of strings. Not adding carryover to the message." + ) + return content + + def _process_multimodal_carryover(self, content: list[dict[str, Any]], kwargs: dict) -> list[dict[str, Any]]: + """Prepends the context to a multimodal message.""" + # Makes sure there's a carryover + if not kwargs.get("carryover"): + return content + + return [{"type": "text", "text": self._process_carryover("", kwargs)}] + content + + async def a_generate_init_message( + self, message: Optional[Union[dict[str, Any], str]], **kwargs: Any + ) -> Union[str, dict[str, Any]]: + """Generate the initial message for the agent. + If message is None, input() will be called to get the initial message. + + Args: + message (str or None): the message to be processed. + **kwargs: any additional information. It has the following reserved fields: + "carryover": a string or a list of string to specify the carryover information to be passed to this chat. It can be a string or a list of string. + If provided, we will combine this carryover with the "message" content when generating the initial chat + message. + + Returns: + str or dict: the processed message. + """ + if message is None: + message = await self.a_get_human_input(">") + + return self._handle_carryover(message, kwargs) + + @property + def tools(self) -> list[Tool]: + """Get the agent's tools (registered for LLM) + + Note this is a copy of the tools list, use add_tool and remove_tool to modify the tools list. + """ + return self._tools.copy() + + def remove_tool_for_llm(self, tool: Tool) -> None: + """Remove a tool (register for LLM tool)""" + try: + self._register_for_llm(tool=tool, api_style="tool", is_remove=True) + self._tools.remove(tool) + except ValueError: + raise ValueError(f"Tool {tool} not found in collection") + + def register_function(self, function_map: dict[str, Union[Callable[..., Any]]], silent_override: bool = False): + """Register functions to the agent. + + Args: + function_map: a dictionary mapping function names to functions. if function_map[name] is None, the function will be removed from the function_map. + silent_override: whether to print warnings when overriding functions. + """ + for name, func in function_map.items(): + self._assert_valid_name(name) + if func is None and name not in self._function_map: + warnings.warn(f"The function {name} to remove doesn't exist", name) + if not silent_override and name in self._function_map: + warnings.warn(f"Function '{name}' is being overridden.", UserWarning) + self._function_map.update(function_map) + self._function_map = {k: v for k, v in self._function_map.items() if v is not None} + + def update_function_signature( + self, func_sig: Union[str, dict[str, Any]], is_remove: None, silent_override: bool = False + ): + """Update a function_signature in the LLM configuration for function_call. + + Args: + func_sig (str or dict): description/name of the function to update/remove to the model. See: https://platform.openai.com/docs/api-reference/chat/create#chat/create-functions + is_remove: whether removing the function from llm_config with name 'func_sig' + silent_override: whether to print warnings when overriding functions. + + Deprecated as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0) + See https://platform.openai.com/docs/api-reference/chat/create#chat-create-function_call + """ + if not isinstance(self.llm_config, (dict, LLMConfig)): + error_msg = "To update a function signature, agent must have an llm_config" + logger.error(error_msg) + raise AssertionError(error_msg) + + if is_remove: + if "functions" not in self.llm_config or len(self.llm_config["functions"]) == 0: + error_msg = f"The agent config doesn't have function {func_sig}." + logger.error(error_msg) + raise AssertionError(error_msg) + else: + self.llm_config["functions"] = [ + func for func in self.llm_config["functions"] if func["name"] != func_sig + ] + else: + if not isinstance(func_sig, dict): + raise ValueError( + f"The function signature must be of the type dict. Received function signature type {type(func_sig)}" + ) + if "name" not in func_sig: + raise ValueError(f"The function signature must have a 'name' key. Received: {func_sig}") + self._assert_valid_name(func_sig["name"]), func_sig + if "functions" in self.llm_config: + if not silent_override and any( + func["name"] == func_sig["name"] for func in self.llm_config["functions"] + ): + warnings.warn(f"Function '{func_sig['name']}' is being overridden.", UserWarning) + + self.llm_config["functions"] = [ + func for func in self.llm_config["functions"] if func.get("name") != func_sig["name"] + ] + [func_sig] + else: + self.llm_config["functions"] = [func_sig] + + # Do this only if llm_config is a dict. If llm_config is LLMConfig, LLMConfig will handle this. + if len(self.llm_config["functions"]) == 0 and isinstance(self.llm_config, dict): + del self.llm_config["functions"] + + self.client = OpenAIWrapper(**self.llm_config) + + def update_tool_signature( + self, tool_sig: Union[str, dict[str, Any]], is_remove: bool, silent_override: bool = False + ): + """Update a tool_signature in the LLM configuration for tool_call. + + Args: + tool_sig (str or dict): description/name of the tool to update/remove to the model. See: https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools + is_remove: whether removing the tool from llm_config with name 'tool_sig' + silent_override: whether to print warnings when overriding functions. + """ + if not self.llm_config: + error_msg = "To update a tool signature, agent must have an llm_config" + logger.error(error_msg) + raise AssertionError(error_msg) + + if is_remove: + if "tools" not in self.llm_config or len(self.llm_config["tools"]) == 0: + error_msg = f"The agent config doesn't have tool {tool_sig}." + logger.error(error_msg) + raise AssertionError(error_msg) + else: + current_tools = self.llm_config["tools"] + filtered_tools = [] + + # Loop through and rebuild tools list without the tool to remove + for tool in current_tools: + tool_name = tool["function"]["name"] + + # Match by tool name, or by tool signature + is_different = tool_name != tool_sig if isinstance(tool_sig, str) else tool != tool_sig + + if is_different: + filtered_tools.append(tool) + + self.llm_config["tools"] = filtered_tools + else: + if not isinstance(tool_sig, dict): + raise ValueError( + f"The tool signature must be of the type dict. Received tool signature type {type(tool_sig)}" + ) + self._assert_valid_name(tool_sig["function"]["name"]) + if "tools" in self.llm_config and len(self.llm_config["tools"]) > 0: + if not silent_override and any( + tool["function"]["name"] == tool_sig["function"]["name"] for tool in self.llm_config["tools"] + ): + warnings.warn(f"Function '{tool_sig['function']['name']}' is being overridden.", UserWarning) + self.llm_config["tools"] = [ + tool + for tool in self.llm_config["tools"] + if tool.get("function", {}).get("name") != tool_sig["function"]["name"] + ] + [tool_sig] + else: + self.llm_config["tools"] = [tool_sig] + + # Do this only if llm_config is a dict. If llm_config is LLMConfig, LLMConfig will handle this. + if len(self.llm_config["tools"]) == 0 and isinstance(self.llm_config, dict): + del self.llm_config["tools"] + + self.client = OpenAIWrapper(**self.llm_config) + + def can_execute_function(self, name: Union[list[str], str]) -> bool: + """Whether the agent can execute the function.""" + names = name if isinstance(name, list) else [name] + return all([n in self._function_map for n in names]) + + @property + def function_map(self) -> dict[str, Callable[..., Any]]: + """Return the function map.""" + return self._function_map + + def _wrap_function(self, func: F, inject_params: dict[str, Any] = {}, *, serialize: bool = True) -> F: + """Wrap the function inject chat context parameters and to dump the return value to json. + + Handles both sync and async functions. + + Args: + func: the function to be wrapped. + inject_params: the chat context parameters which will be passed to the function. + serialize: whether to serialize the return value + + Returns: + The wrapped function. + """ + + @load_basemodels_if_needed + @functools.wraps(func) + def _wrapped_func(*args, **kwargs): + retval = func(*args, **kwargs, **inject_params) + if logging_enabled(): + log_function_use(self, func, kwargs, retval) + return serialize_to_str(retval) if serialize else retval + + @load_basemodels_if_needed + @functools.wraps(func) + async def _a_wrapped_func(*args, **kwargs): + retval = await func(*args, **kwargs, **inject_params) + if logging_enabled(): + log_function_use(self, func, kwargs, retval) + return serialize_to_str(retval) if serialize else retval + + wrapped_func = _a_wrapped_func if inspect.iscoroutinefunction(func) else _wrapped_func + + # needed for testing + wrapped_func._origin = func + + return wrapped_func + + @staticmethod + def _create_tool_if_needed( + func_or_tool: Union[F, Tool], + name: Optional[str], + description: Optional[str], + ) -> Tool: + if isinstance(func_or_tool, Tool): + tool: Tool = func_or_tool + # create new tool object if name or description is not None + if name or description: + tool = Tool(func_or_tool=tool, name=name, description=description) + elif inspect.isfunction(func_or_tool): + function: Callable[..., Any] = func_or_tool + tool = Tool(func_or_tool=function, name=name, description=description) + else: + raise TypeError(f"'func_or_tool' must be a function or a Tool object, got '{type(func_or_tool)}' instead.") + return tool + + def register_for_llm( + self, + *, + name: Optional[str] = None, + description: Optional[str] = None, + api_style: Literal["function", "tool"] = "tool", + silent_override: bool = False, + ) -> Callable[[Union[F, Tool]], Tool]: + """Decorator factory for registering a function to be used by an agent. + + It's return value is used to decorate a function to be registered to the agent. The function uses type hints to + specify the arguments and return type. The function name is used as the default name for the function, + but a custom name can be provided. The function description is used to describe the function in the + agent's configuration. + + Args: + name (optional(str)): name of the function. If None, the function name will be used (default: None). + description (optional(str)): description of the function (default: None). It is mandatory + for the initial decorator, but the following ones can omit it. + api_style: (literal): the API style for function call. + For Azure OpenAI API, use version 2023-12-01-preview or later. + `"function"` style will be deprecated. For earlier version use + `"function"` if `"tool"` doesn't work. + See [Azure OpenAI documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/function-calling?tabs=python) for details. + silent_override (bool): whether to suppress any override warning messages. + + Returns: + The decorator for registering a function to be used by an agent. + + Examples: + ``` + @user_proxy.register_for_execution() + @agent2.register_for_llm() + @agent1.register_for_llm(description="This is a very useful function") + def my_function(a: Annotated[str, "description of a parameter"] = "a", b: int, c=3.14) -> str: + return a + str(b * c) + ``` + + For Azure OpenAI versions prior to 2023-12-01-preview, set `api_style` + to `"function"` if `"tool"` doesn't work: + ``` + @agent2.register_for_llm(api_style="function") + def my_function(a: Annotated[str, "description of a parameter"] = "a", b: int, c=3.14) -> str: + return a + str(b * c) + ``` + + """ + + def _decorator( + func_or_tool: Union[F, Tool], name: Optional[str] = name, description: Optional[str] = description + ) -> Tool: + """Decorator for registering a function to be used by an agent. + + Args: + func_or_tool: The function or the tool to be registered. + name: The name of the function or the tool. + description: The description of the function or the tool. + + Returns: + The function to be registered, with the _description attribute set to the function description. + + Raises: + ValueError: if the function description is not provided and not propagated by a previous decorator. + RuntimeError: if the LLM config is not set up before registering a function. + + """ + tool = self._create_tool_if_needed(func_or_tool, name, description) + + self._register_for_llm(tool, api_style, silent_override=silent_override) + if tool not in self._tools: + self._tools.append(tool) + + return tool + + return _decorator + + def _register_for_llm( + self, tool: Tool, api_style: Literal["tool", "function"], is_remove: bool = False, silent_override: bool = False + ) -> None: + """ + Register a tool for LLM. + + Args: + tool: the tool to be registered. + api_style: the API style for function call ("tool" or "function"). + is_remove: whether to remove the function or tool. + silent_override: whether to suppress any override warning messages. + + Returns: + None + """ + # register the function to the agent if there is LLM config, raise an exception otherwise + if self.llm_config is None: + raise RuntimeError("LLM config must be setup before registering a function for LLM.") + + if api_style == "function": + self.update_function_signature(tool.function_schema, is_remove=is_remove, silent_override=silent_override) + elif api_style == "tool": + self.update_tool_signature(tool.tool_schema, is_remove=is_remove, silent_override=silent_override) + else: + raise ValueError(f"Unsupported API style: {api_style}") + + def set_ui_tools(self, tools: list[Tool]) -> None: + """Set the UI tools for the agent. + + Args: + tools: a list of tools to be set. + """ + # Unset the previous UI tools + self._unset_previous_ui_tools() + + # Set the new UI tools + for tool in tools: + # Register the tool for LLM + self._register_for_llm(tool, api_style="tool", silent_override=True) + if tool not in self._tools: + self._tools.append(tool) + + # Register for execution + self.register_for_execution(serialize=False, silent_override=True)(tool) + + # Set the current UI tools + self._ui_tools = tools + + def unset_ui_tools(self, tools: list[Tool]) -> None: + """Unset the UI tools for the agent. + + Args: + tools: a list of tools to be unset. + """ + for tool in tools: + self.remove_tool_for_llm(tool) + + def _unset_previous_ui_tools(self) -> None: + """Unset the previous UI tools for the agent. + + This is used to remove UI tools that were previously registered for LLM. + """ + self.unset_ui_tools(self._ui_tools) + for tool in self._ui_tools: + if tool in self._tools: + self._tools.remove(tool) + + # Unregister the function from the function map + if tool.name in self._function_map: + del self._function_map[tool.name] + + self._ui_tools = [] + + def register_for_execution( + self, + name: Optional[str] = None, + description: Optional[str] = None, + *, + serialize: bool = True, + silent_override: bool = False, + ) -> Callable[[Union[Tool, F]], Tool]: + """Decorator factory for registering a function to be executed by an agent. + + It's return value is used to decorate a function to be registered to the agent. + + Args: + name: name of the function. If None, the function name will be used (default: None). + description: description of the function (default: None). + serialize: whether to serialize the return value + silent_override: whether to suppress any override warning messages + + Returns: + The decorator for registering a function to be used by an agent. + + Examples: + ``` + @user_proxy.register_for_execution() + @agent2.register_for_llm() + @agent1.register_for_llm(description="This is a very useful function") + def my_function(a: Annotated[str, "description of a parameter"] = "a", b: int, c=3.14): + return a + str(b * c) + ``` + + """ + + def _decorator( + func_or_tool: Union[Tool, F], name: Optional[str] = name, description: Optional[str] = description + ) -> Tool: + """Decorator for registering a function to be used by an agent. + + Args: + func_or_tool: the function or the tool to be registered. + name: the name of the function. + description: the description of the function. + + Returns: + The tool to be registered. + + """ + + tool = self._create_tool_if_needed(func_or_tool, name, description) + chat_context = ChatContext(self) + chat_context_params = {param: chat_context for param in tool._chat_context_param_names} + + self.register_function( + {tool.name: self._wrap_function(tool.func, chat_context_params, serialize=serialize)}, + silent_override=silent_override, + ) + + return tool + + return _decorator + + def register_model_client(self, model_client_cls: ModelClient, **kwargs: Any): + """Register a model client. + + Args: + model_client_cls: A custom client class that follows the Client interface + **kwargs: The kwargs for the custom client class to be initialized with + """ + self.client.register_model_client(model_client_cls, **kwargs) + + def register_hook(self, hookable_method: str, hook: Callable): + """Registers a hook to be called by a hookable method, in order to add a capability to the agent. + Registered hooks are kept in lists (one per hookable method), and are called in their order of registration. + + Args: + hookable_method: A hookable method name implemented by ConversableAgent. + hook: A method implemented by a subclass of AgentCapability. + """ + assert hookable_method in self.hook_lists, f"{hookable_method} is not a hookable method." + hook_list = self.hook_lists[hookable_method] + assert hook not in hook_list, f"{hook} is already registered as a hook." + hook_list.append(hook) + + def update_agent_state_before_reply(self, messages: list[dict[str, Any]]) -> None: + """Calls any registered capability hooks to update the agent's state. + Primarily used to update context variables. + Will, potentially, modify the messages. + """ + hook_list = self.hook_lists["update_agent_state"] + + # Call each hook (in order of registration) to process the messages. + for hook in hook_list: + hook(self, messages) + + def process_all_messages_before_reply(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Calls any registered capability hooks to process all messages, potentially modifying the messages.""" + hook_list = self.hook_lists["process_all_messages_before_reply"] + # If no hooks are registered, or if there are no messages to process, return the original message list. + if len(hook_list) == 0 or messages is None: + return messages + + # Call each hook (in order of registration) to process the messages. + processed_messages = messages + for hook in hook_list: + processed_messages = hook(processed_messages) + return processed_messages + + def process_last_received_message(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Calls any registered capability hooks to use and potentially modify the text of the last message, + as long as the last message is not a function call or exit command. + """ + # If any required condition is not met, return the original message list. + hook_list = self.hook_lists["process_last_received_message"] + if len(hook_list) == 0: + return messages # No hooks registered. + if messages is None: + return None # No message to process. + if len(messages) == 0: + return messages # No message to process. + last_message = messages[-1] + if "function_call" in last_message: + return messages # Last message is a function call. + if "context" in last_message: + return messages # Last message contains a context key. + if "content" not in last_message: + return messages # Last message has no content. + + user_content = last_message["content"] + if not isinstance(user_content, str) and not isinstance(user_content, list): + # if the user_content is a string, it is for regular LLM + # if the user_content is a list, it should follow the multimodal LMM format. + return messages + if user_content == "exit": + return messages # Last message is an exit command. + + # Call each hook (in order of registration) to process the user's message. + processed_user_content = user_content + for hook in hook_list: + processed_user_content = hook(processed_user_content) + + if processed_user_content == user_content: + return messages # No hooks actually modified the user's message. + + # Replace the last user message with the expanded one. + messages = messages.copy() + messages[-1]["content"] = processed_user_content + return messages + + def print_usage_summary(self, mode: Union[str, list[str]] = ["actual", "total"]) -> None: + """Print the usage summary.""" + iostream = IOStream.get_default() + if self.client is None: + iostream.send(ConversableAgentUsageSummaryNoCostIncurredEvent(recipient=self)) + else: + iostream.send(ConversableAgentUsageSummaryEvent(recipient=self)) + + if self.client is not None: + self.client.print_usage_summary(mode) + + def get_actual_usage(self) -> Union[None, dict[str, int]]: + """Get the actual usage summary.""" + if self.client is None: + return None + else: + return self.client.actual_usage_summary + + def get_total_usage(self) -> Union[None, dict[str, int]]: + """Get the total usage summary.""" + if self.client is None: + return None + else: + return self.client.total_usage_summary + + @contextmanager + def _create_or_get_executor( + self, + executor_kwargs: Optional[dict[str, Any]] = None, + tools: Optional[Union[Tool, Iterable[Tool]]] = None, + agent_name: str = "executor", + agent_human_input_mode: str = "NEVER", + ) -> Generator["ConversableAgent", None, None]: + """Creates a user proxy / tool executor agent. + + Note: Code execution is not enabled by default. Pass the code execution config into executor_kwargs, if needed. + + Args: + executor_kwargs: agent's arguments. + tools: tools to register for execution with the agent. + agent_name: agent's name, defaults to 'executor'. + agent_human_input_mode: agent's human input mode, defaults to 'NEVER'. + """ + if executor_kwargs is None: + executor_kwargs = {} + if "is_termination_msg" not in executor_kwargs: + executor_kwargs["is_termination_msg"] = lambda x: (x["content"] is not None) and "TERMINATE" in x["content"] + + try: + if not self.run_executor: + self.run_executor = ConversableAgent( + name=agent_name, + human_input_mode=agent_human_input_mode, + **executor_kwargs, + ) + + tools = [] if tools is None else tools + tools = [tools] if isinstance(tools, Tool) else tools + for tool in tools: + tool.register_for_execution(self.run_executor) + tool.register_for_llm(self) + yield self.run_executor + finally: + if tools is not None: + for tool in tools: + self.update_tool_signature(tool_sig=tool.tool_schema["function"]["name"], is_remove=True) + + def _deprecated_run( + self, + message: str, + *, + tools: Optional[Union[Tool, Iterable[Tool]]] = None, + executor_kwargs: Optional[dict[str, Any]] = None, + max_turns: Optional[int] = None, + msg_to: Literal["agent", "user"] = "agent", + clear_history: bool = False, + user_input: bool = True, + summary_method: Optional[Union[str, Callable[..., Any]]] = DEFAULT_SUMMARY_METHOD, + ) -> ChatResult: + """Run a chat with the agent using the given message. + + A second agent will be created to represent the user, this agent will by known by the name 'user'. This agent does not have code execution enabled by default, if needed pass the code execution config in with the executor_kwargs parameter. + + The user can terminate the conversation when prompted or, if agent's reply contains 'TERMINATE', it will terminate. + + Args: + message: the message to be processed. + tools: the tools to be used by the agent. + executor_kwargs: the keyword arguments for the executor. + max_turns: maximum number of turns (a turn is equivalent to both agents having replied), defaults no None which means unlimited. The original message is included. + msg_to: which agent is receiving the message and will be the first to reply, defaults to the agent. + clear_history: whether to clear the chat history. + user_input: the user will be asked for input at their turn. + summary_method: the method to summarize the chat. + """ + with self._create_or_get_executor( + executor_kwargs=executor_kwargs, + tools=tools, + agent_name="user", + agent_human_input_mode="ALWAYS" if user_input else "NEVER", + ) as executor: + if msg_to == "agent": + return executor.initiate_chat( + self, + message=message, + clear_history=clear_history, + max_turns=max_turns, + summary_method=summary_method, + ) + else: + return self.initiate_chat( + executor, + message=message, + clear_history=clear_history, + max_turns=max_turns, + summary_method=summary_method, + ) + + async def _deprecated_a_run( + self, + message: str, + *, + tools: Optional[Union[Tool, Iterable[Tool]]] = None, + executor_kwargs: Optional[dict[str, Any]] = None, + max_turns: Optional[int] = None, + msg_to: Literal["agent", "user"] = "agent", + clear_history: bool = False, + user_input: bool = True, + summary_method: Optional[Union[str, Callable[..., Any]]] = DEFAULT_SUMMARY_METHOD, + ) -> ChatResult: + """Run a chat asynchronously with the agent using the given message. + + A second agent will be created to represent the user, this agent will by known by the name 'user'. + + The user can terminate the conversation when prompted or, if agent's reply contains 'TERMINATE', it will terminate. + + Args: + message: the message to be processed. + tools: the tools to be used by the agent. + executor_kwargs: the keyword arguments for the executor. + max_turns: maximum number of turns (a turn is equivalent to both agents having replied), defaults no None which means unlimited. The original message is included. + msg_to: which agent is receiving the message and will be the first to reply, defaults to the agent. + clear_history: whether to clear the chat history. + user_input: the user will be asked for input at their turn. + summary_method: the method to summarize the chat. + """ + with self._create_or_get_executor( + executor_kwargs=executor_kwargs, + tools=tools, + agent_name="user", + agent_human_input_mode="ALWAYS" if user_input else "NEVER", + ) as executor: + if msg_to == "agent": + return await executor.a_initiate_chat( + self, + message=message, + clear_history=clear_history, + max_turns=max_turns, + summary_method=summary_method, + ) + else: + return await self.a_initiate_chat( + executor, + message=message, + clear_history=clear_history, + max_turns=max_turns, + summary_method=summary_method, + ) + + def register_handoff(self, condition: Union["OnContextCondition", "OnCondition"]) -> None: + """ + Register a single handoff condition (OnContextCondition or OnCondition). + + Args: + condition: The condition to add (OnContextCondition, OnCondition) + """ + self.handoffs.add(condition) + + def register_handoffs(self, conditions: list[Union["OnContextCondition", "OnCondition"]]) -> None: + """ + Register multiple handoff conditions (OnContextCondition or OnCondition). + + Args: + conditions: List of conditions to add + """ + self.handoffs.add_many(conditions) + + +@export_module("autogen") +def register_function( + f: Callable[..., Any], + *, + caller: ConversableAgent, + executor: ConversableAgent, + name: Optional[str] = None, + description: str, +) -> None: + """Register a function to be proposed by an agent and executed for an executor. + + This function can be used instead of function decorators `@ConversationAgent.register_for_llm` and + `@ConversationAgent.register_for_execution`. + + Args: + f: the function to be registered. + caller: the agent calling the function, typically an instance of ConversableAgent. + executor: the agent executing the function, typically an instance of UserProxy. + name: name of the function. If None, the function name will be used (default: None). + description: description of the function. The description is used by LLM to decode whether the function + is called. Make sure the description is properly describing what the function does or it might not be + called by LLM when needed. + + """ + f = caller.register_for_llm(name=name, description=description)(f) + executor.register_for_execution(name=name)(f) diff --git a/mm_agents/coact/autogen/agentchat/group/__init__.py b/mm_agents/coact/autogen/agentchat/group/__init__.py new file mode 100644 index 0000000..4b95a59 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/group/__init__.py @@ -0,0 +1,64 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 +# +__all__: list[str] = [] + +from .available_condition import ExpressionAvailableCondition, StringAvailableCondition +from .context_condition import ExpressionContextCondition, StringContextCondition +from .context_expression import ContextExpression +from .context_str import ContextStr +from .context_variables import ContextVariables +from .handoffs import Handoffs +from .llm_condition import ContextStrLLMCondition, StringLLMCondition +from .on_condition import OnCondition +from .on_context_condition import OnContextCondition +from .reply_result import ReplyResult +from .speaker_selection_result import SpeakerSelectionResult +from .targets.group_chat_target import GroupChatConfig, GroupChatTarget + +""" +from .targets.group_manager_target import ( + GroupManagerSelectionMessageContextStr, + GroupManagerSelectionMessageString, + GroupManagerTarget, +) +""" +from .targets.transition_target import ( + AgentNameTarget, + AgentTarget, + AskUserTarget, + NestedChatTarget, + RevertToUserTarget, + StayTarget, + TerminateTarget, +) + +__all__ = [ + "AgentNameTarget", + "AgentTarget", + "AskUserTarget", + "ContextExpression", + "ContextStr", + "ContextStrLLMCondition", + "ContextVariables", + "ExpressionAvailableCondition", + "ExpressionContextCondition", + "GroupChatConfig", + "GroupChatTarget", + # "GroupManagerSelectionMessageContextStr", + # "GroupManagerSelectionMessageString", + # "GroupManagerTarget", + "Handoffs", + "NestedChatTarget", + "OnCondition", + "OnContextCondition", + "ReplyResult", + "RevertToUserTarget", + "SpeakerSelectionResult", + "StayTarget", + "StringAvailableCondition", + "StringContextCondition", + "StringLLMCondition", + "TerminateTarget", +] diff --git a/mm_agents/coact/autogen/agentchat/group/available_condition.py b/mm_agents/coact/autogen/agentchat/group/available_condition.py new file mode 100644 index 0000000..3fbc871 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/group/available_condition.py @@ -0,0 +1,91 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING, Any + +from pydantic import BaseModel + +from .context_expression import ContextExpression + +if TYPE_CHECKING: + # Avoid circular import + from ..conversable_agent import ConversableAgent + +__all__ = ["AvailableCondition", "ExpressionAvailableCondition", "StringAvailableCondition"] + + +class AvailableCondition(BaseModel): + """Protocol for determining if a condition is available to be evaluated.""" + + def is_available(self, agent: "ConversableAgent", messages: list[dict[str, Any]]) -> bool: + """Determine if the condition should be considered for evaluation. + + Args: + agent: The agent evaluating the condition + messages: The conversation history + + Returns: + True if the condition should be evaluated, False otherwise + """ + raise NotImplementedError("Requires subclasses to implement.") + + +class StringAvailableCondition(AvailableCondition): + """String-based available condition. + + This condition checks if a named context variable exists and is truthy. + """ + + context_variable: str + + def __init__(self, context_variable: str, **data: Any) -> None: + """Initialize with a context variable name as a positional parameter. + + Args: + context_variable: The name of the context variable to check + data: Additional data for the parent class + """ + super().__init__(context_variable=context_variable, **data) + + def is_available(self, agent: "ConversableAgent", messages: list[dict[str, Any]]) -> bool: + """Check if the named context variable is truthy. + + Args: + agent: The agent with context variables + messages: The conversation history (not used) + + Returns: + True if the variable exists and is truthy, False otherwise + """ + return bool(agent.context_variables.get(self.context_variable, False)) + + +class ExpressionAvailableCondition(AvailableCondition): + """Expression-based available condition. + + This condition evaluates a ContextExpression against the context variables. + """ + + expression: ContextExpression + + def __init__(self, expression: ContextExpression, **data: Any) -> None: + """Initialize with an expression as a positional parameter. + + Args: + expression: The context expression to evaluate + data: Additional data for the parent class + """ + super().__init__(expression=expression, **data) + + def is_available(self, agent: "ConversableAgent", messages: list[dict[str, Any]]) -> bool: + """Evaluate the expression against the context variables. + + Args: + agent: The agent with context variables + messages: The conversation history (not used) + + Returns: + Boolean result of the expression evaluation + """ + return self.expression.evaluate(agent.context_variables) diff --git a/mm_agents/coact/autogen/agentchat/group/context_condition.py b/mm_agents/coact/autogen/agentchat/group/context_condition.py new file mode 100644 index 0000000..9a4b495 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/group/context_condition.py @@ -0,0 +1,77 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + + +from typing import Any + +from pydantic import BaseModel + +from .context_expression import ContextExpression +from .context_variables import ContextVariables + +__all__ = ["ContextCondition", "ExpressionContextCondition", "StringContextCondition"] + + +class ContextCondition(BaseModel): + """Protocol for conditions evaluated directly using context variables.""" + + def evaluate(self, context_variables: ContextVariables) -> bool: + """Evaluate the condition to a boolean result. + + Args: + context_variables: The context variables to evaluate against + + Returns: + Boolean result of the condition evaluation + """ + raise NotImplementedError("Requires subclasses to implement.") + + +class StringContextCondition(ContextCondition): + """Simple string-based context condition. + + This condition checks if a named context variable exists and is truthy. + """ + + variable_name: str + + def evaluate(self, context_variables: ContextVariables) -> bool: + """Check if the named context variable is truthy. + + Args: + context_variables: The context variables to check against + + Returns: + True if the variable exists and is truthy, False otherwise + """ + return bool(context_variables.get(self.variable_name, False)) + + +class ExpressionContextCondition(ContextCondition): + """Complex expression-based context condition. + + This condition evaluates a ContextExpression against the context variables. + """ + + expression: ContextExpression + + def __init__(self, expression: ContextExpression, **data: Any) -> None: + """Initialize with an expression as a positional parameter. + + Args: + expression: The context expression to evaluate + data: Additional data for the parent class + """ + super().__init__(expression=expression, **data) + + def evaluate(self, context_variables: ContextVariables) -> bool: + """Evaluate the expression against the context variables. + + Args: + context_variables: The context variables to evaluate against + + Returns: + Boolean result of the expression evaluation + """ + return self.expression.evaluate(context_variables) diff --git a/mm_agents/coact/autogen/agentchat/group/context_expression.py b/mm_agents/coact/autogen/agentchat/group/context_expression.py new file mode 100644 index 0000000..64f2052 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/group/context_expression.py @@ -0,0 +1,238 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +import ast +import re +from dataclasses import dataclass + +from ...doc_utils import export_module +from .context_variables import ContextVariables + + +@dataclass +@export_module("autogen") +class ContextExpression: + """A class to evaluate logical expressions using context variables. + + Args: + expression (str): A string containing a logical expression with context variable references. + - Variable references use ${var_name} syntax: ${logged_in}, ${attempts} + - String literals can use normal quotes: 'hello', "world" + - Supported operators: + - Logical: not/!, and/&, or/| + - Comparison: >, <, >=, <=, ==, != + - Supported functions: + - len(${var_name}): Gets the length of a list, string, or other collection + - Parentheses can be used for grouping + - Examples: + - "not ${logged_in} and ${is_admin} or ${guest_checkout}" + - "!${logged_in} & ${is_admin} | ${guest_checkout}" + - "len(${orders}) > 0 & ${user_active}" + - "len(${cart_items}) == 0 | ${checkout_started}" + + Raises: + SyntaxError: If the expression cannot be parsed + ValueError: If the expression contains disallowed operations + """ + + expression: str + + def __post_init__(self) -> None: + # Validate the expression immediately upon creation + try: + # Extract variable references and replace with placeholders + self._variable_names = self._extract_variable_names(self.expression) + + # Convert symbolic operators to Python keywords + python_expr = self._convert_to_python_syntax(self.expression) + + # Sanitize for AST parsing + sanitized_expr = self._prepare_for_ast(python_expr) + + # Use ast to parse and validate the expression + self._ast = ast.parse(sanitized_expr, mode="eval") + + # Verify it only contains allowed operations + self._validate_operations(self._ast.body) + + # Store the Python-syntax version for evaluation + self._python_expr = python_expr + + except SyntaxError as e: + raise SyntaxError(f"Invalid expression syntax in '{self.expression}': {str(e)}") + except Exception as e: + raise ValueError(f"Error validating expression '{self.expression}': {str(e)}") + + def _extract_variable_names(self, expr: str) -> list[str]: + """Extract all variable references ${var_name} from the expression.""" + # Find all patterns like ${var_name} + matches = re.findall(r"\${([^}]*)}", expr) + return matches + + def _convert_to_python_syntax(self, expr: str) -> str: + """Convert symbolic operators to Python keywords.""" + # We need to be careful about operators inside string literals + # First, temporarily replace string literals with placeholders + string_literals = [] + + def replace_string_literal(match: re.Match[str]) -> str: + string_literals.append(match.group(0)) + return f"__STRING_LITERAL_{len(string_literals) - 1}__" + + # Replace both single and double quoted strings + expr_without_strings = re.sub(r"'[^']*'|\"[^\"]*\"", replace_string_literal, expr) + + # Handle the NOT operator (!) - no parentheses handling needed + # Replace standalone ! before variables or expressions + expr_without_strings = re.sub(r"!\s*(\${|\()", "not \\1", expr_without_strings) + + # Handle AND and OR operators - simpler approach without parentheses handling + expr_without_strings = re.sub(r"\s+&\s+", " and ", expr_without_strings) + expr_without_strings = re.sub(r"\s+\|\s+", " or ", expr_without_strings) + + # Now put string literals back + for i, literal in enumerate(string_literals): + expr_without_strings = expr_without_strings.replace(f"__STRING_LITERAL_{i}__", literal) + + return expr_without_strings + + def _prepare_for_ast(self, expr: str) -> str: + """Convert the expression to valid Python for AST parsing by replacing variables with placeholders.""" + # Replace ${var_name} with var_name for AST parsing + processed_expr = expr + for var_name in self._variable_names: + processed_expr = processed_expr.replace(f"${{{var_name}}}", var_name) + + return processed_expr + + def _validate_operations(self, node: ast.AST) -> None: + """Recursively validate that only allowed operations exist in the AST.""" + allowed_node_types = ( + # Boolean operations + ast.BoolOp, + ast.UnaryOp, + ast.And, + ast.Or, + ast.Not, + # Comparison operations + ast.Compare, + ast.Eq, + ast.NotEq, + ast.Lt, + ast.LtE, + ast.Gt, + ast.GtE, + # Basic nodes + ast.Name, + ast.Load, + ast.Constant, + ast.Expression, + # Support for basic numeric operations in comparisons + ast.Num, + ast.NameConstant, + # Support for negative numbers + ast.USub, + ast.UnaryOp, + # Support for string literals + ast.Str, + ast.Constant, + # Support for function calls (specifically len()) + ast.Call, + ) + + if not isinstance(node, allowed_node_types): + raise ValueError(f"Operation type {type(node).__name__} is not allowed in logical expressions") + + # Special validation for function calls - only allow len() + if isinstance(node, ast.Call): + if not (isinstance(node.func, ast.Name) and node.func.id == "len"): + raise ValueError(f"Only the len() function is allowed, got: {getattr(node.func, 'id', 'unknown')}") + if len(node.args) != 1: + raise ValueError(f"len() function must have exactly one argument, got {len(node.args)}") + + # Special validation for Compare nodes + if isinstance(node, ast.Compare): + for op in node.ops: + if not isinstance(op, (ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE)): + raise ValueError(f"Comparison operator {type(op).__name__} is not allowed") + + # Recursively check child nodes + for child in ast.iter_child_nodes(node): + self._validate_operations(child) + + def evaluate(self, context_variables: ContextVariables) -> bool: + """Evaluate the expression using the provided context variables. + + Args: + context_variables: Dictionary of context variables to use for evaluation + + Returns: + bool: The result of evaluating the expression + + Raises: + KeyError: If a variable referenced in the expression is not found in the context + """ + # Create a modified expression that we can safely evaluate + eval_expr = self._python_expr # Use the Python-syntax version + + # First, handle len() functions with variable references inside + len_pattern = r"len\(\${([^}]*)}\)" + len_matches = list(re.finditer(len_pattern, eval_expr)) + + # Process all len() operations first + for match in len_matches: + var_name = match.group(1) + # Check if variable exists in context, raise KeyError if not + if not context_variables.contains(var_name): + raise KeyError(f"Missing context variable: '{var_name}'") + + var_value = context_variables.get(var_name) + + # Calculate the length - works for lists, strings, dictionaries, etc. + try: + length_value = len(var_value) # type: ignore[arg-type] + except TypeError: + # If the value doesn't support len(), treat as 0 + length_value = 0 + + # Replace the len() expression with the actual length + full_match = match.group(0) + eval_expr = eval_expr.replace(full_match, str(length_value)) + + # Then replace remaining variable references with their values + for var_name in self._variable_names: + # Skip variables that were already processed in len() expressions + if any(m.group(1) == var_name for m in len_matches): + continue + + # Check if variable exists in context, raise KeyError if not + if not context_variables.contains(var_name): + raise KeyError(f"Missing context variable: '{var_name}'") + + # Get the value from context + var_value = context_variables.get(var_name) + + # Format the value appropriately based on its type + if isinstance(var_value, (bool, int, float)): + formatted_value = str(var_value) + elif isinstance(var_value, str): + formatted_value = f"'{var_value}'" # Quote strings + elif isinstance(var_value, (list, dict, tuple)): + # For collections, convert to their boolean evaluation + formatted_value = str(bool(var_value)) + else: + formatted_value = str(var_value) + + # Replace the variable reference with the formatted value + eval_expr = eval_expr.replace(f"${{{var_name}}}", formatted_value) + + try: + return eval(eval_expr) # type: ignore[no-any-return] + except Exception as e: + raise ValueError( + f"Error evaluating expression '{self.expression}' (are you sure you're using ${{my_context_variable_key}}): {str(e)}" + ) + + def __str__(self) -> str: + return f"ContextExpression('{self.expression}')" diff --git a/mm_agents/coact/autogen/agentchat/group/context_str.py b/mm_agents/coact/autogen/agentchat/group/context_str.py new file mode 100644 index 0000000..2c42d83 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/group/context_str.py @@ -0,0 +1,41 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +from pydantic import BaseModel + +from .context_variables import ContextVariables + +__all__ = ["ContextStr"] + + +class ContextStr(BaseModel): + """A string that requires context variable substitution. + + Use the format method to substitute context variables into the string. + """ + + """The string to be substituted with context variables. It is expected that the string will contain `{var}` placeholders and that string format will be able to replace all values.""" + template: str + + def format(self, context_variables: ContextVariables) -> Optional[str]: + """Substitute context variables into the string. + + Args: + context_variables (ContextVariables): The context variables to substitute into the string. + + Returns: + Optional[str]: The formatted string with context variables substituted. + """ + + context = context_variables.to_dict() + + if not context: + return self.template + + return self.template.format(**context) + + def __str__(self) -> str: + return f"ContextStr, unformatted: {self.template}" diff --git a/mm_agents/coact/autogen/agentchat/group/context_variables.py b/mm_agents/coact/autogen/agentchat/group/context_variables.py new file mode 100644 index 0000000..4ef8fef --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/group/context_variables.py @@ -0,0 +1,192 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Generator, Iterable, Optional + +from pydantic import BaseModel, Field + +__all__ = ["ContextVariables"] + +# Parameter name for context variables +# Use the value in functions and they will be substituted with the context variables: +# e.g. def my_function(context_variables: ContextVariables, my_other_parameters: Any) -> Any: +__CONTEXT_VARIABLES_PARAM_NAME__ = "context_variables" + + +class ContextVariables(BaseModel): + """ + Stores and manages context variables for agentic workflows. + + Utilises a dictionary-like interface for setting, getting, and removing variables. + """ + + # Internal storage for context variables + data: dict[str, Any] = Field(default_factory=dict) + + def __init__(self, data: Optional[dict[str, Any]] = None, **kwargs: Any) -> None: + """Initialize with data dictionary as an optional positional parameter. + + Args: + data: Initial dictionary of context variables (optional) + kwargs: Additional keyword arguments for the parent class + """ + init_data = data or {} + super().__init__(data=init_data, **kwargs) + + def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]: + """ + Get a value from the context by key. + + Args: + key: The key to retrieve + default: The default value to return if key is not found + + Returns: + The value associated with the key or default if not found + """ + return self.data.get(key, default) + + def set(self, key: str, value: Any) -> None: + """ + Set a value in the context by key. + + Args: + key: The key to set + value: The value to store + """ + self.data[key] = value + + def remove(self, key: str) -> bool: + """ + Remove a key from the context. + + Args: + key: The key to remove + + Returns: + True if the key was removed, False if it didn't exist + """ + if key in self.data: + del self.data[key] + return True + return False + + def keys(self) -> Iterable[str]: + """ + Get all keys in the context. + + Returns: + An iterable of all keys + """ + return self.data.keys() + + def values(self) -> Iterable[Any]: + """ + Get all values in the context. + + Returns: + An iterable of all values + """ + return self.data.values() + + def items(self) -> Iterable[tuple[str, Any]]: + """ + Get all key-value pairs in the context. + + Returns: + An iterable of all key-value pairs + """ + return self.data.items() + + def clear(self) -> None: + """Clear all keys and values from the context.""" + self.data.clear() + + def contains(self, key: str) -> bool: + """ + Check if a key exists in the context. + + Args: + key: The key to check + + Returns: + True if the key exists, False otherwise + """ + return key in self.data + + def update(self, other: dict[str, Any]) -> None: + """ + Update context with key-value pairs from another dictionary. + + Args: + other: Dictionary containing key-value pairs to add + """ + self.data.update(other) + + def to_dict(self) -> dict[str, Any]: + """ + Convert context variables to a dictionary. + + Returns: + Dictionary representation of all context variables + """ + return self.data.copy() + + # Dictionary-compatible interface + def __getitem__(self, key: str) -> Any: + """Get a value using dictionary syntax: context[key]""" + try: + return self.data[key] + except KeyError: + raise KeyError(f"Context variable '{key}' not found") + + def __setitem__(self, key: str, value: Any) -> None: + """Set a value using dictionary syntax: context[key] = value""" + self.data[key] = value + + def __delitem__(self, key: str) -> None: + """Delete a key using dictionary syntax: del context[key]""" + try: + del self.data[key] + except KeyError: + raise KeyError(f"Cannot delete non-existent context variable '{key}'") + + def __contains__(self, key: str) -> bool: + """Check if key exists using 'in' operator: key in context""" + return key in self.data + + def __len__(self) -> int: + """Get the number of items: len(context)""" + return len(self.data) + + def __iter__(self) -> Generator[tuple[str, Any], None, None]: + """Iterate over keys: for key in context""" + for key in self.data: + yield (key, self.data[key]) + + def __str__(self) -> str: + """String representation of context variables.""" + return f"ContextVariables({self.data})" + + def __repr__(self) -> str: + """Detailed representation of context variables.""" + return f"ContextVariables(data={self.data!r})" + + # Utility methods + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ContextVariables": + """ + Create a new ContextVariables instance from a dictionary. + + E.g.: + my_context = {"user_id": "12345", "settings": {"theme": "dark"}} + context = ContextVariables.from_dict(my_context) + + Args: + data: Dictionary of key-value pairs + + Returns: + New ContextVariables instance + """ + return cls(data=data) diff --git a/mm_agents/coact/autogen/agentchat/group/group_tool_executor.py b/mm_agents/coact/autogen/agentchat/group/group_tool_executor.py new file mode 100644 index 0000000..1f50e09 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/group/group_tool_executor.py @@ -0,0 +1,202 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +import inspect +from copy import deepcopy +from typing import Annotated, Any, Callable, Optional + +from ...oai import OpenAIWrapper +from ...tools import Depends, Tool +from ...tools.dependency_injection import inject_params, on +from ..agent import Agent +from ..conversable_agent import ConversableAgent +from .context_variables import __CONTEXT_VARIABLES_PARAM_NAME__, ContextVariables +from .reply_result import ReplyResult +from .targets.transition_target import TransitionTarget + +__TOOL_EXECUTOR_NAME__ = "_Group_Tool_Executor" + + +class GroupToolExecutor(ConversableAgent): + """Tool executor for the group chat initiated with initiate_group_chat""" + + def __init__(self) -> None: + super().__init__( + name=__TOOL_EXECUTOR_NAME__, + system_message="Tool Execution, do not use this agent directly.", + human_input_mode="NEVER", + code_execution_config=False, + ) + + # Store the next target from a tool call + self._group_next_target: Optional[TransitionTarget] = None + + # Primary tool reply function for handling the tool reply and the ReplyResult and TransitionTarget returns + self.register_reply([Agent, None], self._generate_group_tool_reply, remove_other_reply_funcs=True) + + def set_next_target(self, next_target: TransitionTarget) -> None: + """Sets the next target to transition to, used in the determine_next_agent function.""" + self._group_next_target = next_target + + def get_next_target(self) -> TransitionTarget: + """Gets the next target to transition to.""" + """Returns the next target to transition to, if it exists.""" + if self._group_next_target is None: + raise ValueError( + "No next target set. Please set a next target before calling this method. Use has_next_target() to check if a next target exists." + ) + return self._group_next_target + + def has_next_target(self) -> bool: + """Checks if there is a next target to transition to.""" + return self._group_next_target is not None + + def clear_next_target(self) -> None: + """Clears the next target to transition to.""" + self._group_next_target = None + + def _modify_context_variables_param( + self, f: Callable[..., Any], context_variables: ContextVariables + ) -> Callable[..., Any]: + """Modifies the context_variables parameter to use dependency injection and link it to the group context variables. + + This essentially changes: + def some_function(some_variable: int, context_variables: ContextVariables) -> str: + + to: + + def some_function(some_variable: int, context_variables: Annotated[ContextVariables, Depends(on(self.context_variables))]) -> str: + """ + sig = inspect.signature(f) + + # Check if context_variables parameter exists and update it if so + if __CONTEXT_VARIABLES_PARAM_NAME__ in sig.parameters: + new_params = [] + for name, param in sig.parameters.items(): + if name == __CONTEXT_VARIABLES_PARAM_NAME__: + # Replace with new annotation using Depends + new_param = param.replace(annotation=Annotated[ContextVariables, Depends(on(context_variables))]) + new_params.append(new_param) + else: + new_params.append(param) + + # Update signature + new_sig = sig.replace(parameters=new_params) + f.__signature__ = new_sig # type: ignore[attr-defined] + + return f + + def _change_tool_context_variables_to_depends( + self, agent: ConversableAgent, current_tool: Tool, context_variables: ContextVariables + ) -> None: + """Checks for the context_variables parameter in the tool and updates it to use dependency injection.""" + + # If the tool has a context_variables parameter, remove the tool and reregister it without the parameter + if __CONTEXT_VARIABLES_PARAM_NAME__ in current_tool.tool_schema["function"]["parameters"]["properties"]: + # We'll replace the tool, so start with getting the underlying function + tool_func = current_tool._func + + # Remove the Tool from the agent + name = current_tool._name + description = current_tool._description + agent.remove_tool_for_llm(current_tool) + + # Recreate the tool without the context_variables parameter + tool_func = self._modify_context_variables_param(current_tool._func, context_variables) + tool_func = inject_params(tool_func) + new_tool = ConversableAgent._create_tool_if_needed( + func_or_tool=tool_func, name=name, description=description + ) + + # Re-register with the agent + agent.register_for_llm()(new_tool) + + def register_agents_functions(self, agents: list[ConversableAgent], context_variables: ContextVariables) -> None: + """Adds the functions of the agents to the group tool executor.""" + for agent in agents: + # As we're moving towards tools and away from function maps, this may not be used + self._function_map.update(agent._function_map) + + # Update any agent tools that have context_variables parameters to use Dependency Injection + for tool in agent.tools: + self._change_tool_context_variables_to_depends(agent, tool, context_variables) + + # Add all tools to the Tool Executor agent + for tool in agent.tools: + self.register_for_execution(serialize=False, silent_override=True)(tool) + + def _generate_group_tool_reply( + self, + agent: ConversableAgent, + messages: Optional[list[dict[str, Any]]] = None, + sender: Optional[Agent] = None, + config: Optional[OpenAIWrapper] = None, + ) -> tuple[bool, Optional[dict[str, Any]]]: + """Pre-processes and generates tool call replies. + + This function: + 1. Adds context_variables back to the tool call for the function, if necessary. + 2. Generates the tool calls reply. + 3. Updates context_variables and next_agent based on the tool call response.""" + + if config is None: + config = agent # type: ignore[assignment] + if messages is None: + messages = agent._oai_messages[sender] + + message = messages[-1] + if "tool_calls" in message: + tool_call_count = len(message["tool_calls"]) + + # Loop through tool calls individually (so context can be updated after each function call) + next_target: Optional[TransitionTarget] = None + tool_responses_inner = [] + contents = [] + for index in range(tool_call_count): + message_copy = deepcopy(message) + + # 1. add context_variables to the tool call arguments + tool_call = message_copy["tool_calls"][index] + + # Ensure we are only executing the one tool at a time + message_copy["tool_calls"] = [tool_call] + + # 2. generate tool calls reply + _, tool_message = agent.generate_tool_calls_reply([message_copy]) + + if tool_message is None: + raise ValueError("Tool call did not return a message") + + # 3. update context_variables and next_agent, convert content to string + for tool_response in tool_message["tool_responses"]: + content = tool_response.get("content") + + # Tool Call returns that are a target are either a ReplyResult or a TransitionTarget are the next agent + if isinstance(content, ReplyResult): + if content.context_variables and content.context_variables.to_dict() != {}: + agent.context_variables.update(content.context_variables.to_dict()) + if content.target is not None: + next_target = content.target + elif isinstance(content, TransitionTarget): + next_target = content + + # Serialize the content to a string + if content is not None: + tool_response["content"] = str(content) + + tool_responses_inner.append(tool_response) + contents.append(str(tool_response["content"])) + + self._group_next_target = next_target # type: ignore[attr-defined] + + # Put the tool responses and content strings back into the response message + # Caters for multiple tool calls + if tool_message is None: + raise ValueError("Tool call did not return a message") + + tool_message["tool_responses"] = tool_responses_inner + tool_message["content"] = "\n".join(contents) + + return True, tool_message + return False, None diff --git a/mm_agents/coact/autogen/agentchat/group/group_utils.py b/mm_agents/coact/autogen/agentchat/group/group_utils.py new file mode 100644 index 0000000..beef771 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/group/group_utils.py @@ -0,0 +1,636 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +import copy +from functools import partial +from types import MethodType +from typing import TYPE_CHECKING, Any, Callable, Optional, Union + +from ..agent import Agent +from ..groupchat import GroupChat, GroupChatManager +from .context_variables import ContextVariables +from .group_tool_executor import GroupToolExecutor +from .targets.group_manager_target import GroupManagerTarget +from .targets.transition_target import ( + AgentNameTarget, + AgentTarget, + TransitionTarget, +) + +if TYPE_CHECKING: + from ..conversable_agent import ConversableAgent + +# Utility functions for group chat preparation and management +# These are extracted from multi_agent_chat.py to avoid circular imports + + +def update_conditional_functions(agent: "ConversableAgent", messages: list[dict[str, Any]]) -> None: + """Updates the agent's functions based on the OnCondition's available condition. + + All functions are removed and then added back if they are available + """ + for on_condition in agent.handoffs.llm_conditions: + is_available = on_condition.available.is_available(agent, messages) if on_condition.available else True + + # Remove it from their tools + for tool in agent.tools: + if tool.name == on_condition.llm_function_name: + agent.remove_tool_for_llm(tool) + break + + # then add the function if it is available, so that the function signature is updated + if is_available: + agent._add_single_function( + _create_on_condition_handoff_function(on_condition.target), + on_condition.llm_function_name, + on_condition.condition.get_prompt(agent, messages), + ) + + +def establish_group_agent(agent: "ConversableAgent") -> None: + """Establish the group agent with the group-related attributes and hooks. Not for the tool executor. + + Args: + agent ("ConversableAgent"): The agent to establish as a group agent. + """ + + def _group_agent_str(self: "ConversableAgent") -> str: + """Customise the __str__ method to show the agent name for transition messages.""" + return f"Group agent --> {self.name}" + + # Register the hook to update agent state (except tool executor) + agent.register_hook("update_agent_state", update_conditional_functions) + + # Register a reply function to run Python function-based OnContextConditions before any other reply function + agent.register_reply(trigger=([Agent, None]), reply_func=_run_oncontextconditions, position=0) + + agent._get_display_name = MethodType(_group_agent_str, agent) # type: ignore[method-assign] + + # Mark this agent as established as a group agent + agent._group_is_established = True # type: ignore[attr-defined] + + +def link_agents_to_group_manager(agents: list[Agent], group_chat_manager: Agent) -> None: + """Link all agents to the GroupChatManager so they can access the underlying GroupChat and other agents. + + This is primarily used so that agents can get to the tool executor to help set the next agent. + + Does not link the Tool Executor agent. + """ + for agent in agents: + agent._group_manager = group_chat_manager # type: ignore[attr-defined] + + +def _evaluate_after_works_conditions( + agent: "ConversableAgent", + groupchat: GroupChat, + user_agent: Optional["ConversableAgent"], +) -> Optional[Union[Agent, str]]: + """Evaluate after_works context conditions for an agent. + + Args: + agent: The agent to evaluate after_works conditions for + groupchat: The current group chat + user_agent: Optional user proxy agent + + Returns: + The resolved speaker selection result if a condition matches, None otherwise + """ + if not hasattr(agent, "handoffs") or not agent.handoffs.after_works: # type: ignore[attr-defined] + return None + + for after_work_condition in agent.handoffs.after_works: # type: ignore[attr-defined] + # Check if condition is available + is_available = ( + after_work_condition.available.is_available(agent, groupchat.messages) + if after_work_condition.available + else True + ) + + # Evaluate the condition (None condition means always true) + if is_available and ( + after_work_condition.condition is None or after_work_condition.condition.evaluate(agent.context_variables) + ): + # Condition matched, resolve and return + return after_work_condition.target.resolve( + groupchat, + agent, + user_agent, + ).get_speaker_selection_result(groupchat) + + return None + + +def _run_oncontextconditions( + agent: "ConversableAgent", + messages: Optional[list[dict[str, Any]]] = None, + sender: Optional[Agent] = None, + config: Optional[Any] = None, +) -> tuple[bool, Optional[Union[str, dict[str, Any]]]]: + """Run OnContextConditions for an agent before any other reply function.""" + for on_condition in agent.handoffs.context_conditions: # type: ignore[attr-defined] + is_available = ( + on_condition.available.is_available(agent, messages if messages else []) if on_condition.available else True + ) + + if is_available and ( + on_condition.condition is None or on_condition.condition.evaluate(agent.context_variables) + ): + # Condition has been met, we'll set the Tool Executor's next target + # attribute and that will be picked up on the next iteration when + # _determine_next_agent is called + for agent in agent._group_manager.groupchat.agents: # type: ignore[attr-defined] + if isinstance(agent, GroupToolExecutor): + agent.set_next_target(on_condition.target) + break + + transfer_name = on_condition.target.display_name() + + return True, "[Handing off to " + transfer_name + "]" + + return False, None + + +def _create_on_condition_handoff_function(target: TransitionTarget) -> Callable[[], TransitionTarget]: + """Creates a function that will be used by the tool call reply function when the condition is met. + + Args: + target (TransitionTarget): The target to transfer to. + + Returns: + Callable: The transfer function. + """ + + def transfer_to_target() -> TransitionTarget: + return target + + return transfer_to_target + + +def create_on_condition_handoff_functions(agent: "ConversableAgent") -> None: + """Creates the functions for the OnConditions so that the current tool handling works. + + Args: + agent ("ConversableAgent"): The agent to create the functions for. + """ + # Populate the function names for the handoffs + agent.handoffs.set_llm_function_names() + + # Create a function for each OnCondition + for on_condition in agent.handoffs.llm_conditions: + # Create a function that will be called when the condition is met + agent._add_single_function( + _create_on_condition_handoff_function(on_condition.target), + on_condition.llm_function_name, + on_condition.condition.get_prompt(agent, []), + ) + + +def ensure_handoff_agents_in_group(agents: list["ConversableAgent"]) -> None: + """Ensure the agents in handoffs are in the group chat.""" + agent_names = [agent.name for agent in agents] + for agent in agents: + for llm_conditions in agent.handoffs.llm_conditions: + if ( + isinstance(llm_conditions.target, (AgentTarget, AgentNameTarget)) + and llm_conditions.target.agent_name not in agent_names + ): + raise ValueError("Agent in OnCondition Hand-offs must be in the agents list") + for context_conditions in agent.handoffs.context_conditions: + if ( + isinstance(context_conditions.target, (AgentTarget, AgentNameTarget)) + and context_conditions.target.agent_name not in agent_names + ): + raise ValueError("Agent in OnContextCondition Hand-offs must be in the agents list") + # Check after_works targets + for after_work_condition in agent.handoffs.after_works: + if ( + isinstance(after_work_condition.target, (AgentTarget, AgentNameTarget)) + and after_work_condition.target.agent_name not in agent_names + ): + raise ValueError("Agent in after work target Hand-offs must be in the agents list") + + +def prepare_exclude_transit_messages(agents: list["ConversableAgent"]) -> None: + """Preparation for excluding transit messages by getting all tool names and registering a hook on agents to remove those messages.""" + # get all transit functions names + to_be_removed: list[str] = [] + for agent in agents: + for on_condition in agent.handoffs.llm_conditions: + if on_condition.llm_function_name: + to_be_removed.append(on_condition.llm_function_name) + else: + raise ValueError("OnCondition must have a function name") + + remove_function = make_remove_function(to_be_removed) + + # register hook to remove transit messages for group agents + for agent in agents: + agent.register_hook("process_all_messages_before_reply", remove_function) + + +def prepare_group_agents( + agents: list["ConversableAgent"], + context_variables: ContextVariables, + exclude_transit_message: bool = True, +) -> tuple[GroupToolExecutor, list["ConversableAgent"]]: + """Validates agents, create the tool executor, wrap necessary targets in agents. + + Args: + agents (list["ConversableAgent"]): List of all agents in the conversation. + context_variables (ContextVariables): Context variables to assign to all agents. + exclude_transit_message (bool): Whether to exclude transit messages from the agents. + + Returns: + "ConversableAgent": The tool executor agent. + list["ConversableAgent"]: List of wrapped agents. + """ + # Initialise all agents as group agents + for agent in agents: + if not hasattr(agent, "_group_is_established"): + establish_group_agent(agent) + + # Ensure all agents in hand-off after-works are in the passed in agents list + ensure_handoff_agents_in_group(agents) + + # Create Tool Executor for the group + tool_execution = GroupToolExecutor() + + # Wrap handoff targets in agents that need to be wrapped + wrapped_chat_agents: list["ConversableAgent"] = [] + for agent in agents: + wrap_agent_handoff_targets(agent, wrapped_chat_agents) + + # Create the functions for the OnConditions so that the current tool handling works + for agent in agents: + create_on_condition_handoff_functions(agent) + + # Register all the agents' functions with the tool executor and + # use dependency injection for the context variables parameter + # Update tool execution agent with all the functions from all the agents + tool_execution.register_agents_functions(agents + wrapped_chat_agents, context_variables) + + if exclude_transit_message: + prepare_exclude_transit_messages(agents) + + return tool_execution, wrapped_chat_agents + + +def wrap_agent_handoff_targets(agent: "ConversableAgent", wrapped_agent_list: list["ConversableAgent"]) -> None: + """Wrap handoff targets in agents that need to be wrapped to be part of the group chat. + + Example is NestedChatTarget. + + Args: + agent ("ConversableAgent"): The agent to wrap the handoff targets for. + wrapped_agent_list (list["ConversableAgent"]): List of wrapped chat agents that will be appended to. + """ + # Wrap OnCondition targets + for i, handoff_oncondition_requiring_wrapping in enumerate(agent.handoffs.get_llm_conditions_requiring_wrapping()): + # Create wrapper agent + wrapper_agent = handoff_oncondition_requiring_wrapping.target.create_wrapper_agent(parent_agent=agent, index=i) + wrapped_agent_list.append(wrapper_agent) + + # Change this handoff target to point to the newly created agent + handoff_oncondition_requiring_wrapping.target = AgentTarget(wrapper_agent) + + for i, handoff_oncontextcondition_requiring_wrapping in enumerate( + agent.handoffs.get_context_conditions_requiring_wrapping() + ): + # Create wrapper agent + wrapper_agent = handoff_oncontextcondition_requiring_wrapping.target.create_wrapper_agent( + parent_agent=agent, index=i + ) + wrapped_agent_list.append(wrapper_agent) + + # Change this handoff target to point to the newly created agent + handoff_oncontextcondition_requiring_wrapping.target = AgentTarget(wrapper_agent) + + +def process_initial_messages( + messages: Union[list[dict[str, Any]], str], + user_agent: Optional["ConversableAgent"], + agents: list["ConversableAgent"], + wrapped_agents: list["ConversableAgent"], +) -> tuple[list[dict[str, Any]], Optional["ConversableAgent"], list[str], list[Agent]]: + """Process initial messages, validating agent names against messages, and determining the last agent to speak. + + Args: + messages: Initial messages to process. + user_agent: Optional user proxy agent passed in to a_/initiate_group_chat. + agents: Agents in the group. + wrapped_agents: List of wrapped agents. + + Returns: + list[dict[str, Any]]: Processed message(s). + Agent: Last agent to speak. + list[str]: List of agent names. + list[Agent]: List of temporary user proxy agents to add to GroupChat. + """ + from ..conversable_agent import ConversableAgent # NEED SOLUTION + + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + + group_agent_names = [agent.name for agent in agents + wrapped_agents] + + # If there's only one message and there's no identified group agent + # Start with a user proxy agent, creating one if they haven't passed one in + last_agent: Optional[ConversableAgent] + temp_user_proxy: Optional[ConversableAgent] = None + temp_user_list: list[Agent] = [] + if len(messages) == 1 and "name" not in messages[0] and not user_agent: + temp_user_proxy = ConversableAgent(name="_User", code_execution_config=False, human_input_mode="ALWAYS") + last_agent = temp_user_proxy + temp_user_list.append(temp_user_proxy) + else: + last_message = messages[0] + if "name" in last_message: + if last_message["name"] in group_agent_names: + last_agent = next(agent for agent in agents + wrapped_agents if agent.name == last_message["name"]) # type: ignore[assignment] + elif user_agent and last_message["name"] == user_agent.name: + last_agent = user_agent + else: + raise ValueError(f"Invalid group agent name in last message: {last_message['name']}") + else: + last_agent = user_agent if user_agent else temp_user_proxy + + return messages, last_agent, group_agent_names, temp_user_list + + +def setup_context_variables( + tool_execution: "ConversableAgent", + agents: list["ConversableAgent"], + manager: GroupChatManager, + user_agent: Optional["ConversableAgent"], + context_variables: ContextVariables, +) -> None: + """Assign a common context_variables reference to all agents in the group, including the tool executor, group chat manager, and user proxy agent. + + Args: + tool_execution: The tool execution agent. + agents: List of all agents in the conversation. + manager: GroupChatManager instance. + user_agent: Optional user proxy agent. + context_variables: Context variables to assign to all agents. + """ + for agent in agents + [tool_execution] + [manager] + ([user_agent] if user_agent else []): + agent.context_variables = context_variables + + +def cleanup_temp_user_messages(chat_result: Any) -> None: + """Remove temporary user proxy agent name from messages before returning. + + Args: + chat_result: ChatResult instance. + """ + for message in chat_result.chat_history: + if "name" in message and message["name"] == "_User": + del message["name"] + + +def get_last_agent_speaker( + groupchat: GroupChat, group_agent_names: list[str], tool_executor: GroupToolExecutor +) -> Agent: + """Get the last group agent from the group chat messages. Not including the tool executor.""" + last_group_speaker = None + for message in reversed(groupchat.messages): + if "name" in message and message["name"] in group_agent_names and message["name"] != tool_executor.name: + agent = groupchat.agent_by_name(name=message["name"]) + if agent: + last_group_speaker = agent + break + if last_group_speaker is None: + raise ValueError("No group agent found in the message history") + + return last_group_speaker + + +def determine_next_agent( + last_speaker: "ConversableAgent", + groupchat: GroupChat, + initial_agent: "ConversableAgent", + use_initial_agent: bool, + tool_executor: GroupToolExecutor, + group_agent_names: list[str], + user_agent: Optional["ConversableAgent"], + group_after_work: TransitionTarget, +) -> Optional[Union[Agent, str]]: + """Determine the next agent in the conversation. + + Args: + last_speaker ("ConversableAgent"): The last agent to speak. + groupchat (GroupChat): GroupChat instance. + initial_agent ("ConversableAgent"): The initial agent in the conversation. + use_initial_agent (bool): Whether to use the initial agent straight away. + tool_executor ("ConversableAgent"): The tool execution agent. + group_agent_names (list[str]): List of agent names. + user_agent (UserProxyAgent): Optional user proxy agent. + group_after_work (TransitionTarget): Group-level Transition option when an agent doesn't select the next agent. + + Returns: + Optional[Union[Agent, str]]: The next agent or speaker selection method. + """ + + # Logic for determining the next target (anything based on Transition Target: an agent, wrapped agent, TerminateTarget, StayTarget, RevertToUserTarget, GroupManagerTarget, etc. + # 1. If it's the first response -> initial agent + # 2. If the last message is a tool call -> tool execution agent + # 3. If the Tool Executor has determined a next target (e.g. ReplyResult specified target) -> transition to tool reply target + # 4. If the user last spoke -> return to the previous agent + # NOW "AFTER WORK": + # 5. Get the After Work condition (if the agent doesn't have one, get the group-level one) + # 6. Resolve and return the After Work condition -> agent / wrapped agent / TerminateTarget / StayTarget / RevertToUserTarget / GroupManagerTarget / etc. + + # 1. If it's the first response, return the initial agent + if use_initial_agent: + return initial_agent + + # 2. If the last message is a tool call, return the tool execution agent + if "tool_calls" in groupchat.messages[-1]: + return tool_executor + + # 3. If the Tool Executor has determined a next target, return that + if tool_executor.has_next_target(): + next_agent = tool_executor.get_next_target() + tool_executor.clear_next_target() + + if next_agent.can_resolve_for_speaker_selection(): + return next_agent.resolve(groupchat, last_speaker, user_agent).get_speaker_selection_result(groupchat) + else: + raise ValueError( + "Tool Executor next target must be a valid TransitionTarget that can resolve for speaker selection." + ) + + # get the last group agent + last_agent_speaker = get_last_agent_speaker(groupchat, group_agent_names, tool_executor) + + # If we are returning from a tool execution, return to the last agent that spoke + if groupchat.messages[-1]["role"] == "tool": + return last_agent_speaker + + # If the user last spoke, return to the agent prior to them (if they don't have an after work, otherwise it's treated like any other agent) + if user_agent and last_speaker == user_agent: + if not user_agent.handoffs.after_works: + return last_agent_speaker + else: + last_agent_speaker = user_agent + + # AFTER WORK: + + # First, try to evaluate after_works context conditions + after_works_result = _evaluate_after_works_conditions( + last_agent_speaker, # type: ignore[arg-type] + groupchat, + user_agent, + ) + if after_works_result is not None: + return after_works_result + + # If no after_works conditions matched, use the group-level after_work + # Resolve the next agent, termination, or speaker selection method + resolved_speaker_selection_result = group_after_work.resolve( + groupchat, + last_agent_speaker, # type: ignore[arg-type] + user_agent, + ).get_speaker_selection_result(groupchat) + + return resolved_speaker_selection_result + + +def create_group_transition( + initial_agent: "ConversableAgent", + tool_execution: GroupToolExecutor, + group_agent_names: list[str], + user_agent: Optional["ConversableAgent"], + group_after_work: TransitionTarget, +) -> Callable[["ConversableAgent", GroupChat], Optional[Union[Agent, str]]]: + """Creates a transition function for group chat with enclosed state for the use_initial_agent. + + Args: + initial_agent ("ConversableAgent"): The first agent to speak + tool_execution (GroupToolExecutor): The tool execution agent + group_agent_names (list[str]): List of all agent names + user_agent (UserProxyAgent): Optional user proxy agent + group_after_work (TransitionTarget): Group-level after work + + Returns: + Callable[["ConversableAgent", GroupChat], Optional[Union[Agent, str]]]: The transition function + """ + # Create enclosed state, this will be set once per creation so will only be True on the first execution + # of group_transition + state = {"use_initial_agent": True} + + def group_transition(last_speaker: "ConversableAgent", groupchat: GroupChat) -> Optional[Union[Agent, str]]: + result = determine_next_agent( + last_speaker=last_speaker, + groupchat=groupchat, + initial_agent=initial_agent, + use_initial_agent=state["use_initial_agent"], + tool_executor=tool_execution, + group_agent_names=group_agent_names, + user_agent=user_agent, + group_after_work=group_after_work, + ) + state["use_initial_agent"] = False + return result + + return group_transition + + +def create_group_manager( + groupchat: GroupChat, + group_manager_args: Optional[dict[str, Any]], + agents: list["ConversableAgent"], + group_after_work: TransitionTarget, +) -> GroupChatManager: + """Create a GroupChatManager for the group chat utilising any arguments passed in and ensure an LLM Config exists if needed + + Args: + groupchat (GroupChat): The groupchat. + group_manager_args (dict[str, Any]): Group manager arguments to create the GroupChatManager. + agents (list["ConversableAgent"]): List of agents in the group to check handoffs and after work. + group_after_work (TransitionTarget): Group-level after work to check. + + Returns: + GroupChatManager: GroupChatManager instance. + """ + manager_args = (group_manager_args or {}).copy() + if "groupchat" in manager_args: + raise ValueError("'groupchat' cannot be specified in group_manager_args as it is set by initiate_group_chat") + manager = GroupChatManager(groupchat, **manager_args) + + # Ensure that our manager has an LLM Config if we have any GroupManagerTarget targets used + if manager.llm_config is False: + has_group_manager_target = False + + if isinstance(group_after_work, GroupManagerTarget): + # Check group after work + has_group_manager_target = True + else: + # Check agent hand-offs and after work + for agent in agents: + if ( + len(agent.handoffs.get_context_conditions_by_target_type(GroupManagerTarget)) > 0 + or len(agent.handoffs.get_llm_conditions_by_target_type(GroupManagerTarget)) > 0 + or any(isinstance(aw.target, GroupManagerTarget) for aw in agent.handoffs.after_works) + ): + has_group_manager_target = True + break + + if has_group_manager_target: + raise ValueError( + "The group manager doesn't have an LLM Config and it is required for any targets or after works using a GroupManagerTarget. Use the 'llm_config' in the group_manager_args parameter to specify the LLM Config for the group manager." + ) + + return manager + + +def make_remove_function(tool_msgs_to_remove: list[str]) -> Callable[[list[dict[str, Any]]], list[dict[str, Any]]]: + """Create a function to remove messages with tool calls from the messages list. + + The returned function can be registered as a hook to "process_all_messages_before_reply"" to remove messages with tool calls. + """ + + def remove_messages(messages: list[dict[str, Any]], tool_msgs_to_remove: list[str]) -> list[dict[str, Any]]: + copied = copy.deepcopy(messages) + new_messages = [] + removed_tool_ids = [] + for message in copied: + # remove tool calls + if message.get("tool_calls") is not None: + filtered_tool_calls = [] + for tool_call in message["tool_calls"]: + if tool_call.get("function") is not None and tool_call["function"]["name"] in tool_msgs_to_remove: + # remove + removed_tool_ids.append(tool_call["id"]) + else: + filtered_tool_calls.append(tool_call) + if len(filtered_tool_calls) > 0: + message["tool_calls"] = filtered_tool_calls + else: + del message["tool_calls"] + if ( + message.get("content") is None + or message.get("content") == "" + or message.get("content") == "None" + ): + continue # if no tool call and no content, skip this message + # else: keep the message with tool_calls removed + # remove corresponding tool responses + elif message.get("tool_responses") is not None: + filtered_tool_responses = [] + for tool_response in message["tool_responses"]: + if tool_response["tool_call_id"] not in removed_tool_ids: + filtered_tool_responses.append(tool_response) + + if len(filtered_tool_responses) > 0: + message["tool_responses"] = filtered_tool_responses + else: + continue + + new_messages.append(message) + + return new_messages + + return partial(remove_messages, tool_msgs_to_remove=tool_msgs_to_remove) diff --git a/mm_agents/coact/autogen/agentchat/group/handoffs.py b/mm_agents/coact/autogen/agentchat/group/handoffs.py new file mode 100644 index 0000000..4d62f2e --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/group/handoffs.py @@ -0,0 +1,320 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union, overload + +from pydantic import BaseModel, Field + +from .on_condition import OnCondition +from .on_context_condition import OnContextCondition +from .targets.transition_target import TransitionTarget + +__all__ = ["Handoffs"] + + +class Handoffs(BaseModel): + """ + Container for all handoff transition conditions of a ConversableAgent. + + Three types of conditions can be added, each with a different order and time of use: + 1. OnContextConditions (evaluated without an LLM) + 2. OnConditions (evaluated with an LLM) + 3. After work TransitionTarget (if no other transition is triggered) + + Supports method chaining: + agent.handoffs.add_context_conditions([condition1]) \ + .add_llm_condition(condition2) \ + .set_after_work(after_work) + """ + + context_conditions: list[OnContextCondition] = Field(default_factory=list) + llm_conditions: list[OnCondition] = Field(default_factory=list) + after_works: list[OnContextCondition] = Field(default_factory=list) + + def add_context_condition(self, condition: OnContextCondition) -> "Handoffs": + """ + Add a single context condition. + + Args: + condition: The OnContextCondition to add + + Returns: + Self for method chaining + """ + # Validate that it is an OnContextCondition + if not isinstance(condition, OnContextCondition): + raise TypeError(f"Expected an OnContextCondition instance, got {type(condition).__name__}") + + self.context_conditions.append(condition) + return self + + def add_context_conditions(self, conditions: list[OnContextCondition]) -> "Handoffs": + """ + Add multiple context conditions. + + Args: + conditions: List of OnContextConditions to add + + Returns: + Self for method chaining + """ + # Validate that it is a list of OnContextConditions + if not all(isinstance(condition, OnContextCondition) for condition in conditions): + raise TypeError("All conditions must be of type OnContextCondition") + + self.context_conditions.extend(conditions) + return self + + def add_llm_condition(self, condition: OnCondition) -> "Handoffs": + """ + Add a single LLM condition. + + Args: + condition: The OnCondition to add + + Returns: + Self for method chaining + """ + # Validate that it is an OnCondition + if not isinstance(condition, OnCondition): + raise TypeError(f"Expected an OnCondition instance, got {type(condition).__name__}") + + self.llm_conditions.append(condition) + return self + + def add_llm_conditions(self, conditions: list[OnCondition]) -> "Handoffs": + """ + Add multiple LLM conditions. + + Args: + conditions: List of OnConditions to add + + Returns: + Self for method chaining + """ + # Validate that it is a list of OnConditions + if not all(isinstance(condition, OnCondition) for condition in conditions): + raise TypeError("All conditions must be of type OnCondition") + + self.llm_conditions.extend(conditions) + return self + + def set_after_work(self, target: TransitionTarget) -> "Handoffs": + """ + Set the after work target (replaces all after_works with single entry). + + For backward compatibility, this creates an OnContextCondition with no condition (always true). + + Args: + target: The after work TransitionTarget to set + + Returns: + Self for method chaining + """ + if not isinstance(target, TransitionTarget): + raise TypeError(f"Expected a TransitionTarget instance, got {type(target).__name__}") + + # Create OnContextCondition with no condition (always true) + after_work_condition = OnContextCondition(target=target, condition=None) + self.after_works = [after_work_condition] + return self + + def add_after_work(self, condition: OnContextCondition) -> "Handoffs": + """ + Add a single after-work condition. + + If the condition has condition=None, it will replace any existing + condition=None entry and be placed at the end. + + Args: + condition: The OnContextCondition to add + + Returns: + Self for method chaining + """ + if not isinstance(condition, OnContextCondition): + raise TypeError(f"Expected an OnContextCondition instance, got {type(condition).__name__}") + + if condition.condition is None: + # Remove any existing condition=None entries + self.after_works = [c for c in self.after_works if c.condition is not None] + # Add the new one at the end + self.after_works.append(condition) + else: + # For regular conditions, check if we need to move condition=None to the end + none_conditions = [c for c in self.after_works if c.condition is None] + if none_conditions: + # Remove the None condition temporarily + self.after_works = [c for c in self.after_works if c.condition is not None] + # Add the new regular condition + self.after_works.append(condition) + # Re-add the None condition at the end + self.after_works.append(none_conditions[0]) + else: + # No None condition exists, just append + self.after_works.append(condition) + + return self + + def add_after_works(self, conditions: list[OnContextCondition]) -> "Handoffs": + """ + Add multiple after-work conditions. + + Special handling for condition=None entries: + - Only one condition=None entry is allowed (the fallback) + - It will always be placed at the end of the list + - If multiple condition=None entries are provided, only the last one is kept + + Args: + conditions: List of OnContextConditions to add + + Returns: + Self for method chaining + """ + # Validate that it is a list of OnContextConditions + if not all(isinstance(condition, OnContextCondition) for condition in conditions): + raise TypeError("All conditions must be of type OnContextCondition") + + # Separate conditions with None and without None + none_conditions = [c for c in conditions if c.condition is None] + regular_conditions = [c for c in conditions if c.condition is not None] + + # Remove any existing condition=None entries + self.after_works = [c for c in self.after_works if c.condition is not None] + + # Add regular conditions + self.after_works.extend(regular_conditions) + + # Add at most one None condition at the end + if none_conditions: + self.after_works.append(none_conditions[-1]) # Use the last one if multiple provided + + return self + + @overload + def add(self, condition: OnContextCondition) -> "Handoffs": ... + + @overload + def add(self, condition: OnCondition) -> "Handoffs": ... + + def add(self, condition: Union[OnContextCondition, OnCondition]) -> "Handoffs": + """ + Add a single condition (OnContextCondition or OnCondition). + + Args: + condition: The condition to add (OnContextCondition or OnCondition) + + Raises: + TypeError: If the condition type is not supported + + Returns: + Self for method chaining + """ + # This add method is a helper method designed to make it easier for + # adding handoffs without worrying about the specific type. + if isinstance(condition, OnContextCondition): + return self.add_context_condition(condition) + elif isinstance(condition, OnCondition): + return self.add_llm_condition(condition) + else: + raise TypeError(f"Unsupported condition type: {type(condition).__name__}") + + def add_many(self, conditions: list[Union[OnContextCondition, OnCondition]]) -> "Handoffs": + """ + Add multiple conditions of any supported types (OnContextCondition and OnCondition). + + Args: + conditions: List of conditions to add + + Raises: + TypeError: If an unsupported condition type is provided + + Returns: + Self for method chaining + """ + # This add_many method is a helper method designed to make it easier for + # adding handoffs without worrying about the specific type. + context_conditions = [] + llm_conditions = [] + + for condition in conditions: + if isinstance(condition, OnContextCondition): + context_conditions.append(condition) + elif isinstance(condition, OnCondition): + llm_conditions.append(condition) + else: + raise TypeError(f"Unsupported condition type: {type(condition).__name__}") + + if context_conditions: + self.add_context_conditions(context_conditions) + if llm_conditions: + self.add_llm_conditions(llm_conditions) + + return self + + def clear(self) -> "Handoffs": + """ + Clear all handoff conditions. + + Returns: + Self for method chaining + """ + self.context_conditions.clear() + self.llm_conditions.clear() + self.after_works.clear() + return self + + def get_llm_conditions_by_target_type(self, target_type: type) -> list[OnCondition]: + """ + Get OnConditions for a specific target type. + + Args: + target_type: The type of condition to retrieve + + Returns: + List of conditions of the specified type, or None if none exist + """ + return [on_condition for on_condition in self.llm_conditions if on_condition.has_target_type(target_type)] + + def get_context_conditions_by_target_type(self, target_type: type) -> list[OnContextCondition]: + """ + Get OnContextConditions for a specific target type. + + Args: + target_type: The type of condition to retrieve + + Returns: + List of conditions of the specified type, or None if none exist + """ + return [ + on_context_condition + for on_context_condition in self.context_conditions + if on_context_condition.has_target_type(target_type) + ] + + def get_llm_conditions_requiring_wrapping(self) -> list[OnCondition]: + """ + Get LLM conditions that have targets that require wrapping. + + Returns: + List of LLM conditions that require wrapping + """ + return [condition for condition in self.llm_conditions if condition.target_requires_wrapping()] + + def get_context_conditions_requiring_wrapping(self) -> list[OnContextCondition]: + """ + Get context conditions that have targets that require wrapping. + + Returns: + List of context conditions that require wrapping + """ + return [condition for condition in self.context_conditions if condition.target_requires_wrapping()] + + def set_llm_function_names(self) -> None: + """ + Set the LLM function names for all LLM conditions, creating unique names for each function. + """ + for i, condition in enumerate(self.llm_conditions): + # Function names are made unique and allow multiple OnCondition's to the same agent + condition.llm_function_name = f"transfer_to_{condition.target.normalized_name()}_{i + 1}" diff --git a/mm_agents/coact/autogen/agentchat/group/llm_condition.py b/mm_agents/coact/autogen/agentchat/group/llm_condition.py new file mode 100644 index 0000000..8bc13fe --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/group/llm_condition.py @@ -0,0 +1,93 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING, Any + +from pydantic import BaseModel + +from .context_str import ContextStr + +if TYPE_CHECKING: + # Avoid circular import + from ..conversable_agent import ConversableAgent + +__all__ = ["ContextStrLLMCondition", "LLMCondition", "StringLLMCondition"] + + +class LLMCondition(BaseModel): + """Protocol for conditions evaluated by an LLM.""" + + def get_prompt(self, agent: "ConversableAgent", messages: list[dict[str, Any]]) -> str: + """Get the prompt text for LLM evaluation. + + Args: + agent: The agent evaluating the condition + messages: The conversation history + + Returns: + The prompt text to be evaluated by the LLM + """ + raise NotImplementedError("Requires subclasses to implement.") + + +class StringLLMCondition(LLMCondition): + """Simple string-based LLM condition. + + This condition provides a static string prompt to be evaluated by an LLM. + """ + + prompt: str + + def __init__(self, prompt: str, **data: Any) -> None: + """Initialize with a prompt string as a positional parameter. + + Args: + prompt: The static prompt string to evaluate + data: Additional data for the parent class + """ + super().__init__(prompt=prompt, **data) + + def get_prompt(self, agent: "ConversableAgent", messages: list[dict[str, Any]]) -> str: + """Return the static prompt string. + + Args: + agent: The agent evaluating the condition (not used) + messages: The conversation history (not used) + + Returns: + The static prompt string + """ + return self.prompt + + +class ContextStrLLMCondition(LLMCondition): + """Context variable-based LLM condition. + + This condition uses a ContextStr object with context variable placeholders that + will be substituted before being evaluated by an LLM. + """ + + context_str: ContextStr + + def __init__(self, context_str: ContextStr, **data: Any) -> None: + """Initialize with a context string as a positional parameter. + + Args: + context_str: The ContextStr object with variable placeholders + data: Additional data for the parent class + """ + super().__init__(context_str=context_str, **data) + + def get_prompt(self, agent: "ConversableAgent", messages: list[dict[str, Any]]) -> str: + """Return the prompt with context variables substituted. + + Args: + agent: The agent evaluating the condition (provides context variables) + messages: The conversation history (not used) + + Returns: + The prompt with context variables substituted + """ + result = self.context_str.format(agent.context_variables) + return result if result is not None else "" diff --git a/mm_agents/coact/autogen/agentchat/group/multi_agent_chat.py b/mm_agents/coact/autogen/agentchat/group/multi_agent_chat.py new file mode 100644 index 0000000..7e0a91c --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/group/multi_agent_chat.py @@ -0,0 +1,237 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import threading +from typing import TYPE_CHECKING, Any, Union + +from ...doc_utils import export_module +from ...events.agent_events import ErrorEvent, RunCompletionEvent +from ...io.base import IOStream +from ...io.run_response import AsyncRunResponse, AsyncRunResponseProtocol, RunResponse, RunResponseProtocol +from ...io.thread_io_stream import AsyncThreadIOStream, ThreadIOStream +from ..chat import ChatResult +from .context_variables import ContextVariables +from .group_utils import cleanup_temp_user_messages + +if TYPE_CHECKING: + from ..agent import Agent + from .patterns.pattern import Pattern + +__all__ = [ + "a_initiate_group_chat", + "a_run_group_chat", + "initiate_group_chat", + "run_group_chat", +] + + +@export_module("autogen") +def initiate_group_chat( + pattern: "Pattern", + messages: Union[list[dict[str, Any]], str], + max_rounds: int = 20, +) -> tuple[ChatResult, ContextVariables, "Agent"]: + """Initialize and run a group chat using a pattern for configuration. + + Args: + pattern: Pattern object that encapsulates the chat configuration. + messages: Initial message(s). + max_rounds: Maximum number of conversation rounds. + + Returns: + ChatResult: Conversations chat history. + ContextVariables: Updated Context variables. + "ConversableAgent": Last speaker. + """ + # Let the pattern prepare the group chat and all its components + # Only passing the necessary parameters that aren't already in the pattern + ( + _, # agents, + _, # wrapped_agents, + _, # user_agent, + context_variables, + _, # initial_agent, + _, # group_after_work, + _, # tool_execution, + _, # groupchat, + manager, + processed_messages, + last_agent, + _, # group_agent_names, + _, # temp_user_list, + ) = pattern.prepare_group_chat( + max_rounds=max_rounds, + messages=messages, + ) + + # Start or resume the conversation + if len(processed_messages) > 1: + last_agent, last_message = manager.resume(messages=processed_messages) + clear_history = False + else: + last_message = processed_messages[0] + clear_history = True + + if last_agent is None: + raise ValueError("No agent selected to start the conversation") + + chat_result = last_agent.initiate_chat( + manager, + message=last_message, + clear_history=clear_history, + summary_method=pattern.summary_method, + ) + + cleanup_temp_user_messages(chat_result) + + return chat_result, context_variables, manager.last_speaker + + +@export_module("autogen.agentchat") +async def a_initiate_group_chat( + pattern: "Pattern", + messages: Union[list[dict[str, Any]], str], + max_rounds: int = 20, +) -> tuple[ChatResult, ContextVariables, "Agent"]: + """Initialize and run a group chat using a pattern for configuration, asynchronously. + + Args: + pattern: Pattern object that encapsulates the chat configuration. + messages: Initial message(s). + max_rounds: Maximum number of conversation rounds. + + Returns: + ChatResult: Conversations chat history. + ContextVariables: Updated Context variables. + "ConversableAgent": Last speaker. + """ + # Let the pattern prepare the group chat and all its components + # Only passing the necessary parameters that aren't already in the pattern + ( + _, # agents, + _, # wrapped_agents, + _, # user_agent, + context_variables, + _, # initial_agent, + _, # group_after_work, + _, # tool_execution, + _, # groupchat, + manager, + processed_messages, + last_agent, + _, # group_agent_names, + _, # temp_user_list, + ) = pattern.prepare_group_chat( + max_rounds=max_rounds, + messages=messages, + ) + + # Start or resume the conversation + if len(processed_messages) > 1: + last_agent, last_message = await manager.a_resume(messages=processed_messages) + clear_history = False + else: + last_message = processed_messages[0] + clear_history = True + + if last_agent is None: + raise ValueError("No agent selected to start the conversation") + + chat_result = await last_agent.a_initiate_chat( + manager, + message=last_message, # type: ignore[arg-type] + clear_history=clear_history, + summary_method=pattern.summary_method, + ) + + cleanup_temp_user_messages(chat_result) + + return chat_result, context_variables, manager.last_speaker + + +@export_module("autogen.agentchat") +def run_group_chat( + pattern: "Pattern", + messages: Union[list[dict[str, Any]], str], + max_rounds: int = 20, +) -> RunResponseProtocol: + iostream = ThreadIOStream() + # todo: add agents + response = RunResponse(iostream, agents=[]) + + def _initiate_group_chat( + pattern: "Pattern" = pattern, + messages: Union[list[dict[str, Any]], str] = messages, + max_rounds: int = max_rounds, + iostream: ThreadIOStream = iostream, + response: RunResponse = response, + ) -> None: + with IOStream.set_default(iostream): + try: + chat_result, context_vars, agent = initiate_group_chat( + pattern=pattern, + messages=messages, + max_rounds=max_rounds, + ) + + IOStream.get_default().send( + RunCompletionEvent( # type: ignore[call-arg] + history=chat_result.chat_history, + summary=chat_result.summary, + cost=chat_result.cost, + last_speaker=agent.name, + context_variables=context_vars, + ) + ) + except Exception as e: + response.iostream.send(ErrorEvent(error=e)) # type: ignore[call-arg] + + threading.Thread( + target=_initiate_group_chat, + ).start() + + return response + + +@export_module("autogen.agentchat") +async def a_run_group_chat( + pattern: "Pattern", + messages: Union[list[dict[str, Any]], str], + max_rounds: int = 20, +) -> AsyncRunResponseProtocol: + iostream = AsyncThreadIOStream() + # todo: add agents + response = AsyncRunResponse(iostream, agents=[]) + + async def _initiate_group_chat( + pattern: "Pattern" = pattern, + messages: Union[list[dict[str, Any]], str] = messages, + max_rounds: int = max_rounds, + iostream: AsyncThreadIOStream = iostream, + response: AsyncRunResponse = response, + ) -> None: + with IOStream.set_default(iostream): + try: + chat_result, context_vars, agent = await a_initiate_group_chat( + pattern=pattern, + messages=messages, + max_rounds=max_rounds, + ) + + IOStream.get_default().send( + RunCompletionEvent( # type: ignore[call-arg] + history=chat_result.chat_history, + summary=chat_result.summary, + cost=chat_result.cost, + last_speaker=agent.name, + context_variables=context_vars, + ) + ) + except Exception as e: + response.iostream.send(ErrorEvent(error=e)) # type: ignore[call-arg] + + asyncio.create_task(_initiate_group_chat()) + + return response diff --git a/mm_agents/coact/autogen/agentchat/group/on_condition.py b/mm_agents/coact/autogen/agentchat/group/on_condition.py new file mode 100644 index 0000000..8360dc0 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/group/on_condition.py @@ -0,0 +1,58 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +from pydantic import BaseModel + +from ...doc_utils import export_module +from .available_condition import AvailableCondition +from .llm_condition import LLMCondition +from .targets.transition_target import TransitionTarget + +__all__ = [ + "OnCondition", +] + + +@export_module("autogen") +class OnCondition(BaseModel): # noqa: N801 + """Defines a condition for transitioning to another agent or nested chats. + + This is for LLM-based condition evaluation where these conditions are translated into tools and attached to the agent. + + These are evaluated after the OnCondition conditions but before the after work condition. + + Args: + target (TransitionTarget): The transition (essentially an agent) to hand off to. + condition (LLMCondition): The condition for transitioning to the target agent, evaluated by the LLM. + available (AvailableCondition): Optional condition to determine if this OnCondition is included for the LLM to evaluate based on context variables using classes like StringAvailableCondition and ContextExpressionAvailableCondition. + llm_function_name (Optional[str]): The name of the LLM function to use for this condition. + """ + + target: TransitionTarget + condition: LLMCondition + available: Optional[AvailableCondition] = None + llm_function_name: Optional[str] = None + + def has_target_type(self, target_type: type) -> bool: + """ + Check if the target type matches the specified type. + + Args: + target_type (type): The target type to check against, which should be a subclass of TransitionTarget + + Returns: + bool: True if the target type matches, False otherwise + """ + return isinstance(self.target, target_type) + + def target_requires_wrapping(self) -> bool: + """ + Check if the target requires wrapping in an agent. + + Returns: + bool: True if the target requires wrapping, False otherwise + """ + return self.target.needs_agent_wrapper() diff --git a/mm_agents/coact/autogen/agentchat/group/on_context_condition.py b/mm_agents/coact/autogen/agentchat/group/on_context_condition.py new file mode 100644 index 0000000..013f1c1 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/group/on_context_condition.py @@ -0,0 +1,54 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +from pydantic import BaseModel + +from .available_condition import AvailableCondition +from .context_condition import ContextCondition +from .targets.transition_target import TransitionTarget + +__all__ = [ + "OnContextCondition", +] + + +class OnContextCondition(BaseModel): # noqa: N801 + """Defines a condition for transitioning to another agent or nested chats using context variables and the ContextExpression class. + + This is for context variable-based condition evaluation (does not use the agent's LLM). + + These are evaluated before the OnCondition and after work conditions. + + Args: + target (TransitionTarget): The transition (essentially an agent) to hand off to. + condition (Optional[ContextCondition]): The context variable based condition for transitioning to the target agent. If None, the condition always evaluates to True. + available (AvailableCondition): Optional condition to determine if this OnCondition is included for the LLM to evaluate based on context variables using classes like StringAvailableCondition and ContextExpressionAvailableCondition. + """ + + target: TransitionTarget + condition: Optional[ContextCondition] = None + available: Optional[AvailableCondition] = None + + def has_target_type(self, target_type: type) -> bool: + """ + Check if the target type matches the specified type. + + Args: + target_type (type): The target type to check against. Should be a subclass of TransitionTarget. + + Returns: + bool: True if the target type matches, False otherwise + """ + return isinstance(self.target, target_type) + + def target_requires_wrapping(self) -> bool: + """ + Check if the target requires wrapping in an agent. + + Returns: + bool: True if the target requires wrapping, False otherwise + """ + return self.target.needs_agent_wrapper() diff --git a/mm_agents/coact/autogen/agentchat/group/patterns/__init__.py b/mm_agents/coact/autogen/agentchat/group/patterns/__init__.py new file mode 100644 index 0000000..a26451c --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/group/patterns/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 +# + +from .auto import AutoPattern +from .manual import ManualPattern +from .pattern import DefaultPattern +from .random import RandomPattern +from .round_robin import RoundRobinPattern + +__all__ = [ + "AutoPattern", + "DefaultPattern", + "ManualPattern", + "RandomPattern", + "RoundRobinPattern", +] diff --git a/mm_agents/coact/autogen/agentchat/group/patterns/auto.py b/mm_agents/coact/autogen/agentchat/group/patterns/auto.py new file mode 100644 index 0000000..f4a1d71 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/group/patterns/auto.py @@ -0,0 +1,159 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union + +from ..context_variables import ContextVariables +from ..targets.group_manager_target import GroupManagerSelectionMessage, GroupManagerTarget +from ..targets.transition_target import TransitionTarget +from .pattern import Pattern + +if TYPE_CHECKING: + from ...conversable_agent import ConversableAgent + from ...groupchat import GroupChat, GroupChatManager + from ..group_tool_executor import GroupToolExecutor + + +class AutoPattern(Pattern): + """AutoPattern implements a flexible pattern where agents are selected based on their expertise. + + In this pattern, a group manager automatically selects the next agent to speak based on the context + of the conversation and agent descriptions. The after_work is always set to "group_manager" as + this is the defining characteristic of this pattern. + """ + + def __init__( + self, + initial_agent: "ConversableAgent", + agents: list["ConversableAgent"], + user_agent: Optional["ConversableAgent"] = None, + group_manager_args: Optional[dict[str, Any]] = None, + context_variables: Optional[ContextVariables] = None, + selection_message: Optional[GroupManagerSelectionMessage] = None, + exclude_transit_message: bool = True, + summary_method: Optional[Union[str, Callable[..., Any]]] = "last_msg", + ): + """Initialize the AutoPattern. + + The after_work is always set to group_manager selection, which is the defining + characteristic of this pattern. You can customize the selection message used + by the group manager when selecting the next agent. + + Args: + initial_agent: The first agent to speak in the group chat. + agents: List of all agents participating in the chat. + user_agent: Optional user proxy agent. + group_manager_args: Optional arguments for the GroupChatManager. + context_variables: Initial context variables for the chat. + selection_message: Custom message to use when the group manager is selecting agents. + exclude_transit_message: Whether to exclude transit messages from the conversation. + summary_method: Method for summarizing the conversation. + """ + # Create the group_manager after_work with the provided selection message + group_manager_after_work = GroupManagerTarget(selection_message=selection_message) + + super().__init__( + initial_agent=initial_agent, + agents=agents, + user_agent=user_agent, + group_manager_args=group_manager_args, + context_variables=context_variables, + group_after_work=group_manager_after_work, + exclude_transit_message=exclude_transit_message, + summary_method=summary_method, + ) + + # Store the selection message for potential use + self.selection_message = selection_message + + def prepare_group_chat( + self, + max_rounds: int, + messages: Union[list[dict[str, Any]], str], + ) -> Tuple[ + list["ConversableAgent"], + list["ConversableAgent"], + Optional["ConversableAgent"], + ContextVariables, + "ConversableAgent", + TransitionTarget, + "GroupToolExecutor", + "GroupChat", + "GroupChatManager", + list[dict[str, Any]], + Any, + list[str], + list[Any], + ]: + """Prepare the group chat for organic agent selection. + + Ensures that: + 1. The group manager has a valid LLM config + 2. All agents have appropriate descriptions for the group manager to use + + Args: + max_rounds: Maximum number of conversation rounds. + messages: Initial message(s) to start the conversation. + + Returns: + Tuple containing all necessary components for the group chat. + """ + # Validate that group_manager_args has an LLM config which is required for this pattern + if not self.group_manager_args.get("llm_config", False): + # Check if any agent has an LLM config we can use + has_llm_config = any(getattr(agent, "llm_config", False) for agent in self.agents) + + if not has_llm_config: + raise ValueError( + "AutoPattern requires the group_manager_args to include an llm_config, " + "or at least one agent to have an llm_config" + ) + + # Check that all agents have descriptions for effective group manager selection + for agent in self.agents: + if not hasattr(agent, "description") or not agent.description: + agent.description = f"Agent {agent.name}" + + # Use the parent class's implementation to prepare the agents and group chat + components = super().prepare_group_chat( + max_rounds=max_rounds, + messages=messages, + ) + + # Extract the group_after_work and the rest of the components + ( + agents, + wrapped_agents, + user_agent, + context_variables, + initial_agent, + _, + tool_executor, + groupchat, + manager, + processed_messages, + last_agent, + group_agent_names, + temp_user_list, + ) = components + + # Ensure we're using the group_manager after_work + group_after_work = self.group_after_work + + # Return all components with our group_after_work + return ( + agents, + wrapped_agents, + user_agent, + context_variables, + initial_agent, + group_after_work, + tool_executor, + groupchat, + manager, + processed_messages, + last_agent, + group_agent_names, + temp_user_list, + ) diff --git a/mm_agents/coact/autogen/agentchat/group/patterns/manual.py b/mm_agents/coact/autogen/agentchat/group/patterns/manual.py new file mode 100644 index 0000000..3d9c90c --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/group/patterns/manual.py @@ -0,0 +1,176 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union + +from ..context_variables import ContextVariables +from ..group_tool_executor import GroupToolExecutor +from ..targets.transition_target import AskUserTarget, TransitionTarget +from .pattern import Pattern + +if TYPE_CHECKING: + from ...conversable_agent import ConversableAgent + from ...groupchat import GroupChat, GroupChatManager + + +class ManualPattern(Pattern): + """ManualPattern will ask the user to nominate the next agent to speak at each turn.""" + + def __init__( + self, + initial_agent: "ConversableAgent", + agents: list["ConversableAgent"], + user_agent: Optional["ConversableAgent"] = None, + group_manager_args: Optional[dict[str, Any]] = None, + context_variables: Optional[ContextVariables] = None, + exclude_transit_message: bool = True, + summary_method: Optional[Union[str, Callable[..., Any]]] = "last_msg", + ): + """Initialize the ManualPattern. + + The after_work is always set to ask_user, which will prompt the user for the next agent + + Args: + initial_agent: The first agent to speak in the group chat. + agents: List of all agents participating in the chat. + user_agent: Optional user proxy agent. + group_manager_args: Optional arguments for the GroupChatManager. + context_variables: Initial context variables for the chat. + exclude_transit_message: Whether to exclude transit messages from the conversation. + summary_method: Method for summarizing the conversation. + """ + # The group after work will be to ask the user + group_after_work = AskUserTarget() + + super().__init__( + initial_agent=initial_agent, + agents=agents, + user_agent=user_agent, + group_manager_args=group_manager_args, + context_variables=context_variables, + group_after_work=group_after_work, + exclude_transit_message=exclude_transit_message, + summary_method=summary_method, + ) + + def prepare_group_chat( + self, + max_rounds: int, + messages: Union[list[dict[str, Any]], str], + ) -> Tuple[ + list["ConversableAgent"], + list["ConversableAgent"], + Optional["ConversableAgent"], + ContextVariables, + "ConversableAgent", + TransitionTarget, + "GroupToolExecutor", + "GroupChat", + "GroupChatManager", + list[dict[str, Any]], + Any, + list[str], + list[Any], + ]: + """Prepare the group chat for organic agent selection. + + Ensures that: + 1. The group manager has a valid LLM config + 2. All agents have appropriate descriptions for the group manager to use + + Args: + max_rounds: Maximum number of conversation rounds. + messages: Initial message(s) to start the conversation. + + Returns: + Tuple containing all necessary components for the group chat. + """ + # Use the parent class's implementation to prepare the agents and group chat + components = super().prepare_group_chat( + max_rounds=max_rounds, + messages=messages, + ) + + # Extract the group_after_work and the rest of the components + ( + agents, + wrapped_agents, + user_agent, + context_variables, + initial_agent, + _, + tool_executor, + groupchat, + manager, + processed_messages, + last_agent, + group_agent_names, + temp_user_list, + ) = components + + # Ensure we're using the group_manager after_work + group_after_work = self.group_after_work + + # Set up the allowed speaker transitions to exclude user_agent and GroupToolExecutor + self._setup_allowed_transitions(groupchat, user_agent, tool_executor) + + # Return all components with our group_after_work + return ( + agents, + wrapped_agents, + user_agent, + context_variables, + initial_agent, + group_after_work, + tool_executor, + groupchat, + manager, + processed_messages, + last_agent, + group_agent_names, + temp_user_list, + ) + + def _setup_allowed_transitions( + self, groupchat: "GroupChat", user_agent: Optional["ConversableAgent"], tool_executor: "GroupToolExecutor" + ) -> None: + """Set up the allowed speaker transitions for the group chat so that when a user selects the next agent the tool executor and user agent don't appear as options. + + Creates transitions where: + 1. Any agent can speak after any other agent, including themselves + 2. The user_agent and GroupToolExecutor are excluded from transitions + + Args: + groupchat: The GroupChat instance to configure + user_agent: The user agent to exclude from transitions + tool_executor: The GroupToolExecutor to exclude from transitions + """ + # NOTE: THIS IS NOT WORKING - THE TRANSITIONS ARE NOT BEING KEPT?! + """ + # Get all agents in the group chat + all_agents = groupchat.agents + + # Filter out user_agent and group tool executor + eligible_agents = [] + for agent in all_agents: + # Skip user_agent + if agent == user_agent: + continue + + # Skip GroupToolExecutor + if isinstance(agent, GroupToolExecutor): + continue + + eligible_agents.append(agent) + + # Create a fully connected graph among eligible agents + # Each agent can be followed by any other eligible agent + allowed_transitions = {} + for agent in eligible_agents: + # For each agent, every other eligible agent can follow + allowed_transitions[agent] = eligible_agents + + # Set the transitions in the group chat + groupchat.allowed_speaker_transitions_dict = allowed_transitions + """ diff --git a/mm_agents/coact/autogen/agentchat/group/patterns/pattern.py b/mm_agents/coact/autogen/agentchat/group/patterns/pattern.py new file mode 100644 index 0000000..6c0d748 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/group/patterns/pattern.py @@ -0,0 +1,294 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +# Patterns of agent orchestrations +# Uses the group chat or the agents' handoffs to create a pattern + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union + +from ..context_variables import ContextVariables +from ..group_utils import ( + create_group_manager, + create_group_transition, + link_agents_to_group_manager, + prepare_group_agents, + process_initial_messages, + setup_context_variables, +) +from ..targets.transition_target import TerminateTarget, TransitionTarget + +if TYPE_CHECKING: + from ...agent import Agent + from ...conversable_agent import ConversableAgent + from ...groupchat import GroupChat, GroupChatManager + from ..group_tool_executor import GroupToolExecutor + + +class Pattern(ABC): + """Base abstract class for all orchestration patterns. + + Patterns provide a reusable way to define how agents interact within a group chat. + Each pattern encapsulates the logic for setting up agents, configuring handoffs, + and determining the flow of conversation. + + This is an abstract base class and should not be instantiated directly. + Use one of the concrete pattern implementations like AutoPattern, + RoundRobinPattern, RandomPattern, or ManualPattern. + """ + + def __init__( + self, + initial_agent: "ConversableAgent", + agents: list["ConversableAgent"], + user_agent: Optional["ConversableAgent"] = None, + group_manager_args: Optional[dict[str, Any]] = None, + context_variables: Optional[ContextVariables] = None, + group_after_work: Optional[TransitionTarget] = None, + exclude_transit_message: bool = True, + summary_method: Optional[Union[str, Callable[..., Any]]] = "last_msg", + ): + """Initialize the pattern with the required components. + + Args: + initial_agent: The first agent to speak in the group chat. + agents: List of all agents participating in the chat. + user_agent: Optional user proxy agent. + group_manager_args: Optional arguments for the GroupChatManager. + context_variables: Initial context variables for the chat. + group_after_work: Default after work transition behavior when no specific next agent is determined. + exclude_transit_message: Whether to exclude transit messages from the conversation. + summary_method: Method for summarizing the conversation. + """ + self.initial_agent = initial_agent + self.agents = agents + self.user_agent = user_agent + self.group_manager_args = group_manager_args or {} + self.context_variables = context_variables or ContextVariables() + self.group_after_work = group_after_work if group_after_work is not None else TerminateTarget() + self.exclude_transit_message = exclude_transit_message + self.summary_method = summary_method + + @abstractmethod + def prepare_group_chat( + self, + max_rounds: int, + messages: Union[list[dict[str, Any]], str], + ) -> Tuple[ + list["ConversableAgent"], + list["ConversableAgent"], + Optional["ConversableAgent"], + ContextVariables, + "ConversableAgent", + TransitionTarget, + "GroupToolExecutor", + "GroupChat", + "GroupChatManager", + list[dict[str, Any]], + "ConversableAgent", + list[str], + list["Agent"], + ]: + """Prepare the group chat for orchestration. + + This is the main method called by initiate_group_chat to set up the pattern. + Subclasses must implement or extend this method to define pattern-specific behavior. + + Args: + max_rounds: Maximum number of conversation rounds. + messages: Initial message(s) to start the conversation. + + Returns: + Tuple containing: + - List of agents involved in the group chat + - List of wrapped agents + - User agent, if applicable + - Context variables for the group chat + - Initial agent for the group chat + - Group-level after work transition for the group chat + - Tool executor for the group chat + - GroupChat instance + - GroupChatManager instance + - Processed messages + - Last agent to speak + - List of group agent names + - List of temporary user agents + """ + from ...groupchat import GroupChat + + # Prepare the agents using the existing helper function + tool_executor, wrapped_agents = prepare_group_agents( + self.agents, self.context_variables, self.exclude_transit_message + ) + + # Process the initial messages BEFORE creating the GroupChat + # This will create a temporary user agent if needed + processed_messages, last_agent, group_agent_names, temp_user_list = process_initial_messages( + messages, self.user_agent, self.agents, wrapped_agents + ) + + # Create transition function (has enclosed state for initial agent) + group_transition = create_group_transition( + initial_agent=self.initial_agent, + tool_execution=tool_executor, + group_agent_names=group_agent_names, + user_agent=self.user_agent, + group_after_work=self.group_after_work, + ) + + # Create the group chat - now we use temp_user_list if no user_agent + groupchat = GroupChat( + agents=[tool_executor] + + self.agents + + wrapped_agents + + ([self.user_agent] if self.user_agent else temp_user_list), + messages=[], + max_round=max_rounds, + speaker_selection_method=group_transition, + ) + + # Create the group manager + manager = create_group_manager(groupchat, self.group_manager_args, self.agents, self.group_after_work) + + # Point all agent's context variables to this function's context_variables + setup_context_variables( + tool_execution=tool_executor, + agents=self.agents, + manager=manager, + user_agent=self.user_agent, + context_variables=self.context_variables, + ) + + # Link all agents with the GroupChatManager to allow access to the group chat + link_agents_to_group_manager(groupchat.agents, manager) + + return ( + self.agents, + wrapped_agents, + self.user_agent, + self.context_variables, + self.initial_agent, + self.group_after_work, + tool_executor, + groupchat, + manager, + processed_messages, + last_agent, + group_agent_names, + temp_user_list, + ) # type: ignore[return-value] + + @classmethod + def create_default( + cls, + initial_agent: "ConversableAgent", + agents: list["ConversableAgent"], + user_agent: Optional["ConversableAgent"] = None, + group_manager_args: Optional[dict[str, Any]] = None, + context_variables: Optional[ContextVariables] = None, + exclude_transit_message: bool = True, + summary_method: Optional[Union[str, Callable[..., Any]]] = "last_msg", + ) -> "DefaultPattern": + """Create a default pattern with minimal configuration. + + This replaces the need for a separate BasePattern class by providing + a factory method that creates a simple DefaultPattern instance. + + Args: + initial_agent: The first agent to speak in the group chat. + agents: List of all agents participating in the chat. + user_agent: Optional user proxy agent. + group_manager_args: Optional arguments for the GroupChatManager. + context_variables: Initial context variables for the chat. + exclude_transit_message: Whether to exclude transit messages from the conversation. + summary_method: Method for summarizing the conversation. + + Returns: + A DefaultPattern instance with basic configuration. + """ + return DefaultPattern( + initial_agent=initial_agent, + agents=agents, + user_agent=user_agent, + group_manager_args=group_manager_args, + context_variables=context_variables, + exclude_transit_message=exclude_transit_message, + summary_method=summary_method, + ) + + +class DefaultPattern(Pattern): + """DefaultPattern implements a minimal pattern for simple agent interactions. + + This replaces the previous BasePattern and provides a concrete implementation + of the Pattern abstract base class. + """ + + def prepare_group_chat( + self, + max_rounds: int, + messages: Union[list[dict[str, Any]], str], + ) -> Tuple[ + list["ConversableAgent"], + list["ConversableAgent"], + Optional["ConversableAgent"], + ContextVariables, + "ConversableAgent", + TransitionTarget, + "GroupToolExecutor", + "GroupChat", + "GroupChatManager", + list[dict[str, Any]], + Any, + list[str], + list[Any], + ]: + """Prepare the group chat with default configuration. + + This implementation calls the parent class method but ensures that + the group_after_work in the returned tuple is the pattern's own. + + Args: + max_rounds: Maximum number of conversation rounds. + messages: Initial message(s) to start the conversation. + + Returns: + Tuple containing all necessary components for the group chat. + """ + # Use the parent class's implementation to prepare the agents and group chat + ( + agents, + wrapped_agents, + user_agent, + context_variables, + initial_agent, + _, # Ignore the group_after_work from parent + tool_executor, + groupchat, + manager, + processed_messages, + last_agent, + group_agent_names, + temp_user_list, + ) = super().prepare_group_chat( + max_rounds=max_rounds, + messages=messages, + ) + + # Return all components with our group_after_work + return ( + agents, + wrapped_agents, + user_agent, + context_variables, + initial_agent, + self.group_after_work, # Use our own group_after_work + tool_executor, + groupchat, + manager, + processed_messages, + last_agent, + group_agent_names, + temp_user_list, + ) diff --git a/mm_agents/coact/autogen/agentchat/group/patterns/random.py b/mm_agents/coact/autogen/agentchat/group/patterns/random.py new file mode 100644 index 0000000..24138ed --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/group/patterns/random.py @@ -0,0 +1,106 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING, Any, Optional, Tuple, Union + +from ..context_variables import ContextVariables +from ..targets.transition_target import RandomAgentTarget, TransitionTarget +from .pattern import Pattern + +if TYPE_CHECKING: + from ...conversable_agent import ConversableAgent + from ...groupchat import GroupChat, GroupChatManager + from ..group_tool_executor import GroupToolExecutor + + +class RandomPattern(Pattern): + """RandomPattern implements a random agent selection process.""" + + def _generate_handoffs( + self, + initial_agent: "ConversableAgent", + agents: list["ConversableAgent"], + user_agent: Optional["ConversableAgent"], + ) -> None: + """Generate handoffs between agents in a random fashion.""" + agent_list = agents + ([user_agent] if user_agent is not None else []) + + for agent in agent_list: + # Get the list of agents except itself + other_agents = [a for a in agent_list if a != agent] + + # Create a random after work + agent.handoffs.set_after_work(target=RandomAgentTarget(agents=other_agents)) + + def prepare_group_chat( + self, + max_rounds: int, + messages: Union[list[dict[str, Any]], str], + ) -> Tuple[ + list["ConversableAgent"], + list["ConversableAgent"], + Optional["ConversableAgent"], + ContextVariables, + "ConversableAgent", + TransitionTarget, + "GroupToolExecutor", + "GroupChat", + "GroupChatManager", + list[dict[str, Any]], + Any, + list[str], + list[Any], + ]: + """Prepare the group chat for organic agent selection. + + Ensures that: + 1. The group manager has a valid LLM config + 2. All agents have appropriate descriptions for the group manager to use + + Args: + max_rounds: Maximum number of conversation rounds. + messages: Initial message(s) to start the conversation. + + Returns: + Tuple containing all necessary components for the group chat. + """ + # Use the parent class's implementation to prepare the agents and group chat + ( + agents, + wrapped_agents, + user_agent, + context_variables, + initial_agent, + group_after_work, + tool_executor, + groupchat, + manager, + processed_messages, + last_agent, + group_agent_names, + temp_user_list, + ) = super().prepare_group_chat( + max_rounds=max_rounds, + messages=messages, + ) + + # Create the random handoffs between agents + self._generate_handoffs(initial_agent=initial_agent, agents=agents, user_agent=user_agent) + + # Return all components with our group_after_work + return ( + agents, + wrapped_agents, + user_agent, + context_variables, + initial_agent, + group_after_work, + tool_executor, + groupchat, + manager, + processed_messages, + last_agent, + group_agent_names, + temp_user_list, + ) diff --git a/mm_agents/coact/autogen/agentchat/group/patterns/round_robin.py b/mm_agents/coact/autogen/agentchat/group/patterns/round_robin.py new file mode 100644 index 0000000..9ecbaae --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/group/patterns/round_robin.py @@ -0,0 +1,117 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING, Any, Optional, Tuple, Union + +from ..context_variables import ContextVariables +from ..targets.transition_target import AgentTarget, TransitionTarget +from .pattern import Pattern + +if TYPE_CHECKING: + from ...conversable_agent import ConversableAgent + from ...groupchat import GroupChat, GroupChatManager + from ..group_tool_executor import GroupToolExecutor + + +class RoundRobinPattern(Pattern): + """RoundRobinPattern implements a round robin with handoffs between agents.""" + + def _generate_handoffs( + self, + initial_agent: "ConversableAgent", + agents: list["ConversableAgent"], + user_agent: Optional["ConversableAgent"], + ) -> None: + """Generate handoffs between agents in a round-robin fashion.""" + # Create a list of the agents and the user_agent but put the initial_agent first + agent_list = [initial_agent] + + # Add the rest of the agents, excluding the initial_agent and user_agent + for agent in agents: + if agent != initial_agent and (user_agent is None or agent != user_agent): + agent_list.append(agent) + + # Add the user_agent last if it exists + if user_agent is not None: + agent_list.append(user_agent) + + # Create handoffs in a round-robin fashion + for i, agent in enumerate(agent_list): + # Last agent hands off to the first agent + # Otherwise agent hands off to the next one + handoff_target = agent_list[0] if i == len(agent_list) - 1 else agent_list[i + 1] + + agent.handoffs.set_after_work(target=AgentTarget(agent=handoff_target)) + + def prepare_group_chat( + self, + max_rounds: int, + messages: Union[list[dict[str, Any]], str], + ) -> Tuple[ + list["ConversableAgent"], + list["ConversableAgent"], + Optional["ConversableAgent"], + ContextVariables, + "ConversableAgent", + TransitionTarget, + "GroupToolExecutor", + "GroupChat", + "GroupChatManager", + list[dict[str, Any]], + Any, + list[str], + list[Any], + ]: + """Prepare the group chat for organic agent selection. + + Ensures that: + 1. The group manager has a valid LLM config + 2. All agents have appropriate descriptions for the group manager to use + + Args: + max_rounds: Maximum number of conversation rounds. + messages: Initial message(s) to start the conversation. + + Returns: + Tuple containing all necessary components for the group chat. + """ + # Use the parent class's implementation to prepare the agents and group chat + ( + agents, + wrapped_agents, + user_agent, + context_variables, + initial_agent, + group_after_work, + tool_executor, + groupchat, + manager, + processed_messages, + last_agent, + group_agent_names, + temp_user_list, + ) = super().prepare_group_chat( + max_rounds=max_rounds, + messages=messages, + ) + + # Create the handoffs between agents + self._generate_handoffs(initial_agent=initial_agent, agents=agents, user_agent=user_agent) + + # Return all components with our group_after_work + return ( + agents, + wrapped_agents, + user_agent, + context_variables, + initial_agent, + group_after_work, + tool_executor, + groupchat, + manager, + processed_messages, + last_agent, + group_agent_names, + temp_user_list, + ) diff --git a/mm_agents/coact/autogen/agentchat/group/reply_result.py b/mm_agents/coact/autogen/agentchat/group/reply_result.py new file mode 100644 index 0000000..7ee8e02 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/group/reply_result.py @@ -0,0 +1,26 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + + +__all__ = ["ReplyResult"] + + +from typing import Optional + +from pydantic import BaseModel + +from .context_variables import ContextVariables +from .targets.transition_target import TransitionTarget + + +class ReplyResult(BaseModel): + """Result of a tool call that is used to provide the return message and the target to transition to.""" + + message: str + target: Optional[TransitionTarget] = None + context_variables: Optional[ContextVariables] = None + + def __str__(self) -> str: + """The string representation for ReplyResult will be just the message.""" + return self.message diff --git a/mm_agents/coact/autogen/agentchat/group/speaker_selection_result.py b/mm_agents/coact/autogen/agentchat/group/speaker_selection_result.py new file mode 100644 index 0000000..9c8ba11 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/group/speaker_selection_result.py @@ -0,0 +1,41 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING, Optional, Union + +from pydantic import BaseModel + +from ..agent import Agent + +if TYPE_CHECKING: + # Avoid circular import + from ..groupchat import GroupChat + + +class SpeakerSelectionResult(BaseModel): + """Represents a speaker selection result that will be returned to GroupChat._prepare_and_select_agents to determine the next speaker. + + This class can return an Agent, a None to end the conversation, or a string for a speaker selection method. + """ + + terminate: Optional[bool] = None + agent_name: Optional[str] = None + speaker_selection_method: Optional[str] = None + + def get_speaker_selection_result(self, groupchat: "GroupChat") -> Optional[Union[Agent, str]]: + """Get the speaker selection result. If None, the conversation will end.""" + if self.agent_name is not None: + # Find the agent by name in the groupchat + for agent in groupchat.agents: + if agent.name == self.agent_name: + return agent + raise ValueError(f"Agent '{self.agent_name}' not found in groupchat.") + elif self.speaker_selection_method is not None: + return self.speaker_selection_method + elif self.terminate is not None and self.terminate: + return None + else: + raise ValueError( + "Unable to establish speaker selection result. No terminate, agent, or speaker selection method provided." + ) diff --git a/mm_agents/coact/autogen/agentchat/group/targets/__init__.py b/mm_agents/coact/autogen/agentchat/group/targets/__init__.py new file mode 100644 index 0000000..78d5382 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/group/targets/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 +# diff --git a/mm_agents/coact/autogen/agentchat/group/targets/group_chat_target.py b/mm_agents/coact/autogen/agentchat/group/targets/group_chat_target.py new file mode 100644 index 0000000..62ed046 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/group/targets/group_chat_target.py @@ -0,0 +1,132 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING, Any, Optional, Union + +from pydantic import BaseModel + +from ....doc_utils import export_module +from ...agent import Agent +from ..speaker_selection_result import SpeakerSelectionResult +from .transition_target import AgentTarget, TransitionTarget +from .transition_utils import __AGENT_WRAPPER_PREFIX__ + +if TYPE_CHECKING: + from ...conversable_agent import ConversableAgent + from ...groupchat import GroupChat + from ..patterns.pattern import Pattern + + +__all__ = ["GroupChatConfig", "GroupChatTarget"] + + +@export_module("autogen.agentchat.group") +class GroupChatConfig(BaseModel): + """Configuration for a group chat transition target. + + Note: If context_variables are not passed in, the outer context variables will be passed in""" + + pattern: "Pattern" + messages: Union[list[dict[str, Any]], str] + max_rounds: int = 20 + + +@export_module("autogen.agentchat.group") +class GroupChatTarget(TransitionTarget): + """Target that represents a group chat.""" + + group_chat_config: GroupChatConfig + + def can_resolve_for_speaker_selection(self) -> bool: + """Check if the target can resolve for speaker selection. For GroupChatTarget the chat must be encapsulated into an agent.""" + return False + + def resolve( + self, + groupchat: "GroupChat", + current_agent: "ConversableAgent", + user_agent: Optional["ConversableAgent"], + ) -> SpeakerSelectionResult: + """Resolve to the nested chat configuration.""" + raise NotImplementedError( + "GroupChatTarget does not support the resolve method. An agent should be used to encapsulate this nested chat and then the target changed to an AgentTarget." + ) + + def display_name(self) -> str: + """Get the display name for the target.""" + return "a group chat" + + def normalized_name(self) -> str: + """Get a normalized name for the target that has no spaces, used for function calling.""" + return "group_chat" + + def __str__(self) -> str: + """String representation for AgentTarget, can be shown as a function call message.""" + return "Transfer to group chat" + + def needs_agent_wrapper(self) -> bool: + """Check if the target needs to be wrapped in an agent. GroupChatTarget must be wrapped in an agent.""" + return True + + def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent": + """Create a wrapper agent for the group chat.""" + from autogen.agentchat import initiate_group_chat + + from ...conversable_agent import ConversableAgent # to avoid circular import + + # Create the wrapper agent with a name that identifies it as a wrapped group chat + group_chat_agent = ConversableAgent( + name=f"{__AGENT_WRAPPER_PREFIX__}group_{parent_agent.name}_{index + 1}", + # Copy LLM config from parent agent to ensure it can generate replies if needed + llm_config=parent_agent.llm_config, + ) + + # Store the config directly on the agent + group_chat_agent._group_chat_config = self.group_chat_config # type: ignore[attr-defined] + + # Define the reply function that will run the group chat + def group_chat_reply( + agent: "ConversableAgent", + messages: Optional[list[dict[str, Any]]] = None, + sender: Optional["Agent"] = None, + config: Optional[Any] = None, + ) -> tuple[bool, Optional[dict[str, Any]]]: + """Run the inner group chat and return its results as a reply.""" + # Get the configuration stored directly on the agent + group_config = agent._group_chat_config # type: ignore[attr-defined] + + # Pull through the second last message from the outer chat (the last message will be the handoff message) + # This may need work to make sure we get the right message(s) from the outer chat + message = ( + messages[-2]["content"] + if messages and len(messages) >= 2 and "content" in messages[-2] + else "No message to pass through." + ) + + try: + # Run the group chat with direct agent references from the config + result, _, _ = initiate_group_chat( + pattern=group_config.pattern, + messages=message, + max_rounds=group_config.max_rounds, + ) + + # Return the summary from the chat result summary + return True, {"content": result.summary} + + except Exception as e: + # Handle any errors during execution + return True, {"content": f"Error running group chat: {str(e)}"} + + # Register the reply function with the wrapper agent + group_chat_agent.register_reply( + trigger=[ConversableAgent, None], + reply_func=group_chat_reply, + remove_other_reply_funcs=True, # Use only this reply function + ) + + # After the group chat completes, transition back to the parent agent + group_chat_agent.handoffs.set_after_work(AgentTarget(parent_agent)) + + return group_chat_agent diff --git a/mm_agents/coact/autogen/agentchat/group/targets/group_manager_target.py b/mm_agents/coact/autogen/agentchat/group/targets/group_manager_target.py new file mode 100644 index 0000000..3aab182 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/group/targets/group_manager_target.py @@ -0,0 +1,151 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING, Any, Optional, Type, Union + +from pydantic import BaseModel, field_validator + +from ....doc_utils import export_module +from ..context_str import ContextStr +from ..group_tool_executor import GroupToolExecutor +from ..speaker_selection_result import SpeakerSelectionResult +from .transition_target import TransitionTarget +from .transition_utils import __AGENT_WRAPPER_PREFIX__ + +if TYPE_CHECKING: + # Avoid circular import + from ...conversable_agent import ConversableAgent + from ...groupchat import GroupChat + +__all__ = ["GroupManagerTarget"] + + +def prepare_groupchat_auto_speaker( + groupchat: "GroupChat", + last_group_agent: "ConversableAgent", + group_chat_manager_selection_msg: Optional[Any], +) -> None: + """Prepare the group chat for auto speaker selection, includes updating or restore the groupchat speaker selection message. + + Tool Executor and wrapped agents will be removed from the available agents list. + + Args: + groupchat (GroupChat): GroupChat instance. + last_group_agent ("ConversableAgent"): The last group agent for which the LLM config is used + group_chat_manager_selection_msg (GroupManagerSelectionMessage): Optional message to use for the agent selection (in internal group chat). + """ + from ...groupchat import SELECT_SPEAKER_PROMPT_TEMPLATE + + def substitute_agentlist(template: str) -> str: + # Run through group chat's string substitution first for {agentlist} + # We need to do this so that the next substitution doesn't fail with agentlist + # and we can remove the tool executor and wrapped chats from the available agents list + agent_list = [ + agent + for agent in groupchat.agents + if not isinstance(agent, GroupToolExecutor) and not agent.name.startswith(__AGENT_WRAPPER_PREFIX__) + ] + + groupchat.select_speaker_prompt_template = template + return groupchat.select_speaker_prompt(agent_list) + + # Use the default speaker selection prompt if one is not specified, otherwise use the specified one + groupchat.select_speaker_prompt_template = substitute_agentlist( + SELECT_SPEAKER_PROMPT_TEMPLATE + if group_chat_manager_selection_msg is None + else group_chat_manager_selection_msg.get_message(last_group_agent) + ) + + +# GroupManagerSelectionMessage protocol and implementations +@export_module("autogen.agentchat.group") +class GroupManagerSelectionMessage(BaseModel): + """Base class for all GroupManager selection message types.""" + + def get_message(self, agent: "ConversableAgent") -> str: + """Get the formatted message.""" + raise NotImplementedError("Requires subclasses to implement.") + + +@export_module("autogen.agentchat.group") +class GroupManagerSelectionMessageString(GroupManagerSelectionMessage): + """Selection message that uses a plain string template.""" + + message: str + + def get_message(self, agent: "ConversableAgent") -> str: + """Get the message string.""" + return self.message + + +@export_module("autogen.agentchat.group") +class GroupManagerSelectionMessageContextStr(GroupManagerSelectionMessage): + """Selection message that uses a ContextStr template.""" + + context_str_template: str + + # We will replace {agentlist} with another term and return it later for use with the internal group chat auto speaker selection + # Otherwise our format will fail + @field_validator("context_str_template", mode="before") + def _replace_agentlist_placeholder(cls: Type["GroupManagerSelectionMessageContextStr"], v: Any) -> Union[str, Any]: # noqa: N805 + """Replace {agentlist} placeholder before validation/assignment.""" + if isinstance(v, str): + if "{agentlist}" in v: + return v.replace("{agentlist}", "<>") # Perform the replacement + else: + return v # If no replacement is needed, return the original value + return "" + + def get_message(self, agent: "ConversableAgent") -> str: + """Get the formatted message with context variables substituted.""" + context_str = ContextStr(template=self.context_str_template) + format_result = context_str.format(agent.context_variables) + if format_result is None: + return "" + + return format_result.replace( + "<>", "{agentlist}" + ) # Restore agentlist so it can be substituted by the internal group chat auto speaker selection + + +class GroupManagerTarget(TransitionTarget): + """Target that represents an agent by name.""" + + selection_message: Optional[GroupManagerSelectionMessage] = None + + def can_resolve_for_speaker_selection(self) -> bool: + """Check if the target can resolve for speaker selection.""" + return True + + def resolve( + self, + groupchat: "GroupChat", + current_agent: "ConversableAgent", + user_agent: Optional["ConversableAgent"], + ) -> SpeakerSelectionResult: + """Resolve to the speaker selection for the group.""" + if self.selection_message is not None: + prepare_groupchat_auto_speaker(groupchat, current_agent, self.selection_message) + + return SpeakerSelectionResult(speaker_selection_method="auto") + + def display_name(self) -> str: + """Get the display name for the target.""" + return "the group manager" + + def normalized_name(self) -> str: + """Get a normalized name for the target that has no spaces, used for function calling""" + return self.display_name() + + def __str__(self) -> str: + """String representation for AgentTarget, can be shown as a function call message.""" + return "Transfer to the group manager" + + def needs_agent_wrapper(self) -> bool: + """Check if the target needs to be wrapped in an agent.""" + return False + + def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent": + """Create a wrapper agent for the target if needed.""" + raise NotImplementedError("GroupManagerTarget does not require wrapping in an agent.") diff --git a/mm_agents/coact/autogen/agentchat/group/targets/transition_target.py b/mm_agents/coact/autogen/agentchat/group/targets/transition_target.py new file mode 100644 index 0000000..eb7ac0f --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/group/targets/transition_target.py @@ -0,0 +1,413 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +import random +from typing import TYPE_CHECKING, Any, Optional + +from pydantic import BaseModel + +from ..speaker_selection_result import SpeakerSelectionResult +from .transition_utils import __AGENT_WRAPPER_PREFIX__ + +if TYPE_CHECKING: + # Avoid circular import + from ...conversable_agent import ConversableAgent + from ...groupchat import GroupChat + +__all__ = [ + "AgentNameTarget", + "AgentTarget", + "AskUserTarget", + "NestedChatTarget", + "RandomAgentTarget", + "RevertToUserTarget", + "StayTarget", + "TerminateTarget", + "TransitionTarget", +] + +# Common options for transitions +# terminate: Terminate the conversation +# revert_to_user: Revert to the user agent +# stay: Stay with the current agent +# group_manager: Use the group manager (auto speaker selection) +# ask_user: Use the user manager (ask the user, aka manual) +# TransitionOption = Literal["terminate", "revert_to_user", "stay", "group_manager", "ask_user"] + + +class TransitionTarget(BaseModel): + """Base class for all transition targets across OnCondition, OnContextCondition, and after work.""" + + def can_resolve_for_speaker_selection(self) -> bool: + """Check if the target can resolve to an option for speaker selection (Agent, 'None' to end, Str for speaker selection method). In the case of a nested chat, this will return False as it should be encapsulated in an agent.""" + return False + + def resolve( + self, + groupchat: "GroupChat", + current_agent: "ConversableAgent", + user_agent: Optional["ConversableAgent"], + ) -> SpeakerSelectionResult: + """Resolve to a speaker selection result (Agent, None for termination, or str for speaker selection method).""" + raise NotImplementedError("Requires subclasses to implement.") + + def display_name(self) -> str: + """Get the display name for the target.""" + raise NotImplementedError("Requires subclasses to implement.") + + def normalized_name(self) -> str: + """Get a normalized name for the target that has no spaces, used for function calling""" + raise NotImplementedError("Requires subclasses to implement.") + + def needs_agent_wrapper(self) -> bool: + """Check if the target needs to be wrapped in an agent.""" + raise NotImplementedError("Requires subclasses to implement.") + + def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent": + """Create a wrapper agent for the target if needed.""" + raise NotImplementedError("Requires subclasses to implement.") + + +class AgentTarget(TransitionTarget): + """Target that represents a direct agent reference.""" + + agent_name: str + + def __init__(self, agent: "ConversableAgent", **data: Any) -> None: # type: ignore[no-untyped-def] + # Store the name from the agent for serialization + super().__init__(agent_name=agent.name, **data) + + def can_resolve_for_speaker_selection(self) -> bool: + """Check if the target can resolve for speaker selection.""" + return True + + def resolve( + self, + groupchat: "GroupChat", + current_agent: "ConversableAgent", + user_agent: Optional["ConversableAgent"], + ) -> SpeakerSelectionResult: + """Resolve to the actual agent object from the groupchat.""" + return SpeakerSelectionResult(agent_name=self.agent_name) + + def display_name(self) -> str: + """Get the display name for the target.""" + return f"{self.agent_name}" + + def normalized_name(self) -> str: + """Get a normalized name for the target that has no spaces, used for function calling""" + return self.display_name() + + def __str__(self) -> str: + """String representation for AgentTarget, can be shown as a function call message.""" + return f"Transfer to {self.agent_name}" + + def needs_agent_wrapper(self) -> bool: + """Check if the target needs to be wrapped in an agent.""" + return False + + def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent": + """Create a wrapper agent for the target if needed.""" + raise NotImplementedError("AgentTarget does not require wrapping in an agent.") + + +class AgentNameTarget(TransitionTarget): + """Target that represents an agent by name.""" + + agent_name: str + + def __init__(self, agent_name: str, **data: Any) -> None: + """Initialize with agent name as a positional parameter.""" + super().__init__(agent_name=agent_name, **data) + + def can_resolve_for_speaker_selection(self) -> bool: + """Check if the target can resolve for speaker selection.""" + return True + + def resolve( + self, + groupchat: "GroupChat", + current_agent: "ConversableAgent", + user_agent: Optional["ConversableAgent"], + ) -> SpeakerSelectionResult: + """Resolve to the agent name string.""" + return SpeakerSelectionResult(agent_name=self.agent_name) + + def display_name(self) -> str: + """Get the display name for the target.""" + return f"{self.agent_name}" + + def normalized_name(self) -> str: + """Get a normalized name for the target that has no spaces, used for function calling""" + return self.display_name() + + def __str__(self) -> str: + """String representation for AgentTarget, can be shown as a function call message.""" + return f"Transfer to {self.agent_name}" + + def needs_agent_wrapper(self) -> bool: + """Check if the target needs to be wrapped in an agent.""" + return False + + def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent": + """Create a wrapper agent for the target if needed.""" + raise NotImplementedError("AgentNameTarget does not require wrapping in an agent.") + + +class NestedChatTarget(TransitionTarget): + """Target that represents a nested chat configuration.""" + + nested_chat_config: dict[str, Any] + + def can_resolve_for_speaker_selection(self) -> bool: + """Check if the target can resolve for speaker selection. For NestedChatTarget the nested chat must be encapsulated into an agent.""" + return False + + def resolve( + self, + groupchat: "GroupChat", + current_agent: "ConversableAgent", + user_agent: Optional["ConversableAgent"], + ) -> SpeakerSelectionResult: + """Resolve to the nested chat configuration.""" + raise NotImplementedError( + "NestedChatTarget does not support the resolve method. An agent should be used to encapsulate this nested chat and then the target changed to an AgentTarget." + ) + + def display_name(self) -> str: + """Get the display name for the target.""" + return "a nested chat" + + def normalized_name(self) -> str: + """Get a normalized name for the target that has no spaces, used for function calling""" + return "nested_chat" + + def __str__(self) -> str: + """String representation for AgentTarget, can be shown as a function call message.""" + return "Transfer to nested chat" + + def needs_agent_wrapper(self) -> bool: + """Check if the target needs to be wrapped in an agent. NestedChatTarget must be wrapped in an agent.""" + return True + + def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent": + """Create a wrapper agent for the nested chat.""" + from ...conversable_agent import ConversableAgent # to avoid circular import - NEED SOLUTION + + nested_chat_agent = ConversableAgent(name=f"{__AGENT_WRAPPER_PREFIX__}nested_{parent_agent.name}_{index + 1}") + + nested_chat_agent.register_nested_chats( + self.nested_chat_config["chat_queue"], + reply_func_from_nested_chats=self.nested_chat_config.get("reply_func_from_nested_chats") + or "summary_from_nested_chats", + config=self.nested_chat_config.get("config"), + trigger=lambda sender: True, + position=0, + use_async=self.nested_chat_config.get("use_async", False), + ) + + # After the nested chat is complete, transfer back to the parent agent + nested_chat_agent.handoffs.set_after_work(AgentTarget(parent_agent)) + + return nested_chat_agent + + +class TerminateTarget(TransitionTarget): + """Target that represents a termination of the conversation.""" + + def can_resolve_for_speaker_selection(self) -> bool: + """Check if the target can resolve for speaker selection.""" + return True + + def resolve( + self, + groupchat: "GroupChat", + current_agent: "ConversableAgent", + user_agent: Optional["ConversableAgent"], + ) -> SpeakerSelectionResult: + """Resolve to termination.""" + return SpeakerSelectionResult(terminate=True) + + def display_name(self) -> str: + """Get the display name for the target.""" + return "Terminate" + + def normalized_name(self) -> str: + """Get a normalized name for the target that has no spaces, used for function calling""" + return "terminate" + + def __str__(self) -> str: + """String representation for AgentTarget, can be shown as a function call message.""" + return "Terminate" + + def needs_agent_wrapper(self) -> bool: + """Check if the target needs to be wrapped in an agent.""" + return False + + def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent": + """Create a wrapper agent for the target if needed.""" + raise NotImplementedError("TerminateTarget does not require wrapping in an agent.") + + +class StayTarget(TransitionTarget): + """Target that represents staying with the current agent.""" + + def can_resolve_for_speaker_selection(self) -> bool: + """Check if the target can resolve for speaker selection.""" + return True + + def resolve( + self, + groupchat: "GroupChat", + current_agent: "ConversableAgent", + user_agent: Optional["ConversableAgent"], + ) -> SpeakerSelectionResult: + """Resolve to staying with the current agent.""" + return SpeakerSelectionResult(agent_name=current_agent.name) + + def display_name(self) -> str: + """Get the display name for the target.""" + return "Stay" + + def normalized_name(self) -> str: + """Get a normalized name for the target that has no spaces, used for function calling""" + return "stay" + + def __str__(self) -> str: + """String representation for AgentTarget, can be shown as a function call message.""" + return "Stay with agent" + + def needs_agent_wrapper(self) -> bool: + """Check if the target needs to be wrapped in an agent.""" + return False + + def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent": + """Create a wrapper agent for the target if needed.""" + raise NotImplementedError("StayTarget does not require wrapping in an agent.") + + +class RevertToUserTarget(TransitionTarget): + """Target that represents reverting to the user agent.""" + + def can_resolve_for_speaker_selection(self) -> bool: + """Check if the target can resolve for speaker selection.""" + return True + + def resolve( + self, + groupchat: "GroupChat", + current_agent: "ConversableAgent", + user_agent: Optional["ConversableAgent"], + ) -> SpeakerSelectionResult: + """Resolve to reverting to the user agent.""" + if user_agent is None: + raise ValueError("User agent must be provided to the chat for the revert_to_user option.") + return SpeakerSelectionResult(agent_name=user_agent.name) + + def display_name(self) -> str: + """Get the display name for the target.""" + return "Revert to User" + + def normalized_name(self) -> str: + """Get a normalized name for the target that has no spaces, used for function calling""" + return "revert_to_user" + + def __str__(self) -> str: + """String representation for AgentTarget, can be shown as a function call message.""" + return "Revert to User" + + def needs_agent_wrapper(self) -> bool: + """Check if the target needs to be wrapped in an agent.""" + return False + + def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent": + """Create a wrapper agent for the target if needed.""" + raise NotImplementedError("RevertToUserTarget does not require wrapping in an agent.") + + +class AskUserTarget(TransitionTarget): + """Target that represents asking the user for input.""" + + def can_resolve_for_speaker_selection(self) -> bool: + """Check if the target can resolve for speaker selection.""" + return True + + def resolve( + self, + groupchat: "GroupChat", + current_agent: "ConversableAgent", + user_agent: Optional["ConversableAgent"], + ) -> SpeakerSelectionResult: + """Resolve to asking the user for input.""" + return SpeakerSelectionResult(speaker_selection_method="manual") + + def display_name(self) -> str: + """Get the display name for the target.""" + return "Ask User" + + def normalized_name(self) -> str: + """Get a normalized name for the target that has no spaces, used for function calling""" + return "ask_user" + + def __str__(self) -> str: + """String representation for AgentTarget, can be shown as a function call message.""" + return "Ask User" + + def needs_agent_wrapper(self) -> bool: + """Check if the target needs to be wrapped in an agent.""" + return False + + def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent": + """Create a wrapper agent for the target if needed.""" + raise NotImplementedError("AskUserTarget does not require wrapping in an agent.") + + +class RandomAgentTarget(TransitionTarget): + """Target that represents a random selection from a list of agents.""" + + agent_names: list[str] + nominated_name: str = "" + + def __init__(self, agents: list["ConversableAgent"], **data: Any) -> None: # type: ignore[no-untyped-def] + # Store the name from the agent for serialization + super().__init__(agent_names=[agent.name for agent in agents], **data) + + def can_resolve_for_speaker_selection(self) -> bool: + """Check if the target can resolve for speaker selection.""" + return True + + def resolve( + self, + groupchat: "GroupChat", + current_agent: "ConversableAgent", + user_agent: Optional["ConversableAgent"], + ) -> SpeakerSelectionResult: + """Resolve to the actual agent object from the groupchat, choosing a random agent (except the current one)""" + # Randomly select the next agent + self.nominated_name = random.choice([name for name in self.agent_names if name != current_agent.name]) + + return SpeakerSelectionResult(agent_name=self.nominated_name) + + def display_name(self) -> str: + """Get the display name for the target.""" + return self.nominated_name + + def normalized_name(self) -> str: + """Get a normalized name for the target that has no spaces, used for function calling""" + return self.display_name() + + def __str__(self) -> str: + """String representation for RandomAgentTarget, can be shown as a function call message.""" + return f"Transfer to {self.nominated_name}" + + def needs_agent_wrapper(self) -> bool: + """Check if the target needs to be wrapped in an agent.""" + return False + + def create_wrapper_agent(self, parent_agent: "ConversableAgent", index: int) -> "ConversableAgent": + """Create a wrapper agent for the target if needed.""" + raise NotImplementedError("RandomAgentTarget does not require wrapping in an agent.") + + +# TODO: Consider adding a SequentialChatTarget class diff --git a/mm_agents/coact/autogen/agentchat/group/targets/transition_utils.py b/mm_agents/coact/autogen/agentchat/group/targets/transition_utils.py new file mode 100644 index 0000000..fd904d4 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/group/targets/transition_utils.py @@ -0,0 +1,6 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +# Prefix for all wrapped agent names +__AGENT_WRAPPER_PREFIX__ = "wrapped_" diff --git a/mm_agents/coact/autogen/agentchat/groupchat.py b/mm_agents/coact/autogen/agentchat/groupchat.py new file mode 100644 index 0000000..960bbc2 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/groupchat.py @@ -0,0 +1,1694 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 +# +# Portions derived from https://github.com/microsoft/autogen are under the MIT License. +# SPDX-License-Identifier: MIT +import copy +import json +import logging +import random +import re +import sys +from dataclasses import dataclass, field +from typing import Any, Callable, Literal, Optional, Union + +from ..code_utils import content_str +from ..doc_utils import export_module +from ..events.agent_events import ( + ClearAgentsHistoryEvent, + GroupChatResumeEvent, + GroupChatRunChatEvent, + SelectSpeakerEvent, + SelectSpeakerInvalidInputEvent, + SelectSpeakerTryCountExceededEvent, + SpeakerAttemptFailedMultipleAgentsEvent, + SpeakerAttemptFailedNoAgentsEvent, + SpeakerAttemptSuccessfulEvent, + TerminationEvent, +) +from ..exception_utils import AgentNameConflictError, NoEligibleSpeakerError, UndefinedNextAgentError +from ..graph_utils import check_graph_validity, invert_disallowed_to_allowed +from ..io.base import IOStream +from ..llm_config import LLMConfig +from ..oai.client import ModelClient +from ..runtime_logging import log_new_agent, logging_enabled +from .agent import Agent +from .contrib.capabilities import transform_messages +from .conversable_agent import ConversableAgent + +logger = logging.getLogger(__name__) + +SELECT_SPEAKER_PROMPT_TEMPLATE = ( + "Read the above conversation. Then select the next role from {agentlist} to play. Only return the role." +) + + +@dataclass +@export_module("autogen") +class GroupChat: + """(In preview) A group chat class that contains the following data fields: + - agents: a list of participating agents. + - messages: a list of messages in the group chat. + - max_round: the maximum number of rounds. + - admin_name: the name of the admin agent if there is one. Default is "Admin". + KeyBoardInterrupt will make the admin agent take over. + - func_call_filter: whether to enforce function call filter. Default is True. + When set to True and when a message is a function call suggestion, + the next speaker will be chosen from an agent which contains the corresponding function name + in its `function_map`. + - select_speaker_message_template: customize the select speaker message (used in "auto" speaker selection), which appears first in the message context and generally includes the agent descriptions and list of agents. If the string contains "`{roles}`" it will replaced with the agent's and their role descriptions. If the string contains "`{agentlist}`" it will be replaced with a comma-separated list of agent names in square brackets. The default value is: + "You are in a role play game. The following roles are available: + `{roles}`. + Read the following conversation. + Then select the next role from `{agentlist}` to play. Only return the role." + - select_speaker_prompt_template: customize the select speaker prompt (used in "auto" speaker selection), which appears last in the message context and generally includes the list of agents and guidance for the LLM to select the next agent. If the string contains "`{agentlist}`" it will be replaced with a comma-separated list of agent names in square brackets. The default value is: + "Read the above conversation. Then select the next role from `{agentlist}` to play. Only return the role." + To ignore this prompt being used, set this to None. If set to None, ensure your instructions for selecting a speaker are in the select_speaker_message_template string. + - select_speaker_auto_multiple_template: customize the follow-up prompt used when selecting a speaker fails with a response that contains multiple agent names. This prompt guides the LLM to return just one agent name. Applies only to "auto" speaker selection method. If the string contains "`{agentlist}`" it will be replaced with a comma-separated list of agent names in square brackets. The default value is: + "You provided more than one name in your text, please return just the name of the next speaker. To determine the speaker use these prioritised rules: + 1. If the context refers to themselves as a speaker e.g. "As the..." , choose that speaker's name + 2. If it refers to the "next" speaker name, choose that name + 3. Otherwise, choose the first provided speaker's name in the context + The names are case-sensitive and should not be abbreviated or changed. + Respond with ONLY the name of the speaker and DO NOT provide a reason." + - select_speaker_auto_none_template: customize the follow-up prompt used when selecting a speaker fails with a response that contains no agent names. This prompt guides the LLM to return an agent name and provides a list of agent names. Applies only to "auto" speaker selection method. If the string contains "`{agentlist}`" it will be replaced with a comma-separated list of agent names in square brackets. The default value is: + "You didn't choose a speaker. As a reminder, to determine the speaker use these prioritised rules: + 1. If the context refers to themselves as a speaker e.g. "As the..." , choose that speaker's name + 2. If it refers to the "next" speaker name, choose that name + 3. Otherwise, choose the first provided speaker's name in the context + The names are case-sensitive and should not be abbreviated or changed. + The only names that are accepted are `{agentlist}`. + Respond with ONLY the name of the speaker and DO NOT provide a reason." + - speaker_selection_method: the method for selecting the next speaker. Default is "auto". + Could be any of the following (case insensitive), will raise ValueError if not recognized: + - "auto": the next speaker is selected automatically by LLM. + - "manual": the next speaker is selected manually by user input. + - "random": the next speaker is selected randomly. + - "round_robin": the next speaker is selected in a round robin fashion, i.e., iterating in the same order as provided in `agents`. + - a customized speaker selection function (Callable): the function will be called to select the next speaker. + The function should take the last speaker and the group chat as input and return one of the following: + 1. an `Agent` class, it must be one of the agents in the group chat. + 2. a string from ['auto', 'manual', 'random', 'round_robin'] to select a default method to use. + 3. None, which would terminate the conversation gracefully. + ```python + def custom_speaker_selection_func( + last_speaker: Agent, groupchat: GroupChat + ) -> Union[Agent, str, None]: + ``` + - max_retries_for_selecting_speaker: the maximum number of times the speaker selection requery process will run. + If, during speaker selection, multiple agent names or no agent names are returned by the LLM as the next agent, it will be queried again up to the maximum number + of times until a single agent is returned or it exhausts the maximum attempts. + Applies only to "auto" speaker selection method. + Default is 2. + - select_speaker_transform_messages: (optional) the message transformations to apply to the nested select speaker agent-to-agent chat messages. + Takes a TransformMessages object, defaults to None and is only utilised when the speaker selection method is "auto". + - select_speaker_auto_verbose: whether to output the select speaker responses and selections + If set to True, the outputs from the two agents in the nested select speaker chat will be output, along with + whether the responses were successful, or not, in selecting an agent + Applies only to "auto" speaker selection method. + - allow_repeat_speaker: whether to allow the same speaker to speak consecutively. + Default is True, in which case all speakers are allowed to speak consecutively. + If `allow_repeat_speaker` is a list of Agents, then only those listed agents are allowed to repeat. + If set to False, then no speakers are allowed to repeat. + `allow_repeat_speaker` and `allowed_or_disallowed_speaker_transitions` are mutually exclusive. + - allowed_or_disallowed_speaker_transitions: dict. + The keys are source agents, and the values are agents that the key agent can/can't transit to, + depending on speaker_transitions_type. Default is None, which means all agents can transit to all other agents. + `allow_repeat_speaker` and `allowed_or_disallowed_speaker_transitions` are mutually exclusive. + - speaker_transitions_type: whether the speaker_transitions_type is a dictionary containing lists of allowed agents or disallowed agents. + "allowed" means the `allowed_or_disallowed_speaker_transitions` is a dictionary containing lists of allowed agents. + If set to "disallowed", then the `allowed_or_disallowed_speaker_transitions` is a dictionary containing lists of disallowed agents. + Must be supplied if `allowed_or_disallowed_speaker_transitions` is not None. + - enable_clear_history: enable possibility to clear history of messages for agents manually by providing + "clear history" phrase in user prompt. This is experimental feature. + See description of GroupChatManager.clear_agents_history function for more info. + - send_introductions: send a round of introductions at the start of the group chat, so agents know who they can speak to (default: False) + - select_speaker_auto_model_client_cls: Custom model client class for the internal speaker select agent used during 'auto' speaker selection (optional) + - select_speaker_auto_llm_config: LLM config for the internal speaker select agent used during 'auto' speaker selection (optional) + - role_for_select_speaker_messages: sets the role name for speaker selection when in 'auto' mode, typically 'user' or 'system'. (default: 'system') + """ + + agents: list[Agent] + messages: list[dict[str, Any]] = field(default_factory=list) + max_round: int = 10 + admin_name: str = "Admin" + func_call_filter: bool = True + speaker_selection_method: Union[Literal["auto", "manual", "random", "round_robin"], Callable[..., Any]] = "auto" + max_retries_for_selecting_speaker: int = 2 + allow_repeat_speaker: Optional[Union[bool, list[Agent]]] = None + allowed_or_disallowed_speaker_transitions: Optional[dict[str, Any]] = None + speaker_transitions_type: Literal["allowed", "disallowed", None] = None + enable_clear_history: bool = False + send_introductions: bool = False + select_speaker_message_template: str = """You are in a role play game. The following roles are available: + {roles}. + Read the following conversation. + Then select the next role from {agentlist} to play. Only return the role.""" + select_speaker_prompt_template: str = SELECT_SPEAKER_PROMPT_TEMPLATE + select_speaker_auto_multiple_template: str = """You provided more than one name in your text, please return just the name of the next speaker. To determine the speaker use these prioritised rules: + 1. If the context refers to themselves as a speaker e.g. "As the..." , choose that speaker's name + 2. If it refers to the "next" speaker name, choose that name + 3. Otherwise, choose the first provided speaker's name in the context + The names are case-sensitive and should not be abbreviated or changed. + Respond with ONLY the name of the speaker and DO NOT provide a reason.""" + select_speaker_auto_none_template: str = """You didn't choose a speaker. As a reminder, to determine the speaker use these prioritised rules: + 1. If the context refers to themselves as a speaker e.g. "As the..." , choose that speaker's name + 2. If it refers to the "next" speaker name, choose that name + 3. Otherwise, choose the first provided speaker's name in the context + The names are case-sensitive and should not be abbreviated or changed. + The only names that are accepted are {agentlist}. + Respond with ONLY the name of the speaker and DO NOT provide a reason.""" + select_speaker_transform_messages: Optional[transform_messages.TransformMessages] = None + select_speaker_auto_verbose: Optional[bool] = False + select_speaker_auto_model_client_cls: Optional[Union[ModelClient, list[ModelClient]]] = None + select_speaker_auto_llm_config: Optional[Union[LLMConfig, dict[str, Any], Literal[False]]] = None + role_for_select_speaker_messages: Optional[str] = "system" + + _VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin"] + _VALID_SPEAKER_TRANSITIONS_TYPE = ["allowed", "disallowed", None] + + # Define a class attribute for the default introduction message + DEFAULT_INTRO_MSG = ( + "Hello everyone. We have assembled a great team today to answer questions and solve tasks. In attendance are:" + ) + + allowed_speaker_transitions_dict: dict[str, list[Agent]] = field(init=False) + + def __post_init__(self): + # Post init steers clears of the automatically generated __init__ method from dataclass + + if self.allow_repeat_speaker is not None and not isinstance(self.allow_repeat_speaker, (bool, list)): + raise ValueError("GroupChat allow_repeat_speaker should be a bool or a list of Agents.") + + # Here, we create allowed_speaker_transitions_dict from the supplied allowed_or_disallowed_speaker_transitions and speaker_transitions_type, and lastly checks for validity. + + # Check input + if self.speaker_transitions_type is not None: + self.speaker_transitions_type = self.speaker_transitions_type.lower() + + if self.speaker_transitions_type not in self._VALID_SPEAKER_TRANSITIONS_TYPE: + raise ValueError( + f"GroupChat speaker_transitions_type is set to '{self.speaker_transitions_type}'. " + f"It should be one of {self._VALID_SPEAKER_TRANSITIONS_TYPE} (case insensitive). " + ) + + # If both self.allowed_or_disallowed_speaker_transitions is None and self.allow_repeat_speaker is None, set allow_repeat_speaker to True to ensure backward compatibility + # Discussed in https://github.com/microsoft/autogen/pull/857#discussion_r1451541204 + if self.allowed_or_disallowed_speaker_transitions is None and self.allow_repeat_speaker is None: + self.allow_repeat_speaker = True + + # self.allowed_or_disallowed_speaker_transitions and self.allow_repeat_speaker are mutually exclusive parameters. + # Discussed in https://github.com/microsoft/autogen/pull/857#discussion_r1451266661 + if self.allowed_or_disallowed_speaker_transitions is not None and self.allow_repeat_speaker is not None: + raise ValueError( + "Don't provide both allowed_or_disallowed_speaker_transitions and allow_repeat_speaker in group chat. " + "Please set one of them to None." + ) + + # Asks the user to specify whether the speaker_transitions_type is allowed or disallowed if speaker_transitions_type is supplied + # Discussed in https://github.com/microsoft/autogen/pull/857#discussion_r1451259524 + if self.allowed_or_disallowed_speaker_transitions is not None and self.speaker_transitions_type is None: + raise ValueError( + "GroupChat allowed_or_disallowed_speaker_transitions is not None, but speaker_transitions_type is None. " + "Please set speaker_transitions_type to either 'allowed' or 'disallowed'." + ) + + # Inferring self.allowed_speaker_transitions_dict + # Create self.allowed_speaker_transitions_dict if allowed_or_disallowed_speaker_transitions is None, using allow_repeat_speaker + if self.allowed_or_disallowed_speaker_transitions is None: + self.allowed_speaker_transitions_dict = {} + + # Create a fully connected allowed_speaker_transitions_dict not including self loops + for agent in self.agents: + self.allowed_speaker_transitions_dict[agent] = [ + other_agent for other_agent in self.agents if other_agent != agent + ] + + # If self.allow_repeat_speaker is True, add self loops to all agents + if self.allow_repeat_speaker is True: + for agent in self.agents: + self.allowed_speaker_transitions_dict[agent].append(agent) + + # Else if self.allow_repeat_speaker is a list of Agents, add self loops to the agents in the list + elif isinstance(self.allow_repeat_speaker, list): + for agent in self.allow_repeat_speaker: + self.allowed_speaker_transitions_dict[agent].append(agent) + + # Create self.allowed_speaker_transitions_dict if allowed_or_disallowed_speaker_transitions is not None, using allowed_or_disallowed_speaker_transitions + else: + # Process based on speaker_transitions_type + if self.speaker_transitions_type == "allowed": + self.allowed_speaker_transitions_dict = self.allowed_or_disallowed_speaker_transitions + else: + # Logic for processing disallowed allowed_or_disallowed_speaker_transitions to allowed_speaker_transitions_dict + self.allowed_speaker_transitions_dict = invert_disallowed_to_allowed( + self.allowed_or_disallowed_speaker_transitions, self.agents + ) + + # Check for validity + check_graph_validity( + allowed_speaker_transitions_dict=self.allowed_speaker_transitions_dict, + agents=self.agents, + ) + + # Check select speaker messages, prompts, roles, and retries have values + if self.select_speaker_message_template is None or len(self.select_speaker_message_template) == 0: + raise ValueError("select_speaker_message_template cannot be empty or None.") + + if self.select_speaker_prompt_template is not None and len(self.select_speaker_prompt_template) == 0: + self.select_speaker_prompt_template = None + + if self.role_for_select_speaker_messages is None or len(self.role_for_select_speaker_messages) == 0: + raise ValueError("role_for_select_speaker_messages cannot be empty or None.") + + if self.select_speaker_auto_multiple_template is None or len(self.select_speaker_auto_multiple_template) == 0: + raise ValueError("select_speaker_auto_multiple_template cannot be empty or None.") + + if self.select_speaker_auto_none_template is None or len(self.select_speaker_auto_none_template) == 0: + raise ValueError("select_speaker_auto_none_template cannot be empty or None.") + + if self.max_retries_for_selecting_speaker is None or len(self.role_for_select_speaker_messages) == 0: + raise ValueError("role_for_select_speaker_messages cannot be empty or None.") + + # Validate max select speakers retries + if self.max_retries_for_selecting_speaker is None or not isinstance( + self.max_retries_for_selecting_speaker, int + ): + raise ValueError("max_retries_for_selecting_speaker cannot be None or non-int") + elif self.max_retries_for_selecting_speaker < 0: + raise ValueError("max_retries_for_selecting_speaker must be greater than or equal to zero") + + # Load message transforms here (load once for the Group Chat so we don't have to re-initiate it and it maintains the cache across subsequent select speaker calls) + if self.select_speaker_transform_messages is not None: + if isinstance(self.select_speaker_transform_messages, transform_messages.TransformMessages): + self._speaker_selection_transforms = self.select_speaker_transform_messages + else: + raise ValueError("select_speaker_transform_messages must be None or MessageTransforms.") + else: + self._speaker_selection_transforms = None + + # Validate select_speaker_auto_verbose + if self.select_speaker_auto_verbose is None or not isinstance(self.select_speaker_auto_verbose, bool): + raise ValueError("select_speaker_auto_verbose cannot be None or non-bool") + + @property + def agent_names(self) -> list[str]: + """Return the names of the agents in the group chat.""" + return [agent.name for agent in self.agents] + + def reset(self): + """Reset the group chat.""" + self.messages.clear() + + def append(self, message: dict[str, Any], speaker: Agent): + """Append a message to the group chat. + We cast the content to str here so that it can be managed by text-based + model. + """ + # set the name to speaker's name if the role is not function + # if the role is tool, it is OK to modify the name + if message["role"] != "function": + message["name"] = speaker.name + if not isinstance(message["content"], str) and not isinstance(message["content"], list): + message["content"] = str(message["content"]) + message["content"] = content_str(message["content"]) + self.messages.append(message) + + def agent_by_name( + self, name: str, recursive: bool = False, raise_on_name_conflict: bool = False + ) -> Optional[Agent]: + """Returns the agent with a given name. If recursive is True, it will search in nested teams.""" + agents = self.nested_agents() if recursive else self.agents + filtered_agents = [agent for agent in agents if agent.name == name] + + if raise_on_name_conflict and len(filtered_agents) > 1: + raise AgentNameConflictError() + + return filtered_agents[0] if filtered_agents else None + + def nested_agents(self) -> list[Agent]: + """Returns all agents in the group chat manager.""" + agents = self.agents.copy() + for agent in agents: + if isinstance(agent, GroupChatManager): + # Recursive call for nested teams + agents.extend(agent.groupchat.nested_agents()) + return agents + + def next_agent(self, agent: Agent, agents: Optional[list[Agent]] = None) -> Agent: + """Return the next agent in the list.""" + if agents is None: + agents = self.agents + + # Ensure the provided list of agents is a subset of self.agents + if not set(agents).issubset(set(self.agents)): + raise UndefinedNextAgentError() + + # What index is the agent? (-1 if not present) + idx = self.agent_names.index(agent.name) if agent.name in self.agent_names else -1 + + # Return the next agent + if agents == self.agents: + return agents[(idx + 1) % len(agents)] + else: + offset = idx + 1 + for i in range(len(self.agents)): + if self.agents[(offset + i) % len(self.agents)] in agents: + return self.agents[(offset + i) % len(self.agents)] + + # Explicitly handle cases where no valid next agent exists in the provided subset. + raise UndefinedNextAgentError() + + def select_speaker_msg(self, agents: Optional[list[Agent]] = None) -> str: + """Return the system message for selecting the next speaker. This is always the *first* message in the context.""" + if agents is None: + agents = self.agents + + roles = self._participant_roles(agents) + agentlist = f"{[agent.name for agent in agents]}" + + return_msg = self.select_speaker_message_template.format(roles=roles, agentlist=agentlist) + return return_msg + + def select_speaker_prompt(self, agents: Optional[list[Agent]] = None) -> str: + """Return the floating system prompt selecting the next speaker. + This is always the *last* message in the context. + Will return None if the select_speaker_prompt_template is None. + """ + if self.select_speaker_prompt_template is None: + return None + + if agents is None: + agents = self.agents + + agentlist = f"{[agent.name for agent in agents]}" + + return_prompt = f"{self.select_speaker_prompt_template}".replace("{agentlist}", agentlist) + return return_prompt + + def introductions_msg(self, agents: Optional[list[Agent]] = None) -> str: + """Return the system message for selecting the next speaker. This is always the *first* message in the context.""" + if agents is None: + agents = self.agents + + # Use the class attribute instead of a hardcoded string + intro_msg = self.DEFAULT_INTRO_MSG + participant_roles = self._participant_roles(agents) + + return f"{intro_msg}\n\n{participant_roles}" + + def manual_select_speaker(self, agents: Optional[list[Agent]] = None) -> Union[Agent, None]: + """Manually select the next speaker.""" + iostream = IOStream.get_default() + + if agents is None: + agents = self.agents + + iostream.send(SelectSpeakerEvent(agents=agents)) + + try_count = 0 + # Assume the user will enter a valid number within 3 tries, otherwise use auto selection to avoid blocking. + while try_count <= 3: + try_count += 1 + if try_count >= 3: + iostream.send(SelectSpeakerTryCountExceededEvent(try_count=try_count, agents=agents)) + break + try: + i = iostream.input( + "Enter the number of the next speaker (enter nothing or `q` to use auto selection): " + ) + if i == "" or i == "q": + break + i = int(i) + if i > 0 and i <= len(agents): + return agents[i - 1] + else: + raise ValueError + except ValueError: + iostream.send(SelectSpeakerInvalidInputEvent(agents=agents)) + return None + + def random_select_speaker(self, agents: Optional[list[Agent]] = None) -> Union[Agent, None]: + """Randomly select the next speaker.""" + if agents is None: + agents = self.agents + return random.choice(agents) + + def _prepare_and_select_agents( + self, + last_speaker: Agent, + ) -> tuple[Optional[Agent], list[Agent], Optional[list[dict[str, Any]]]]: + # If self.speaker_selection_method is a callable, call it to get the next speaker. + # If self.speaker_selection_method is a string, return it. + speaker_selection_method = self.speaker_selection_method + if isinstance(self.speaker_selection_method, Callable): + selected_agent = self.speaker_selection_method(last_speaker, self) + if selected_agent is None: + raise NoEligibleSpeakerError( + "Custom speaker selection function returned None. Terminating conversation." + ) + elif isinstance(selected_agent, Agent): + if selected_agent in self.agents: + return selected_agent, self.agents, None + else: + raise ValueError( + f"Custom speaker selection function returned an agent {selected_agent.name} not in the group chat." + ) + elif isinstance(selected_agent, str): + # If returned a string, assume it is a speaker selection method + speaker_selection_method = selected_agent + else: + raise ValueError( + f"Custom speaker selection function returned an object of type {type(selected_agent)} instead of Agent or str." + ) + + if speaker_selection_method.lower() not in self._VALID_SPEAKER_SELECTION_METHODS: + raise ValueError( + f"GroupChat speaker_selection_method is set to '{speaker_selection_method}'. " + f"It should be one of {self._VALID_SPEAKER_SELECTION_METHODS} (case insensitive). " + ) + + # If provided a list, make sure the agent is in the list + allow_repeat_speaker = ( + self.allow_repeat_speaker + if isinstance(self.allow_repeat_speaker, bool) or self.allow_repeat_speaker is None + else last_speaker in self.allow_repeat_speaker + ) + + agents = self.agents + n_agents = len(agents) + # Warn if GroupChat is underpopulated + if n_agents < 2: + raise ValueError( + f"GroupChat is underpopulated with {n_agents} agents. " + "Please add more agents to the GroupChat or use direct communication instead." + ) + elif n_agents == 2 and speaker_selection_method.lower() != "round_robin" and allow_repeat_speaker: + logger.warning( + f"GroupChat is underpopulated with {n_agents} agents. " + "Consider setting speaker_selection_method to 'round_robin' or allow_repeat_speaker to False, " + "or use direct communication, unless repeated speaker is desired." + ) + + if ( + self.func_call_filter + and self.messages + and ("function_call" in self.messages[-1] or "tool_calls" in self.messages[-1]) + ): + funcs = [] + if "function_call" in self.messages[-1]: + funcs += [self.messages[-1]["function_call"]["name"]] + if "tool_calls" in self.messages[-1]: + funcs += [ + tool["function"]["name"] for tool in self.messages[-1]["tool_calls"] if tool["type"] == "function" + ] + + # find agents with the right function_map which contains the function name + agents = [agent for agent in self.agents if agent.can_execute_function(funcs)] + if len(agents) == 1: + # only one agent can execute the function + return agents[0], agents, None + elif not agents: + # find all the agents with function_map + agents = [agent for agent in self.agents if agent.function_map] + if len(agents) == 1: + return agents[0], agents, None + elif not agents: + raise ValueError( + f"No agent can execute the function {', '.join(funcs)}. " + "Please check the function_map of the agents." + ) + # remove the last speaker from the list to avoid selecting the same speaker if allow_repeat_speaker is False + agents = [agent for agent in agents if agent != last_speaker] if allow_repeat_speaker is False else agents + + # Filter agents with allowed_speaker_transitions_dict + + is_last_speaker_in_group = last_speaker in self.agents + + # this condition means last_speaker is a sink in the graph, then no agents are eligible + if last_speaker not in self.allowed_speaker_transitions_dict and is_last_speaker_in_group: + raise NoEligibleSpeakerError( + f"Last speaker {last_speaker.name} is not in the allowed_speaker_transitions_dict." + ) + # last_speaker is not in the group, so all agents are eligible + elif last_speaker not in self.allowed_speaker_transitions_dict and not is_last_speaker_in_group: + graph_eligible_agents = [] + else: + # Extract agent names from the list of agents + graph_eligible_agents = [ + agent for agent in agents if agent in self.allowed_speaker_transitions_dict[last_speaker] + ] + + # If there is only one eligible agent, just return it to avoid the speaker selection prompt + if len(graph_eligible_agents) == 1: + return graph_eligible_agents[0], graph_eligible_agents, None + + # If there are no eligible agents, return None, which means all agents will be taken into consideration in the next step + if len(graph_eligible_agents) == 0: + graph_eligible_agents = None + + # Use the selected speaker selection method + select_speaker_messages = None + if speaker_selection_method.lower() == "manual": + selected_agent = self.manual_select_speaker(graph_eligible_agents) + elif speaker_selection_method.lower() == "round_robin": + selected_agent = self.next_agent(last_speaker, graph_eligible_agents) + elif speaker_selection_method.lower() == "random": + selected_agent = self.random_select_speaker(graph_eligible_agents) + else: # auto + selected_agent = None + select_speaker_messages = self.messages.copy() + # If last message is a tool call or function call, blank the call so the api doesn't throw + if select_speaker_messages[-1].get("function_call", False): + select_speaker_messages[-1] = dict(select_speaker_messages[-1], function_call=None) + if select_speaker_messages[-1].get("tool_calls", False): + select_speaker_messages[-1] = dict(select_speaker_messages[-1], tool_calls=None) + return selected_agent, graph_eligible_agents, select_speaker_messages + + def select_speaker(self, last_speaker: Agent, selector: ConversableAgent) -> Agent: + """Select the next speaker (with requery).""" + # Prepare the list of available agents and select an agent if selection method allows (non-auto) + selected_agent, agents, messages = self._prepare_and_select_agents(last_speaker) + if selected_agent: + return selected_agent + elif self.speaker_selection_method == "manual": + # An agent has not been selected while in manual mode, so move to the next agent + return self.next_agent(last_speaker) + + # auto speaker selection with 2-agent chat + return self._auto_select_speaker(last_speaker, selector, messages, agents) + + async def a_select_speaker(self, last_speaker: Agent, selector: ConversableAgent) -> Agent: + """Select the next speaker (with requery), asynchronously.""" + selected_agent, agents, messages = self._prepare_and_select_agents(last_speaker) + if selected_agent: + return selected_agent + elif self.speaker_selection_method == "manual": + # An agent has not been selected while in manual mode, so move to the next agent + return self.next_agent(last_speaker) + + # auto speaker selection with 2-agent chat + return await self.a_auto_select_speaker(last_speaker, selector, messages, agents) + + def _finalize_speaker(self, last_speaker: Agent, final: bool, name: str, agents: Optional[list[Agent]]) -> Agent: + if not final: + # the LLM client is None, thus no reply is generated. Use round robin instead. + return self.next_agent(last_speaker, agents) + + # If exactly one agent is mentioned, use it. Otherwise, leave the OAI response unmodified + mentions = self._mentioned_agents(name, agents) + if len(mentions) == 1: + name = next(iter(mentions)) + else: + logger.warning( + f"GroupChat select_speaker failed to resolve the next speaker's name. This is because the speaker selection OAI call returned:\n{name}" + ) + + # Return the result + agent = self.agent_by_name(name) + return agent if agent else self.next_agent(last_speaker, agents) + + def _register_client_from_config(self, agent: Agent, config: dict): + model_client_cls_to_match = config.get("model_client_cls") + if model_client_cls_to_match: + if not self.select_speaker_auto_model_client_cls: + raise ValueError( + "A custom model was detected in the config but no 'model_client_cls' " + "was supplied for registration in GroupChat." + ) + + if isinstance(self.select_speaker_auto_model_client_cls, list): + # Register the first custom model client class matching the name specified in the config + matching_model_cls = [ + client_cls + for client_cls in self.select_speaker_auto_model_client_cls + if client_cls.__name__ == model_client_cls_to_match + ] + if len(set(matching_model_cls)) > 1: + raise RuntimeError( + f"More than one unique 'model_client_cls' with __name__ '{model_client_cls_to_match}'." + ) + if not matching_model_cls: + raise ValueError( + "No model's __name__ matches the model client class " + f"'{model_client_cls_to_match}' specified in select_speaker_auto_llm_config." + ) + select_speaker_auto_model_client_cls = matching_model_cls[0] + else: + # Register the only custom model client + select_speaker_auto_model_client_cls = self.select_speaker_auto_model_client_cls + + agent.register_model_client(select_speaker_auto_model_client_cls) + + def _register_custom_model_clients(self, agent: ConversableAgent): + if not self.select_speaker_auto_llm_config: + return + + config_format_is_list = "config_list" in self.select_speaker_auto_llm_config + if config_format_is_list: + for config in self.select_speaker_auto_llm_config["config_list"]: + self._register_client_from_config(agent, config) + elif not config_format_is_list: + self._register_client_from_config(agent, self.select_speaker_auto_llm_config) + + def _create_internal_agents( + self, agents, max_attempts, messages, validate_speaker_name, selector: Optional[ConversableAgent] = None + ): + checking_agent = ConversableAgent("checking_agent", default_auto_reply=max_attempts) + + # Register the speaker validation function with the checking agent + checking_agent.register_reply( + [ConversableAgent, None], + reply_func=validate_speaker_name, # Validate each response + remove_other_reply_funcs=True, + ) + + # Override the selector's config if one was passed as a parameter to this class + speaker_selection_llm_config = self.select_speaker_auto_llm_config or selector.llm_config + + if speaker_selection_llm_config is False: + raise ValueError( + "The group chat's internal speaker selection agent does not have an LLM configuration. Please provide a valid LLM config to the group chat's GroupChatManager or set it with the select_speaker_auto_llm_config parameter." + ) + + # Agent for selecting a single agent name from the response + speaker_selection_agent = ConversableAgent( + "speaker_selection_agent", + system_message=self.select_speaker_msg(agents), + chat_messages={checking_agent: messages}, + llm_config=speaker_selection_llm_config, + human_input_mode="NEVER", + # Suppresses some extra terminal outputs, outputs will be handled by select_speaker_auto_verbose + ) + + # Register any custom model passed in select_speaker_auto_llm_config with the speaker_selection_agent + self._register_custom_model_clients(speaker_selection_agent) + + return checking_agent, speaker_selection_agent + + def _auto_select_speaker( + self, + last_speaker: Agent, + selector: ConversableAgent, + messages: Optional[list[dict[str, Any]]], + agents: Optional[list[Agent]], + ) -> Agent: + """Selects next speaker for the "auto" speaker selection method. Utilises its own two-agent chat to determine the next speaker and supports requerying. + + Speaker selection for "auto" speaker selection method: + 1. Create a two-agent chat with a speaker selector agent and a speaker validator agent, like a nested chat + 2. Inject the group messages into the new chat + 3. Run the two-agent chat, evaluating the result of response from the speaker selector agent: + - If a single agent is provided then we return it and finish. If not, we add an additional message to this nested chat in an attempt to guide the LLM to a single agent response + 4. Chat continues until a single agent is nominated or there are no more attempts left + 5. If we run out of turns and no single agent can be determined, the next speaker in the list of agents is returned + + Args: + last_speaker: The previous speaker in the group chat + selector: The ConversableAgent that initiated the speaker selection + messages: Current chat messages + agents: Valid list of agents for speaker selection + + Returns: + A counter for mentioned agents. + """ + # If no agents are passed in, assign all the group chat's agents + if agents is None: + agents = self.agents + + # The maximum number of speaker selection attempts (including requeries) + # is the initial speaker selection attempt plus the maximum number of retries. + # We track these and use them in the validation function as we can't + # access the max_turns from within validate_speaker_name. + max_attempts = 1 + self.max_retries_for_selecting_speaker + attempts_left = max_attempts + attempt = 0 + + # Registered reply function for checking_agent, checks the result of the response for agent names + def validate_speaker_name( + recipient, messages, sender, config + ) -> tuple[bool, Optional[Union[str, dict[str, Any]]]]: + # The number of retries left, starting at max_retries_for_selecting_speaker + nonlocal attempts_left + nonlocal attempt + + attempt = attempt + 1 + attempts_left = attempts_left - 1 + + return self._validate_speaker_name(recipient, messages, sender, config, attempts_left, attempt, agents) + + # Two-agent chat for speaker selection + + # Agent for checking the response from the speaker_select_agent + checking_agent, speaker_selection_agent = self._create_internal_agents( + agents, max_attempts, messages, validate_speaker_name, selector + ) + + # Create the starting message + if self.select_speaker_prompt_template is not None: + start_message = { + "content": self.select_speaker_prompt(agents), + "name": "checking_agent", + "override_role": self.role_for_select_speaker_messages, + } + else: + start_message = messages[-1] + + # Add the message transforms, if any, to the speaker selection agent + if self._speaker_selection_transforms is not None: + self._speaker_selection_transforms.add_to_agent(speaker_selection_agent) + + # Run the speaker selection chat + result = checking_agent.initiate_chat( + speaker_selection_agent, + cache=None, # don't use caching for the speaker selection chat + message=start_message, + max_turns=2 + * max(1, max_attempts), # Limiting the chat to the number of attempts, including the initial one + clear_history=False, + silent=not self.select_speaker_auto_verbose, # Base silence on the verbose attribute + ) + + return self._process_speaker_selection_result(result, last_speaker, agents) + + async def a_auto_select_speaker( + self, + last_speaker: Agent, + selector: ConversableAgent, + messages: Optional[list[dict[str, Any]]], + agents: Optional[list[Agent]], + ) -> Agent: + """(Asynchronous) Selects next speaker for the "auto" speaker selection method. Utilises its own two-agent chat to determine the next speaker and supports requerying. + + Speaker selection for "auto" speaker selection method: + 1. Create a two-agent chat with a speaker selector agent and a speaker validator agent, like a nested chat + 2. Inject the group messages into the new chat + 3. Run the two-agent chat, evaluating the result of response from the speaker selector agent: + - If a single agent is provided then we return it and finish. If not, we add an additional message to this nested chat in an attempt to guide the LLM to a single agent response + 4. Chat continues until a single agent is nominated or there are no more attempts left + 5. If we run out of turns and no single agent can be determined, the next speaker in the list of agents is returned + + Args: + last_speaker: The previous speaker in the group chat + selector: The ConversableAgent that initiated the speaker selection + messages: Current chat messages + agents: Valid list of agents for speaker selection + + Returns: + A counter for mentioned agents. + """ + # If no agents are passed in, assign all the group chat's agents + if agents is None: + agents = self.agents + + # The maximum number of speaker selection attempts (including requeries) + # We track these and use them in the validation function as we can't + # access the max_turns from within validate_speaker_name + max_attempts = 1 + self.max_retries_for_selecting_speaker + attempts_left = max_attempts + attempt = 0 + + # Registered reply function for checking_agent, checks the result of the response for agent names + def validate_speaker_name( + recipient, messages, sender, config + ) -> tuple[bool, Optional[Union[str, dict[str, Any]]]]: + # The number of retries left, starting at max_retries_for_selecting_speaker + nonlocal attempts_left + nonlocal attempt + + attempt = attempt + 1 + attempts_left = attempts_left - 1 + + return self._validate_speaker_name(recipient, messages, sender, config, attempts_left, attempt, agents) + + # Two-agent chat for speaker selection + + # Agent for checking the response from the speaker_select_agent + checking_agent, speaker_selection_agent = self._create_internal_agents( + agents, max_attempts, messages, validate_speaker_name, selector + ) + + # Create the starting message + if self.select_speaker_prompt_template is not None: + start_message = { + "content": self.select_speaker_prompt(agents), + "override_role": self.role_for_select_speaker_messages, + } + else: + start_message = messages[-1] + + # Add the message transforms, if any, to the speaker selection agent + if self._speaker_selection_transforms is not None: + self._speaker_selection_transforms.add_to_agent(speaker_selection_agent) + + # Run the speaker selection chat + result = await checking_agent.a_initiate_chat( + speaker_selection_agent, + cache=None, # don't use caching for the speaker selection chat + message=start_message, + max_turns=2 + * max(1, max_attempts), # Limiting the chat to the number of attempts, including the initial one + clear_history=False, + silent=not self.select_speaker_auto_verbose, # Base silence on the verbose attribute + ) + + return self._process_speaker_selection_result(result, last_speaker, agents) + + def _validate_speaker_name( + self, recipient, messages, sender, config, attempts_left, attempt, agents + ) -> tuple[bool, Optional[Union[str, dict[str, Any]]]]: + """Validates the speaker response for each round in the internal 2-agent + chat within the auto select speaker method. + + Used by auto_select_speaker and a_auto_select_speaker. + """ + # Validate the speaker name selected + select_name = messages[-1]["content"].strip() + + mentions = self._mentioned_agents(select_name, agents) + + # Output the query and requery results + if self.select_speaker_auto_verbose: + iostream = IOStream.get_default() + no_of_mentions = len(mentions) + if no_of_mentions == 1: + # Success on retry, we have just one name mentioned + iostream.send( + SpeakerAttemptSuccessfulEvent( + mentions=mentions, + attempt=attempt, + attempts_left=attempts_left, + select_speaker_auto_verbose=self.select_speaker_auto_verbose, + ) + ) + elif no_of_mentions == 1: + iostream.send( + SpeakerAttemptFailedMultipleAgentsEvent( + mentions=mentions, + attempt=attempt, + attempts_left=attempts_left, + select_speaker_auto_verbose=self.select_speaker_auto_verbose, + ) + ) + else: + iostream.send( + SpeakerAttemptFailedNoAgentsEvent( + mentions=mentions, + attempt=attempt, + attempts_left=attempts_left, + select_speaker_auto_verbose=self.select_speaker_auto_verbose, + ) + ) + + if len(mentions) == 1: + # Success on retry, we have just one name mentioned + selected_agent_name = next(iter(mentions)) + + # Add the selected agent to the response so we can return it + messages.append({"role": "user", "content": f"[AGENT SELECTED]{selected_agent_name}"}) + + elif len(mentions) > 1: + # More than one name on requery so add additional reminder prompt for next retry + + if attempts_left: + # Message to return to the chat for the next attempt + agentlist = f"{[agent.name for agent in agents]}" + + return True, { + "content": self.select_speaker_auto_multiple_template.format(agentlist=agentlist), + "name": "checking_agent", + "override_role": self.role_for_select_speaker_messages, + } + else: + # Final failure, no attempts left + messages.append({ + "role": "user", + "content": f"[AGENT SELECTION FAILED]Select speaker attempt #{attempt} of {attempt + attempts_left} failed as it returned multiple names.", + }) + + else: + # No names at all on requery so add additional reminder prompt for next retry + + if attempts_left: + # Message to return to the chat for the next attempt + agentlist = f"{[agent.name for agent in agents]}" + + return True, { + "content": self.select_speaker_auto_none_template.format(agentlist=agentlist), + "name": "checking_agent", + "override_role": self.role_for_select_speaker_messages, + } + else: + # Final failure, no attempts left + messages.append({ + "role": "user", + "content": f"[AGENT SELECTION FAILED]Select speaker attempt #{attempt} of {attempt + attempts_left} failed as it did not include any agent names.", + }) + + return True, None + + def _process_speaker_selection_result(self, result, last_speaker: ConversableAgent, agents: Optional[list[Agent]]): + """Checks the result of the auto_select_speaker function, returning the + agent to speak. + + Used by auto_select_speaker and a_auto_select_speaker. + """ + if len(result.chat_history) > 0: + # Use the final message, which will have the selected agent or reason for failure + final_message = result.chat_history[-1]["content"] + + if "[AGENT SELECTED]" in final_message: + # Have successfully selected an agent, return it + return self.agent_by_name(final_message.replace("[AGENT SELECTED]", "")) + + else: # "[AGENT SELECTION FAILED]" + # Failed to select an agent, so we'll select the next agent in the list + next_agent = self.next_agent(last_speaker, agents) + + # No agent, return the failed reason + return next_agent + + def _participant_roles(self, agents: list[Agent] = None) -> str: + # Default to all agents registered + if agents is None: + agents = self.agents + + roles = [] + for agent in agents: + if agent.description.strip() == "": + logger.warning( + f"The agent '{agent.name}' has an empty description, and may not work well with GroupChat." + ) + roles.append(f"{agent.name}: {agent.description}".strip()) + return "\n".join(roles) + + def _mentioned_agents(self, message_content: Union[str, list], agents: Optional[list[Agent]]) -> dict: + """Counts the number of times each agent is mentioned in the provided message content. + Agent names will match under any of the following conditions (all case-sensitive): + - Exact name match + - If the agent name has underscores it will match with spaces instead (e.g. 'Story_writer' == 'Story writer') + - If the agent name has underscores it will match with '\\_' instead of '_' (e.g. 'Story_writer' == 'Story\\_writer') + + Args: + message_content (Union[str, List]): The content of the message, either as a single string or a list of strings. + agents (List[Agent]): A list of Agent objects, each having a 'name' attribute to be searched in the message content. + + Returns: + Dict: a counter for mentioned agents. + """ + if agents is None: + agents = self.agents + + # Cast message content to str + if isinstance(message_content, dict): + message_content = message_content["content"] + message_content = content_str(message_content) + + mentions = dict() + for agent in agents: + # Finds agent mentions, taking word boundaries into account, + # accommodates escaping underscores and underscores as spaces + regex = ( + r"(?<=\W)(" + + re.escape(agent.name) + + r"|" + + re.escape(agent.name.replace("_", " ")) + + r"|" + + re.escape(agent.name.replace("_", r"\_")) + + r")(?=\W)" + ) + count = len(re.findall(regex, f" {message_content} ")) # Pad the message to help with matching + if count > 0: + mentions[agent.name] = count + return mentions + + +@export_module("autogen") +class GroupChatManager(ConversableAgent): + """(In preview) A chat manager agent that can manage a group chat of multiple agents.""" + + def __init__( + self, + groupchat: GroupChat, + name: Optional[str] = "chat_manager", + # unlimited consecutive auto reply by default + max_consecutive_auto_reply: Optional[int] = sys.maxsize, + human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER", + system_message: Optional[Union[str, list]] = "Group chat manager.", + silent: bool = False, + **kwargs: Any, + ): + if ( + kwargs.get("llm_config") + and isinstance(kwargs["llm_config"], dict) + and (kwargs["llm_config"].get("functions") or kwargs["llm_config"].get("tools")) + ): + raise ValueError( + "GroupChatManager is not allowed to make function/tool calls. Please remove the 'functions' or 'tools' config in 'llm_config' you passed in." + ) + + super().__init__( + name=name, + max_consecutive_auto_reply=max_consecutive_auto_reply, + human_input_mode=human_input_mode, + system_message=system_message, + **kwargs, + ) + if logging_enabled(): + log_new_agent(self, locals()) + # Store groupchat + self._groupchat = groupchat + + self._last_speaker = None + self._silent = silent + + # Order of register_reply is important. + # Allow sync chat if initiated using initiate_chat + self.register_reply(Agent, GroupChatManager.run_chat, config=groupchat, reset_config=GroupChat.reset) + # Allow async chat if initiated using a_initiate_chat + self.register_reply( + Agent, + GroupChatManager.a_run_chat, + config=groupchat, + reset_config=GroupChat.reset, + ignore_async_in_sync_chat=True, + ) + + @property + def groupchat(self) -> GroupChat: + """Returns the group chat managed by the group chat manager.""" + return self._groupchat + + def chat_messages_for_summary(self, agent: Agent) -> list[dict[str, Any]]: + """The list of messages in the group chat as a conversation to summarize. + The agent is ignored. + """ + return self._groupchat.messages + + def _prepare_chat( + self, + recipient: ConversableAgent, + clear_history: bool, + prepare_recipient: bool = True, + reply_at_receive: bool = True, + ) -> None: + super()._prepare_chat(recipient, clear_history, prepare_recipient, reply_at_receive) + + if clear_history: + self._groupchat.reset() + + for agent in self._groupchat.agents: + if (recipient != agent or prepare_recipient) and isinstance(agent, ConversableAgent): + agent._prepare_chat(self, clear_history, False, reply_at_receive) + + @property + def last_speaker(self) -> Agent: + """Return the agent who sent the last message to group chat manager. + + In a group chat, an agent will always send a message to the group chat manager, and the group chat manager will + send the message to all other agents in the group chat. So, when an agent receives a message, it will always be + from the group chat manager. With this property, the agent receiving the message can know who actually sent the + message. + + Example: + ```python + from autogen import ConversableAgent + from autogen import GroupChat, GroupChatManager + + + def print_messages(recipient, messages, sender, config): + # Print the message immediately + print(f"Sender: {sender.name} | Recipient: {recipient.name} | Message: {messages[-1].get('content')}") + print(f"Real Sender: {sender.last_speaker.name}") + assert sender.last_speaker.name in messages[-1].get("content") + return False, None # Required to ensure the agent communication flow continues + + + agent_a = ConversableAgent("agent A", default_auto_reply="I'm agent A.") + agent_b = ConversableAgent("agent B", default_auto_reply="I'm agent B.") + agent_c = ConversableAgent("agent C", default_auto_reply="I'm agent C.") + for agent in [agent_a, agent_b, agent_c]: + agent.register_reply([ConversableAgent, None], reply_func=print_messages, config=None) + group_chat = GroupChat( + [agent_a, agent_b, agent_c], + messages=[], + max_round=6, + speaker_selection_method="random", + allow_repeat_speaker=True, + ) + chat_manager = GroupChatManager(group_chat) + groupchat_result = agent_a.initiate_chat(chat_manager, message="Hi, there, I'm agent A.") + ``` + """ + return self._last_speaker + + def run_chat( + self, + messages: Optional[list[dict[str, Any]]] = None, + sender: Optional[Agent] = None, + config: Optional[GroupChat] = None, + ) -> tuple[bool, Optional[str]]: + """Run a group chat.""" + iostream = IOStream.get_default() + + if messages is None: + messages = self._oai_messages[sender] + message = messages[-1] + speaker = sender + groupchat = config + send_introductions = getattr(groupchat, "send_introductions", False) + silent = getattr(self, "_silent", False) + termination_reason = None + + if send_introductions: + # Broadcast the intro + intro = groupchat.introductions_msg() + for agent in groupchat.agents: + self.send(intro, agent, request_reply=False, silent=True) + # NOTE: We do not also append to groupchat.messages, + # since groupchat handles its own introductions + + if self.client_cache is not None: + for a in groupchat.agents: + a.previous_cache = a.client_cache + a.client_cache = self.client_cache + for i in range(groupchat.max_round): + self._last_speaker = speaker + groupchat.append(message, speaker) + # broadcast the message to all agents except the speaker + for agent in groupchat.agents: + if agent != speaker: + self.send(message, agent, request_reply=False, silent=True) + if self._is_termination_msg(message): + # The conversation is over + termination_reason = f"Termination message condition on the GroupChatManager '{self.name}' met" + break + elif i == groupchat.max_round - 1: + # It's the last round + termination_reason = f"Maximum rounds ({groupchat.max_round}) reached" + break + try: + # select the next speaker + speaker = groupchat.select_speaker(speaker, self) + if not silent: + iostream = IOStream.get_default() + iostream.send(GroupChatRunChatEvent(speaker=speaker, silent=silent)) + # let the speaker speak + reply = speaker.generate_reply(sender=self) + except KeyboardInterrupt: + # let the admin agent speak if interrupted + if groupchat.admin_name in groupchat.agent_names: + # admin agent is one of the participants + speaker = groupchat.agent_by_name(groupchat.admin_name) + reply = speaker.generate_reply(sender=self) + else: + # admin agent is not found in the participants + raise + except NoEligibleSpeakerError: + # No eligible speaker, terminate the conversation + termination_reason = "No next speaker selected" + break + + if reply is None: + # no reply is generated, exit the chat + termination_reason = "No reply generated" + break + + # check for "clear history" phrase in reply and activate clear history function if found + if ( + groupchat.enable_clear_history + and isinstance(reply, dict) + and reply["content"] + and "CLEAR HISTORY" in reply["content"].upper() + ): + reply["content"] = self.clear_agents_history(reply, groupchat) + + # The speaker sends the message without requesting a reply + speaker.send(reply, self, request_reply=False, silent=silent) + message = self.last_message(speaker) + if self.client_cache is not None: + for a in groupchat.agents: + a.client_cache = a.previous_cache + a.previous_cache = None + + if termination_reason: + iostream.send(TerminationEvent(termination_reason=termination_reason)) + + return True, None + + async def a_run_chat( + self, + messages: Optional[list[dict[str, Any]]] = None, + sender: Optional[Agent] = None, + config: Optional[GroupChat] = None, + ): + """Run a group chat asynchronously.""" + iostream = IOStream.get_default() + + if messages is None: + messages = self._oai_messages[sender] + message = messages[-1] + speaker = sender + groupchat = config + send_introductions = getattr(groupchat, "send_introductions", False) + silent = getattr(self, "_silent", False) + termination_reason = None + + if send_introductions: + # Broadcast the intro + intro = groupchat.introductions_msg() + for agent in groupchat.agents: + await self.a_send(intro, agent, request_reply=False, silent=True) + # NOTE: We do not also append to groupchat.messages, + # since groupchat handles its own introductions + + if self.client_cache is not None: + for a in groupchat.agents: + a.previous_cache = a.client_cache + a.client_cache = self.client_cache + for i in range(groupchat.max_round): + groupchat.append(message, speaker) + self._last_speaker = speaker + + if self._is_termination_msg(message): + # The conversation is over + termination_reason = f"Termination message condition on the GroupChatManager '{self.name}' met" + break + + # broadcast the message to all agents except the speaker + for agent in groupchat.agents: + if agent != speaker: + await self.a_send(message, agent, request_reply=False, silent=True) + if i == groupchat.max_round - 1: + # the last round + termination_reason = f"Maximum rounds ({groupchat.max_round}) reached" + break + try: + # select the next speaker + speaker = await groupchat.a_select_speaker(speaker, self) + # let the speaker speak + reply = await speaker.a_generate_reply(sender=self) + except KeyboardInterrupt: + # let the admin agent speak if interrupted + if groupchat.admin_name in groupchat.agent_names: + # admin agent is one of the participants + speaker = groupchat.agent_by_name(groupchat.admin_name) + reply = await speaker.a_generate_reply(sender=self) + else: + # admin agent is not found in the participants + raise + except NoEligibleSpeakerError: + # No eligible speaker, terminate the conversation + termination_reason = "No next speaker selected" + break + + if reply is None: + # no reply is generated, exit the chat + termination_reason = "No reply generated" + break + + # The speaker sends the message without requesting a reply + await speaker.a_send(reply, self, request_reply=False, silent=silent) + message = self.last_message(speaker) + if self.client_cache is not None: + for a in groupchat.agents: + a.client_cache = a.previous_cache + a.previous_cache = None + + if termination_reason: + iostream.send(TerminationEvent(termination_reason=termination_reason)) + + return True, None + + def resume( + self, + messages: Union[list[dict[str, Any]], str], + remove_termination_string: Optional[Union[str, Callable[[str], str]]] = None, + silent: Optional[bool] = False, + ) -> tuple[ConversableAgent, dict[str, Any]]: + """Resumes a group chat using the previous messages as a starting point. Requires the agents, group chat, and group chat manager to be established + as per the original group chat. + + Args: + messages: The content of the previous chat's messages, either as a Json string or a list of message dictionaries. + remove_termination_string: Remove the termination string from the last message to prevent immediate termination + If a string is provided, this string will be removed from last message. + If a function is provided, the last message will be passed to this function. + silent: (Experimental) whether to print the messages for this conversation. Default is False. + + Returns: + A tuple containing the last agent who spoke and their message + """ + # Convert messages from string to messages list, if needed + if isinstance(messages, str): + messages = self.messages_from_string(messages) + elif isinstance(messages, list) and all(isinstance(item, dict) for item in messages): + messages = copy.deepcopy(messages) + else: + raise Exception("Messages is not of type str or List[Dict]") + + # Clean up the objects, ensuring there are no messages in the agents and group chat + + # Clear agent message history + for agent in self._groupchat.agents: + if isinstance(agent, ConversableAgent): + agent.clear_history() + + # Clear Manager message history + self.clear_history() + + # Clear GroupChat messages + self._groupchat.reset() + + # Validation of message and agents + + try: + self._valid_resume_messages(messages) + except: + raise + + # Load the messages into the group chat + for i, message in enumerate(messages): + if "name" in message: + message_speaker_agent = self._groupchat.agent_by_name(message["name"]) + else: + # If there's no name, assign the group chat manager (this is an indication the ChatResult messages was used instead of groupchat.messages as state) + message_speaker_agent = self + message["name"] = self.name + + # If it wasn't an agent speaking, it may be the manager + if not message_speaker_agent and message["name"] == self.name: + message_speaker_agent = self + + # Add previous messages to each agent (except the last message, as we'll kick off the conversation with it) + if i != len(messages) - 1: + for agent in self._groupchat.agents: + if agent.name == message["name"]: + # An agent`s message is sent to the Group Chat Manager + agent.send(message, self, request_reply=False, silent=True) + else: + # Otherwise, messages are sent from the Group Chat Manager to the agent + self.send(message, agent, request_reply=False, silent=True) + + # Add previous message to the new groupchat, if it's an admin message the name may not match so add the message directly + if message_speaker_agent: + self._groupchat.append(message, message_speaker_agent) + else: + self._groupchat.messages.append(message) + + # Last speaker agent + last_speaker_name = message["name"] + + # Last message to check for termination (we could avoid this by ignoring termination check for resume in the future) + last_message = message + + # Get last speaker as an agent + previous_last_agent = self._groupchat.agent_by_name(name=last_speaker_name) + + # If we didn't match a last speaker agent, we check that it's the group chat's admin name and assign the manager, if so + if not previous_last_agent and ( + last_speaker_name == self._groupchat.admin_name or last_speaker_name == self.name + ): + previous_last_agent = self + + # Termination removal and check + self._process_resume_termination(remove_termination_string, messages) + + if not silent: + iostream = IOStream.get_default() + iostream.send(GroupChatResumeEvent(last_speaker_name=last_speaker_name, events=messages, silent=silent)) + + # Update group chat settings for resuming + self._groupchat.send_introductions = False + + return previous_last_agent, last_message + + async def a_resume( + self, + messages: Union[list[dict[str, Any]], str], + remove_termination_string: Optional[Union[str, Callable[[str], str]]] = None, + silent: Optional[bool] = False, + ) -> tuple[ConversableAgent, dict[str, Any]]: + """Resumes a group chat using the previous messages as a starting point, asynchronously. Requires the agents, group chat, and group chat manager to be established + as per the original group chat. + + Args: + messages: The content of the previous chat's messages, either as a Json string or a list of message dictionaries. + remove_termination_string: Remove the termination string from the last message to prevent immediate termination + If a string is provided, this string will be removed from last message. + If a function is provided, the last message will be passed to this function, and the function returns the string after processing. + silent: (Experimental) whether to print the messages for this conversation. Default is False. + + Returns: + A tuple containing the last agent who spoke and their message + """ + # Convert messages from string to messages list, if needed + if isinstance(messages, str): + messages = self.messages_from_string(messages) + elif isinstance(messages, list) and all(isinstance(item, dict) for item in messages): + messages = copy.deepcopy(messages) + else: + raise Exception("Messages is not of type str or List[Dict]") + + # Clean up the objects, ensuring there are no messages in the agents and group chat + + # Clear agent message history + for agent in self._groupchat.agents: + if isinstance(agent, ConversableAgent): + agent.clear_history() + + # Clear Manager message history + self.clear_history() + + # Clear GroupChat messages + self._groupchat.reset() + + # Validation of message and agents + + try: + self._valid_resume_messages(messages) + except: + raise + + # Load the messages into the group chat + for i, message in enumerate(messages): + if "name" in message: + message_speaker_agent = self._groupchat.agent_by_name(message["name"]) + else: + # If there's no name, assign the group chat manager (this is an indication the ChatResult messages was used instead of groupchat.messages as state) + message_speaker_agent = self + message["name"] = self.name + + # If it wasn't an agent speaking, it may be the manager + if not message_speaker_agent and message["name"] == self.name: + message_speaker_agent = self + + # Add previous messages to each agent (except the last message, as we'll kick off the conversation with it) + if i != len(messages) - 1: + for agent in self._groupchat.agents: + if agent.name == message["name"]: + # An agent`s message is sent to the Group Chat Manager + await agent.a_send(message, self, request_reply=False, silent=True) + else: + # Otherwise, messages are sent from the Group Chat Manager to the agent + await self.a_send(message, agent, request_reply=False, silent=True) + + # Add previous message to the new groupchat, if it's an admin message the name may not match so add the message directly + if message_speaker_agent: + self._groupchat.append(message, message_speaker_agent) + else: + self._groupchat.messages.append(message) + + # Last speaker agent + last_speaker_name = message["name"] + + # Last message to check for termination (we could avoid this by ignoring termination check for resume in the future) + last_message = message + + # Get last speaker as an agent + previous_last_agent = self._groupchat.agent_by_name(name=last_speaker_name) + + # If we didn't match a last speaker agent, we check that it's the group chat's admin name and assign the manager, if so + if not previous_last_agent and ( + last_speaker_name == self._groupchat.admin_name or last_speaker_name == self.name + ): + previous_last_agent = self + + # Termination removal and check + self._process_resume_termination(remove_termination_string, messages) + + if not silent: + iostream = IOStream.get_default() + iostream.send(GroupChatResumeEvent(last_speaker_name=last_speaker_name, events=messages, silent=silent)) + + # Update group chat settings for resuming + self._groupchat.send_introductions = False + + return previous_last_agent, last_message + + def _valid_resume_messages(self, messages: list[dict[str, Any]]): + """Validates the messages used for resuming + + Args: + messages (List[Dict]): list of messages to resume with + + Returns: + - bool: Whether they are valid for resuming + """ + # Must have messages to start with, otherwise they should run run_chat + if not messages: + raise Exception( + "Cannot resume group chat as no messages were provided. Use GroupChatManager.run_chat or ConversableAgent.initiate_chat to start a new chat." + ) + + # Check that all agents in the chat messages exist in the group chat + for message in messages: + if message.get("name") and ( + not self._groupchat.agent_by_name(message["name"]) + and not message["name"] == self._groupchat.admin_name # ignore group chat's name + and not message["name"] == self.name # ignore group chat manager's name + ): + raise Exception(f"Agent name in message doesn't exist as agent in group chat: {message['name']}") + + def _process_resume_termination( + self, remove_termination_string: Union[str, Callable[[str], str]], messages: list[dict[str, Any]] + ): + """Removes termination string, if required, and checks if termination may occur. + + Args: + remove_termination_string: Remove the termination string from the last message to prevent immediate termination + If a string is provided, this string will be removed from last message. + If a function is provided, the last message will be passed to this function, and the function returns the string after processing. + messages: List of chat messages + + Returns: + None + """ + last_message = messages[-1] + + # Replace any given termination string in the last message + if isinstance(remove_termination_string, str): + + def _remove_termination_string(content: str) -> str: + return content.replace(remove_termination_string, "") + + else: + _remove_termination_string = remove_termination_string + + if _remove_termination_string and messages[-1].get("content"): + messages[-1]["content"] = _remove_termination_string(messages[-1]["content"]) + + # Check if the last message meets termination (if it has one) + if self._is_termination_msg and self._is_termination_msg(last_message): + logger.warning("WARNING: Last message meets termination criteria and this may terminate the chat.") + + def messages_from_string(self, message_string: str) -> list[dict[str, Any]]: + """Reads the saved state of messages in Json format for resume and returns as a messages list + + Args: + message_string: Json string, the saved state + + Returns: + A list of messages + """ + try: + state = json.loads(message_string) + except json.JSONDecodeError: + raise Exception("Messages string is not a valid JSON string") + + return state + + def messages_to_string(self, messages: list[dict[str, Any]]) -> str: + """Converts the provided messages into a Json string that can be used for resuming the chat. + The state is made up of a list of messages + + Args: + messages: set of messages to convert to a string + + Returns: + A JSON representation of the messages which can be persisted for resuming later + """ + return json.dumps(messages) + + def _raise_exception_on_async_reply_functions(self) -> None: + """Raise an exception if any async reply functions are registered. + + Raises: + RuntimeError: if any async reply functions are registered. + """ + super()._raise_exception_on_async_reply_functions() + + for agent in self._groupchat.agents: + agent._raise_exception_on_async_reply_functions() + + def clear_agents_history(self, reply: dict[str, Any], groupchat: GroupChat) -> str: + """Clears history of messages for all agents or selected one. Can preserve selected number of last messages. + That function is called when user manually provide "clear history" phrase in his reply. + When "clear history" is provided, the history of messages for all agents is cleared. + When "clear history ``" is provided, the history of messages for selected agent is cleared. + When "clear history ``" is provided, the history of messages for all agents is cleared + except last `` messages. + When "clear history `` ``" is provided, the history of messages for selected + agent is cleared except last `` messages. + Phrase "clear history" and optional arguments are cut out from the reply before it passed to the chat. + + Args: + reply (dict): reply message dict to analyze. + groupchat (GroupChat): GroupChat object. + """ + iostream = IOStream.get_default() + + reply_content = reply["content"] + # Split the reply into words + words = reply_content.split() + # Find the position of "clear" to determine where to start processing + clear_word_index = next(i for i in reversed(range(len(words))) if words[i].upper() == "CLEAR") + # Extract potential agent name and steps + words_to_check = words[clear_word_index + 2 : clear_word_index + 4] + nr_messages_to_preserve = None + nr_messages_to_preserve_provided = False + agent_to_memory_clear = None + + for word in words_to_check: + if word.isdigit(): + nr_messages_to_preserve = int(word) + nr_messages_to_preserve_provided = True + elif word[:-1].isdigit(): # for the case when number of messages is followed by dot or other sign + nr_messages_to_preserve = int(word[:-1]) + nr_messages_to_preserve_provided = True + else: + for agent in groupchat.agents: + if agent.name == word or agent.name == word[:-1]: + agent_to_memory_clear = agent + break + # preserve last tool call message if clear history called inside of tool response + if "tool_responses" in reply and not nr_messages_to_preserve: + nr_messages_to_preserve = 1 + logger.warning( + "The last tool call message will be saved to prevent errors caused by tool response without tool call." + ) + # clear history + iostream.send( + ClearAgentsHistoryEvent(agent=agent_to_memory_clear, nr_events_to_preserve=nr_messages_to_preserve) + ) + if agent_to_memory_clear: + agent_to_memory_clear.clear_history(nr_messages_to_preserve=nr_messages_to_preserve) + else: + if nr_messages_to_preserve: + # clearing history for groupchat here + temp = groupchat.messages[-nr_messages_to_preserve:] + groupchat.messages.clear() + groupchat.messages.extend(temp) + else: + # clearing history for groupchat here + groupchat.messages.clear() + # clearing history for agents + for agent in groupchat.agents: + agent.clear_history(nr_messages_to_preserve=nr_messages_to_preserve) + + # Reconstruct the reply without the "clear history" command and parameters + skip_words_number = 2 + int(bool(agent_to_memory_clear)) + int(nr_messages_to_preserve_provided) + reply_content = " ".join(words[:clear_word_index] + words[clear_word_index + skip_words_number :]) + + return reply_content diff --git a/mm_agents/coact/autogen/agentchat/realtime/__init__.py b/mm_agents/coact/autogen/agentchat/realtime/__init__.py new file mode 100644 index 0000000..1cce1a2 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/realtime/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/mm_agents/coact/autogen/agentchat/realtime/experimental/__init__.py b/mm_agents/coact/autogen/agentchat/realtime/experimental/__init__.py new file mode 100644 index 0000000..1561d7f --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/realtime/experimental/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from .audio_adapters import TwilioAudioAdapter, WebSocketAudioAdapter +from .audio_observer import AudioObserver +from .function_observer import FunctionObserver +from .realtime_agent import RealtimeAgent +from .realtime_observer import RealtimeObserver +from .realtime_swarm import register_swarm + +__all__ = [ + "AudioObserver", + "FunctionObserver", + "RealtimeAgent", + "RealtimeObserver", + "TwilioAudioAdapter", + "WebSocketAudioAdapter", + "register_swarm", +] diff --git a/mm_agents/coact/autogen/agentchat/realtime/experimental/audio_adapters/__init__.py b/mm_agents/coact/autogen/agentchat/realtime/experimental/audio_adapters/__init__.py new file mode 100644 index 0000000..90d269c --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/realtime/experimental/audio_adapters/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from .twilio_audio_adapter import TwilioAudioAdapter +from .websocket_audio_adapter import WebSocketAudioAdapter + +__all__ = ["TwilioAudioAdapter", "WebSocketAudioAdapter"] diff --git a/mm_agents/coact/autogen/agentchat/realtime/experimental/audio_adapters/twilio_audio_adapter.py b/mm_agents/coact/autogen/agentchat/realtime/experimental/audio_adapters/twilio_audio_adapter.py new file mode 100644 index 0000000..592b6a0 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/realtime/experimental/audio_adapters/twilio_audio_adapter.py @@ -0,0 +1,148 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +import base64 +import json +from logging import Logger +from typing import TYPE_CHECKING, Optional + +from .....doc_utils import export_module +from ..realtime_events import AudioDelta, RealtimeEvent, SpeechStarted +from ..realtime_observer import RealtimeObserver + +if TYPE_CHECKING: + from ..websockets import WebSocketProtocol as WebSocket + + +LOG_EVENT_TYPES = [ + "error", + "response.content.done", + "rate_limits.updated", + "response.done", + "input_audio_buffer.committed", + "input_audio_buffer.speech_stopped", + "input_audio_buffer.speech_started", + "session.created", +] +SHOW_TIMING_MATH = False + + +@export_module("autogen.agentchat.realtime.experimental") +class TwilioAudioAdapter(RealtimeObserver): + """Adapter for streaming audio from Twilio to OpenAI Realtime API and vice versa.""" + + def __init__(self, websocket: "WebSocket", *, logger: Optional[Logger] = None): + """Adapter for streaming audio from Twilio to OpenAI Realtime API and vice versa. + + Args: + websocket: the websocket connection to the Twilio service + logger: the logger to use for logging events + """ + super().__init__(logger=logger) + self.websocket = websocket + + # Connection specific state + self.stream_sid = None + self.latest_media_timestamp = 0 + self.last_assistant_item: Optional[str] = None + self.mark_queue: list[str] = [] + self.response_start_timestamp_twilio: Optional[int] = None + + async def on_event(self, event: RealtimeEvent) -> None: + """Receive events from the OpenAI Realtime API, send audio back to Twilio.""" + logger = self.logger + + if isinstance(event, AudioDelta): + audio_payload = base64.b64encode(base64.b64decode(event.delta)).decode("utf-8") + audio_delta = {"event": "media", "streamSid": self.stream_sid, "media": {"payload": audio_payload}} + await self.websocket.send_json(audio_delta) + + if self.response_start_timestamp_twilio is None: + self.response_start_timestamp_twilio = self.latest_media_timestamp + if SHOW_TIMING_MATH: + logger.info(f"Setting start timestamp for new response: {self.response_start_timestamp_twilio}ms") + + # Update last_assistant_item safely + if event.item_id: + self.last_assistant_item = event.item_id + + await self.send_mark() + + # Trigger an interruption. Your use case might work better using `input_audio_buffer.speech_stopped`, or combining the two. + if isinstance(event, SpeechStarted): + logger.info("Speech start detected.") + if self.last_assistant_item: + logger.info(f"Interrupting response with id: {self.last_assistant_item}") + await self.handle_speech_started_event() + + async def handle_speech_started_event(self) -> None: + """Handle interruption when the caller's speech starts.""" + logger = self.logger + + logger.info("Handling speech started event.") + if self.mark_queue and self.response_start_timestamp_twilio is not None: + elapsed_time = self.latest_media_timestamp - self.response_start_timestamp_twilio + if SHOW_TIMING_MATH: + logger.info( + f"Calculating elapsed time for truncation: {self.latest_media_timestamp} - {self.response_start_timestamp_twilio} = {elapsed_time}ms" + ) + + if self.last_assistant_item: + if SHOW_TIMING_MATH: + logger.info(f"Truncating item with ID: {self.last_assistant_item}, Truncated at: {elapsed_time}ms") + + await self.realtime_client.truncate_audio( + audio_end_ms=elapsed_time, + content_index=0, + item_id=self.last_assistant_item, + ) + + await self.websocket.send_json({"event": "clear", "streamSid": self.stream_sid}) + + self.mark_queue.clear() + self.last_assistant_item = None + self.response_start_timestamp_twilio = None + + async def send_mark(self) -> None: + """Send a mark of audio interruption to the Twilio websocket.""" + if self.stream_sid: + mark_event = {"event": "mark", "streamSid": self.stream_sid, "mark": {"name": "responsePart"}} + await self.websocket.send_json(mark_event) + self.mark_queue.append("responsePart") + + async def run_loop(self) -> None: + """Run the adapter loop.""" + logger = self.logger + + async for message in self.websocket.iter_text(): + try: + data = json.loads(message) + if data["event"] == "media": + self.latest_media_timestamp = int(data["media"]["timestamp"]) + await self.realtime_client.send_audio(audio=data["media"]["payload"]) + elif data["event"] == "start": + self.stream_sid = data["start"]["streamSid"] + logger.info(f"Incoming stream has started {self.stream_sid}") + self.response_start_timestamp_twilio = None + self.latest_media_timestamp = 0 + self.last_assistant_item = None + elif data["event"] == "mark": + if self.mark_queue: + self.mark_queue.pop(0) + except Exception as e: + logger.warning(f"Error processing Twilio message: {e}", stack_info=True) + + async def initialize_session(self) -> None: + """Control initial session with OpenAI.""" + session_update = { + "input_audio_format": "g711_ulaw", + "output_audio_format": "g711_ulaw", + } + await self.realtime_client.session_update(session_update) + + +if TYPE_CHECKING: + + def twilio_audio_adapter(websocket: "WebSocket") -> RealtimeObserver: + return TwilioAudioAdapter(websocket) diff --git a/mm_agents/coact/autogen/agentchat/realtime/experimental/audio_adapters/websocket_audio_adapter.py b/mm_agents/coact/autogen/agentchat/realtime/experimental/audio_adapters/websocket_audio_adapter.py new file mode 100644 index 0000000..7f1335f --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/realtime/experimental/audio_adapters/websocket_audio_adapter.py @@ -0,0 +1,139 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +import base64 +import json +from logging import Logger +from typing import TYPE_CHECKING, Optional + +from .....doc_utils import export_module +from ..realtime_events import AudioDelta, RealtimeEvent, SpeechStarted +from ..realtime_observer import RealtimeObserver + +if TYPE_CHECKING: + from ..websockets import WebSocketProtocol as WebSocket + +LOG_EVENT_TYPES = [ + "error", + "response.content.done", + "rate_limits.updated", + "response.done", + "input_audio_buffer.committed", + "input_audio_buffer.speech_stopped", + "input_audio_buffer.speech_started", + "session.created", +] +SHOW_TIMING_MATH = False + + +@export_module("autogen.agentchat.realtime.experimental") +class WebSocketAudioAdapter(RealtimeObserver): + def __init__(self, websocket: "WebSocket", *, logger: Optional[Logger] = None) -> None: + """Observer for handling function calls from the OpenAI Realtime API. + + Args: + websocket (WebSocket): The websocket connection. + logger (Logger): The logger for the observer. + """ + super().__init__(logger=logger) + self.websocket = websocket + + # Connection specific state + self.stream_sid = None + self.latest_media_timestamp = 0 + self.last_assistant_item: Optional[str] = None + self.mark_queue: list[str] = [] + self.response_start_timestamp_socket: Optional[int] = None + + async def on_event(self, event: RealtimeEvent) -> None: + """Receive events from the OpenAI Realtime API, send audio back to websocket.""" + logger = self.logger + + if isinstance(event, AudioDelta): + audio_payload = base64.b64encode(base64.b64decode(event.delta)).decode("utf-8") + audio_delta = {"event": "media", "streamSid": self.stream_sid, "media": {"payload": audio_payload}} + await self.websocket.send_json(audio_delta) + + if self.response_start_timestamp_socket is None: + self.response_start_timestamp_socket = self.latest_media_timestamp + if SHOW_TIMING_MATH: + logger.info(f"Setting start timestamp for new response: {self.response_start_timestamp_socket}ms") + + # Update last_assistant_item safely + if event.item_id: + self.last_assistant_item = event.item_id + + await self.send_mark() + + # Trigger an interruption. Your use case might work better using `input_audio_buffer.speech_stopped`, or combining the two. + if isinstance(event, SpeechStarted): + logger.info("Speech start detected.") + if self.last_assistant_item: + logger.info(f"Interrupting response with id: {self.last_assistant_item}") + await self.handle_speech_started_event() + + async def handle_speech_started_event(self) -> None: + """Handle interruption when the caller's speech starts.""" + logger = self.logger + logger.info("Handling speech started event.") + if self.mark_queue and self.response_start_timestamp_socket is not None: + elapsed_time = self.latest_media_timestamp - self.response_start_timestamp_socket + if SHOW_TIMING_MATH: + logger.info( + f"Calculating elapsed time for truncation: {self.latest_media_timestamp} - {self.response_start_timestamp_socket} = {elapsed_time}ms" + ) + + if self.last_assistant_item: + if SHOW_TIMING_MATH: + logger.info(f"Truncating item with ID: {self.last_assistant_item}, Truncated at: {elapsed_time}ms") + + await self.realtime_client.truncate_audio( + audio_end_ms=elapsed_time, + content_index=0, + item_id=self.last_assistant_item, + ) + + await self.websocket.send_json({"event": "clear", "streamSid": self.stream_sid}) + + self.mark_queue.clear() + self.last_assistant_item = None + self.response_start_timestamp_socket = None + + async def send_mark(self) -> None: + if self.stream_sid: + mark_event = {"event": "mark", "streamSid": self.stream_sid, "mark": {"name": "responsePart"}} + await self.websocket.send_json(mark_event) + self.mark_queue.append("responsePart") + + async def initialize_session(self) -> None: + """Control initial session with OpenAI.""" + session_update = {"input_audio_format": "pcm16", "output_audio_format": "pcm16"} + await self.realtime_client.session_update(session_update) + + async def run_loop(self) -> None: + """Reads data from websocket and sends it to the RealtimeClient.""" + logger = self.logger + async for message in self.websocket.iter_text(): + try: + data = json.loads(message) + if data["event"] == "media": + self.latest_media_timestamp = int(data["media"]["timestamp"]) + await self.realtime_client.send_audio(audio=data["media"]["payload"]) + elif data["event"] == "start": + self.stream_sid = data["start"]["streamSid"] + logger.info(f"Incoming stream has started {self.stream_sid}") + self.response_start_timestamp_socket = None + self.latest_media_timestamp = 0 + self.last_assistant_item = None + elif data["event"] == "mark": + if self.mark_queue: + self.mark_queue.pop(0) + except Exception as e: + logger.warning(f"Failed to process message: {e}", stack_info=True) + + +if TYPE_CHECKING: + + def websocket_audio_adapter(websocket: "WebSocket") -> RealtimeObserver: + return WebSocketAudioAdapter(websocket) diff --git a/mm_agents/coact/autogen/agentchat/realtime/experimental/audio_observer.py b/mm_agents/coact/autogen/agentchat/realtime/experimental/audio_observer.py new file mode 100644 index 0000000..dfa4ca6 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/realtime/experimental/audio_observer.py @@ -0,0 +1,42 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING, Optional + +from ....doc_utils import export_module +from .realtime_events import InputAudioBufferDelta, RealtimeEvent +from .realtime_observer import RealtimeObserver + +if TYPE_CHECKING: + from logging import Logger + + +@export_module("autogen.agentchat.realtime.experimental") +class AudioObserver(RealtimeObserver): + """Observer for user voice input""" + + def __init__(self, *, logger: Optional["Logger"] = None) -> None: + """Observer for user voice input""" + super().__init__(logger=logger) + + async def on_event(self, event: RealtimeEvent) -> None: + """Observe voice input events from the Realtime. + + Args: + event (dict[str, Any]): The event from the OpenAI Realtime API. + """ + if isinstance(event, InputAudioBufferDelta): + self.logger.info("Received audio buffer delta") + + async def initialize_session(self) -> None: + """No need to initialize session from this observer""" + pass + + async def run_loop(self) -> None: + """Run the observer loop.""" + pass + + +if TYPE_CHECKING: + function_observer: RealtimeObserver = AudioObserver() diff --git a/mm_agents/coact/autogen/agentchat/realtime/experimental/clients/__init__.py b/mm_agents/coact/autogen/agentchat/realtime/experimental/clients/__init__.py new file mode 100644 index 0000000..23e90a6 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/realtime/experimental/clients/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from .gemini.client import GeminiRealtimeClient +from .oai.base_client import OpenAIRealtimeClient +from .realtime_client import RealtimeClientProtocol, Role, get_client + +__all__ = [ + "GeminiRealtimeClient", + "OpenAIRealtimeClient", + "RealtimeClientProtocol", + "Role", + "get_client", +] diff --git a/mm_agents/coact/autogen/agentchat/realtime/experimental/clients/gemini/__init__.py b/mm_agents/coact/autogen/agentchat/realtime/experimental/clients/gemini/__init__.py new file mode 100644 index 0000000..bec9e2c --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/realtime/experimental/clients/gemini/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from .client import GeminiRealtimeClient + +__all__ = ["GeminiRealtimeClient"] diff --git a/mm_agents/coact/autogen/agentchat/realtime/experimental/clients/gemini/client.py b/mm_agents/coact/autogen/agentchat/realtime/experimental/clients/gemini/client.py new file mode 100644 index 0000000..fa6a76a --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/realtime/experimental/clients/gemini/client.py @@ -0,0 +1,274 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +import json +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from logging import Logger, getLogger +from typing import TYPE_CHECKING, Any, Callable, Optional, Union + +from ......doc_utils import export_module +from ......import_utils import optional_import_block, require_optional_import +from ......llm_config import LLMConfig +from ...realtime_events import AudioDelta, FunctionCall, RealtimeEvent, SessionCreated +from ..realtime_client import RealtimeClientBase, Role, register_realtime_client + +with optional_import_block(): + from websockets.asyncio.client import connect + + +if TYPE_CHECKING: + from websockets.asyncio.client import ClientConnection + + from ..realtime_client import RealtimeClientProtocol + +__all__ = ["GeminiRealtimeClient"] + +global_logger = getLogger(__name__) + + +HOST = "generativelanguage.googleapis.com" +API_VERSION = "v1alpha" + + +@register_realtime_client() +@require_optional_import("websockets", "gemini", except_for=["get_factory", "__init__"]) +@export_module("autogen.agentchat.realtime.experimental.clients") +class GeminiRealtimeClient(RealtimeClientBase): + """(Experimental) Client for Gemini Realtime API.""" + + def __init__( + self, + *, + llm_config: Union[LLMConfig, dict[str, Any]], + logger: Optional[Logger] = None, + ) -> None: + """(Experimental) Client for Gemini Realtime API. + + Args: + llm_config: The config for the client. + logger: The logger for the client. + """ + super().__init__() + self._llm_config = llm_config + self._logger = logger + + self._connection: Optional["ClientConnection"] = None + config = llm_config["config_list"][0] + + self._model: str = config["model"] + self._voice = config.get("voice", "charon") + self._temperature: float = config.get("temperature", 0.8) # type: ignore[union-attr] + + self._response_modality = "AUDIO" + + self._api_key = config.get("api_key", None) + # todo: add test with base_url just to make sure it works + self._base_url: str = config.get( + "base_url", + f"wss://{HOST}/ws/google.ai.generativelanguage.{API_VERSION}.GenerativeService.BidiGenerateContent?key={self._api_key}", + ) + self._final_config: dict[str, Any] = {} + self._pending_session_updates: dict[str, Any] = {} + self._is_reading_events = False + + @property + def logger(self) -> Logger: + """Get the logger for the Gemini Realtime API.""" + return self._logger or global_logger + + @property + def connection(self) -> "ClientConnection": + """Get the Gemini WebSocket connection.""" + if self._connection is None: + raise RuntimeError("Gemini WebSocket is not initialized") + return self._connection + + async def send_function_result(self, call_id: str, result: str) -> None: + """Send the result of a function call to the Gemini Realtime API. + + Args: + call_id (str): The ID of the function call. + result (str): The result of the function call. + """ + msg = { + "tool_response": {"function_responses": [{"id": call_id, "response": {"result": {"string_value": result}}}]} + } + if self._is_reading_events: + await self.connection.send(json.dumps(msg)) + + async def send_text(self, *, role: Role, text: str, turn_complete: bool = True) -> None: + """Send a text message to the Gemini Realtime API. + + Args: + role: The role of the message. + text: The text of the message. + turn_complete: A flag indicating if the turn is complete. + """ + msg = { + "client_content": { + "turn_complete": turn_complete, + "turns": [{"role": role, "parts": [{"text": text}]}], + } + } + if self._is_reading_events: + await self.connection.send(json.dumps(msg)) + + async def send_audio(self, audio: str) -> None: + """Send audio to the Gemini Realtime API. + + Args: + audio (str): The audio to send. + """ + msg = { + "realtime_input": { + "media_chunks": [ + { + "data": audio, + "mime_type": "audio/pcm", + } + ] + } + } + await self.queue_input_audio_buffer_delta(audio) + if self._is_reading_events: + await self.connection.send(json.dumps(msg)) + + async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None: + self.logger.info("This is not natively supported by Gemini Realtime API.") + pass + + async def _initialize_session(self) -> None: + """Initialize the session with the Gemini Realtime API.""" + session_config = { + "setup": { + "system_instruction": { + "role": "system", + "parts": [{"text": self._pending_session_updates.get("instructions", "")}], + }, + "model": f"models/{self._model}", + "tools": [ + { + "function_declarations": [ + { + "name": tool_schema["name"], + "description": tool_schema["description"], + "parameters": tool_schema["parameters"], + } + for tool_schema in self._pending_session_updates.get("tools", []) + ] + }, + ], + "generation_config": { + "response_modalities": [self._response_modality], + "speech_config": {"voiceConfig": {"prebuiltVoiceConfig": {"voiceName": self._voice}}}, + "temperature": self._temperature, + }, + } + } + + self.logger.info(f"Sending session update: {session_config}") + await self.connection.send(json.dumps(session_config)) + + async def session_update(self, session_options: dict[str, Any]) -> None: + """Record session updates to be applied when the connection is established. + + Args: + session_options (dict[str, Any]): The session options to update. + """ + if self._is_reading_events: + self.logger.warning("Is reading events. Session update will be ignored.") + else: + self._pending_session_updates.update(session_options) + + @asynccontextmanager + async def connect(self) -> AsyncGenerator[None, None]: + """Connect to the Gemini Realtime API.""" + try: + async with connect( + self._base_url, additional_headers={"Content-Type": "application/json"} + ) as self._connection: + yield + finally: + self._connection = None + + async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]: + """Read Events from the Gemini Realtime Client""" + if self._connection is None: + raise RuntimeError("Client is not connected, call connect() first.") + await self._initialize_session() + + self._is_reading_events = True + + async for event in self._read_events(): + yield event + + async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]: + """Read messages from the Gemini Realtime connection.""" + async for raw_message in self.connection: + message = raw_message.decode("ascii") if isinstance(raw_message, bytes) else raw_message + events = self._parse_message(json.loads(message)) + for event in events: + yield event + + def _parse_message(self, response: dict[str, Any]) -> list[RealtimeEvent]: + """Parse a message from the Gemini Realtime API. + + Args: + response (dict[str, Any]): The response to parse. + + Returns: + list[RealtimeEvent]: The parsed events. + """ + if "serverContent" in response and "modelTurn" in response["serverContent"]: + try: + b64data = response["serverContent"]["modelTurn"]["parts"][0]["inlineData"].pop("data") + return [ + AudioDelta( + delta=b64data, + item_id=None, + raw_message=response, + ) + ] + except KeyError: + return [] + elif "toolCall" in response: + return [ + FunctionCall( + raw_message=response, + call_id=call["id"], + name=call["name"], + arguments=call["args"], + ) + for call in response["toolCall"]["functionCalls"] + ] + elif "setupComplete" in response: + return [ + SessionCreated(raw_message=response), + ] + else: + return [RealtimeEvent(raw_message=response)] + + @classmethod + def get_factory( + cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any + ) -> Optional[Callable[[], "RealtimeClientProtocol"]]: + """Create a Realtime API client. + + Args: + llm_config: The LLM config for the client. + logger: The logger for the client. + **kwargs: Additional arguments. + + Returns: + RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern + """ + if llm_config["config_list"][0].get("api_type") == "google" and list(kwargs.keys()) == []: + return lambda: GeminiRealtimeClient(llm_config=llm_config, logger=logger, **kwargs) + return None + + +# needed for mypy to check if GeminiRealtimeClient implements RealtimeClientProtocol +if TYPE_CHECKING: + _client: RealtimeClientProtocol = GeminiRealtimeClient(llm_config={}) diff --git a/mm_agents/coact/autogen/agentchat/realtime/experimental/clients/oai/__init__.py b/mm_agents/coact/autogen/agentchat/realtime/experimental/clients/oai/__init__.py new file mode 100644 index 0000000..52d380a --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/realtime/experimental/clients/oai/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from .base_client import OpenAIRealtimeClient +from .rtc_client import OpenAIRealtimeWebRTCClient + +__all__ = ["OpenAIRealtimeClient", "OpenAIRealtimeWebRTCClient"] diff --git a/mm_agents/coact/autogen/agentchat/realtime/experimental/clients/oai/base_client.py b/mm_agents/coact/autogen/agentchat/realtime/experimental/clients/oai/base_client.py new file mode 100644 index 0000000..1c4953c --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/realtime/experimental/clients/oai/base_client.py @@ -0,0 +1,220 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from logging import Logger, getLogger +from typing import TYPE_CHECKING, Any, Callable, Optional, Union + +from ......doc_utils import export_module +from ......import_utils import optional_import_block, require_optional_import +from ......llm_config import LLMConfig +from ...realtime_events import RealtimeEvent +from ..realtime_client import RealtimeClientBase, Role, register_realtime_client +from .utils import parse_oai_message + +with optional_import_block(): + from openai import DEFAULT_MAX_RETRIES, NOT_GIVEN, AsyncOpenAI + from openai.resources.beta.realtime.realtime import AsyncRealtimeConnection + + +if TYPE_CHECKING: + from ..realtime_client import RealtimeClientProtocol + +__all__ = ["OpenAIRealtimeClient"] + +global_logger = getLogger(__name__) + + +@register_realtime_client() +@require_optional_import("openai>=1.66.2", "openai-realtime", except_for=["get_factory", "__init__"]) +@export_module("autogen.agentchat.realtime.experimental.clients") +class OpenAIRealtimeClient(RealtimeClientBase): + """(Experimental) Client for OpenAI Realtime API.""" + + def __init__( + self, + *, + llm_config: Union[LLMConfig, dict[str, Any]], + logger: Optional[Logger] = None, + ) -> None: + """(Experimental) Client for OpenAI Realtime API. + + Args: + llm_config: The config for the client. + logger: the logger to use for logging events + """ + super().__init__() + self._llm_config = llm_config + self._logger = logger + + self._connection: Optional["AsyncRealtimeConnection"] = None + + self.config = llm_config["config_list"][0] + # model is passed to self._client.beta.realtime.connect function later + self._model: str = self.config["model"] + self._voice: str = self.config.get("voice", "alloy") + self._temperature: float = llm_config.get("temperature", 0.8) # type: ignore[union-attr] + + self._client: Optional["AsyncOpenAI"] = None + + @property + def logger(self) -> Logger: + """Get the logger for the OpenAI Realtime API.""" + return self._logger or global_logger + + @property + def connection(self) -> "AsyncRealtimeConnection": + """Get the OpenAI WebSocket connection.""" + if self._connection is None: + raise RuntimeError("OpenAI WebSocket is not initialized") + return self._connection + + async def send_function_result(self, call_id: str, result: str) -> None: + """Send the result of a function call to the OpenAI Realtime API. + + Args: + call_id (str): The ID of the function call. + result (str): The result of the function call. + """ + await self.connection.conversation.item.create( + item={ + "type": "function_call_output", + "call_id": call_id, + "output": result, + }, + ) + + await self.connection.response.create() + + async def send_text(self, *, role: Role, text: str) -> None: + """Send a text message to the OpenAI Realtime API. + + Args: + role (str): The role of the message. + text (str): The text of the message. + """ + await self.connection.response.cancel() + await self.connection.conversation.item.create( + item={"type": "message", "role": role, "content": [{"type": "input_text", "text": text}]} + ) + await self.connection.response.create() + + async def send_audio(self, audio: str) -> None: + """Send audio to the OpenAI Realtime API. + + Args: + audio (str): The audio to send. + """ + await self.queue_input_audio_buffer_delta(audio) + await self.connection.input_audio_buffer.append(audio=audio) + + async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None: + """Truncate audio in the OpenAI Realtime API. + + Args: + audio_end_ms (int): The end of the audio to truncate. + content_index (int): The index of the content to truncate. + item_id (str): The ID of the item to truncate. + """ + await self.connection.conversation.item.truncate( + audio_end_ms=audio_end_ms, content_index=content_index, item_id=item_id + ) + + async def _initialize_session(self) -> None: + """Control initial session with OpenAI.""" + session_update = { + "turn_detection": {"type": "server_vad"}, + "voice": self._voice, + "modalities": ["audio", "text"], + "temperature": self._temperature, + } + await self.session_update(session_options=session_update) + + async def session_update(self, session_options: dict[str, Any]) -> None: + """Send a session update to the OpenAI Realtime API. + + Args: + session_options (dict[str, Any]): The session options to update. + """ + logger = self.logger + logger.info(f"Sending session update: {session_options}") + await self.connection.session.update(session=session_options) # type: ignore[arg-type] + logger.info("Sending session update finished") + + @asynccontextmanager + async def connect(self) -> AsyncGenerator[None, None]: + """Connect to the OpenAI Realtime API.""" + try: + if not self._client: + self._client = AsyncOpenAI( + api_key=self.config.get("api_key", None), + organization=self.config.get("organization", None), + project=self.config.get("project", None), + base_url=self.config.get("base_url", None), + websocket_base_url=self.config.get("websocket_base_url", None), + timeout=self.config.get("timeout", NOT_GIVEN), + max_retries=self.config.get("max_retries", DEFAULT_MAX_RETRIES), + default_headers=self.config.get("default_headers", None), + default_query=self.config.get("default_query", None), + ) + async with self._client.beta.realtime.connect( + model=self._model, + ) as self._connection: + await self._initialize_session() + yield + finally: + self._connection = None + + async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]: + """Read messages from the OpenAI Realtime API.""" + if self._connection is None: + raise RuntimeError("Client is not connected, call connect() first.") + + try: + async for event in self._read_events(): + yield event + + finally: + self._connection = None + + async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]: + """Read messages from the OpenAI Realtime API.""" + async for message in self._connection: + for event in self._parse_message(message.model_dump()): + yield event + + def _parse_message(self, message: dict[str, Any]) -> list[RealtimeEvent]: + """Parse a message from the OpenAI Realtime API. + + Args: + message (dict[str, Any]): The message to parse. + + Returns: + RealtimeEvent: The parsed event. + """ + return [parse_oai_message(message)] + + @classmethod + def get_factory( + cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any + ) -> Optional[Callable[[], "RealtimeClientProtocol"]]: + """Create a Realtime API client. + + Args: + llm_config: The config for the client. + logger: The logger to use for logging events. + kwargs: Additional arguments. + + Returns: + RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern + """ + if llm_config["config_list"][0].get("api_type", "openai") == "openai" and list(kwargs.keys()) == []: + return lambda: OpenAIRealtimeClient(llm_config=llm_config, logger=logger, **kwargs) + return None + + +# needed for mypy to check if OpenAIRealtimeWebRTCClient implements RealtimeClientProtocol +if TYPE_CHECKING: + _client: RealtimeClientProtocol = OpenAIRealtimeClient(llm_config={}) diff --git a/mm_agents/coact/autogen/agentchat/realtime/experimental/clients/oai/rtc_client.py b/mm_agents/coact/autogen/agentchat/realtime/experimental/clients/oai/rtc_client.py new file mode 100644 index 0000000..3ba079c --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/realtime/experimental/clients/oai/rtc_client.py @@ -0,0 +1,243 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +import json +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from logging import Logger, getLogger +from typing import TYPE_CHECKING, Any, Callable, Optional, Union + +from autogen.import_utils import optional_import_block, require_optional_import + +from ......doc_utils import export_module +from ......llm_config import LLMConfig +from ...realtime_events import RealtimeEvent +from ..realtime_client import RealtimeClientBase, Role, register_realtime_client +from .utils import parse_oai_message + +if TYPE_CHECKING: + from ...websockets import WebSocketProtocol as WebSocket + from ..realtime_client import RealtimeClientProtocol + +with optional_import_block(): + import httpx + +__all__ = ["OpenAIRealtimeWebRTCClient"] + +global_logger = getLogger(__name__) + + +@register_realtime_client() +@require_optional_import("httpx", "openai-realtime", except_for="get_factory") +@export_module("autogen.agentchat.realtime.experimental.clients.oai") +class OpenAIRealtimeWebRTCClient(RealtimeClientBase): + """(Experimental) Client for OpenAI Realtime API that uses WebRTC protocol.""" + + def __init__( + self, + *, + llm_config: Union[LLMConfig, dict[str, Any]], + websocket: "WebSocket", + logger: Optional[Logger] = None, + ) -> None: + """(Experimental) Client for OpenAI Realtime API. + + Args: + llm_config: The config for the client. + websocket: the websocket to use for the connection + logger: the logger to use for logging events + """ + super().__init__() + self._llm_config = llm_config + self._logger = logger + self._websocket = websocket + + config = llm_config["config_list"][0] + self._model: str = config["model"] + self._voice: str = config.get("voice", "alloy") + self._temperature: float = llm_config.get("temperature", 0.8) # type: ignore[union-attr] + self._config = config + self._base_url = config.get("base_url", "https://api.openai.com/v1/realtime/sessions") + + @property + def logger(self) -> Logger: + """Get the logger for the OpenAI Realtime API.""" + return self._logger or global_logger + + async def send_function_result(self, call_id: str, result: str) -> None: + """Send the result of a function call to the OpenAI Realtime API. + + Args: + call_id (str): The ID of the function call. + result (str): The result of the function call. + """ + await self._websocket.send_json({ + "type": "conversation.item.create", + "item": { + "type": "function_call_output", + "call_id": call_id, + "output": result, + }, + }) + await self._websocket.send_json({"type": "response.create"}) + + async def send_text(self, *, role: Role, text: str) -> None: + """Send a text message to the OpenAI Realtime API. + + Args: + role (str): The role of the message. + text (str): The text of the message. + """ + # await self.connection.response.cancel() #why is this here? + await self._websocket.send_json({ + "type": "response.cancel", + }) + await self._websocket.send_json({ + "type": "conversation.item.create", + "item": {"type": "message", "role": role, "content": [{"type": "input_text", "text": text}]}, + }) + # await self.connection.response.create() + await self._websocket.send_json({"type": "response.create"}) + + async def send_audio(self, audio: str) -> None: + """Send audio to the OpenAI Realtime API. + in case of WebRTC, audio is already sent by js client, so we just queue it in order to be logged. + + Args: + audio (str): The audio to send. + """ + await self.queue_input_audio_buffer_delta(audio) + + async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None: + """Truncate audio in the OpenAI Realtime API. + + Args: + audio_end_ms (int): The end of the audio to truncate. + content_index (int): The index of the content to truncate. + item_id (str): The ID of the item to truncate. + """ + await self._websocket.send_json({ + "type": "conversation.item.truncate", + "content_index": content_index, + "item_id": item_id, + "audio_end_ms": audio_end_ms, + }) + + async def session_update(self, session_options: dict[str, Any]) -> None: + """Send a session update to the OpenAI Realtime API. + + In the case of WebRTC we can not send it directly, but we can send it + to the javascript over the websocket, and rely on it to send session + update to OpenAI + + Args: + session_options (dict[str, Any]): The session options to update. + """ + logger = self.logger + logger.info(f"Sending session update: {session_options}") + # await self.connection.session.update(session=session_options) # type: ignore[arg-type] + await self._websocket.send_json({"type": "session.update", "session": session_options}) + logger.info("Sending session update finished") + + def session_init_data(self) -> list[dict[str, Any]]: + """Control initial session with OpenAI.""" + session_update = { + "turn_detection": {"type": "server_vad"}, + "voice": self._voice, + "modalities": ["audio", "text"], + "temperature": self._temperature, + } + return [{"type": "session.update", "session": session_update}] + + async def _initialize_session(self) -> None: ... + + @asynccontextmanager + async def connect(self) -> AsyncGenerator[None, None]: + """Connect to the OpenAI Realtime API. + + In the case of WebRTC, we pass connection information over the + websocket, so that javascript on the other end of websocket open + actual connection to OpenAI + """ + try: + base_url = self._base_url + api_key = self._config.get("api_key", None) + headers = { + "Authorization": f"Bearer {api_key}", # Use os.getenv to get from environment + "Content-Type": "application/json", + } + data = { + # "model": "gpt-4o-realtime-preview-2024-12-17", + "model": self._model, + "voice": self._voice, + } + async with httpx.AsyncClient() as client: + response = await client.post(base_url, headers=headers, json=data) + response.raise_for_status() + json_data = response.json() + json_data["model"] = self._model + if self._websocket is not None: + session_init = self.session_init_data() + await self._websocket.send_json({"type": "ag2.init", "config": json_data, "init": session_init}) + yield + finally: + pass + + async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]: + """Read events from the OpenAI Realtime API.""" + async for event in self._read_events(): + yield event + + async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]: + """Read messages from the OpenAI Realtime API connection. + Again, in case of WebRTC, we do not read OpenAI messages directly since we + do not hold connection to OpenAI. Instead we read messages from the websocket, and javascript + client on the other side of the websocket that is connected to OpenAI is relaying events to us. + """ + while True: + try: + message_json = await self._websocket.receive_text() + message = json.loads(message_json) + for event in self._parse_message(message): + yield event + except Exception as e: + self.logger.exception(f"Error reading from connection {e}") + break + + def _parse_message(self, message: dict[str, Any]) -> list[RealtimeEvent]: + """Parse a message from the OpenAI Realtime API. + + Args: + message (dict[str, Any]): The message to parse. + + Returns: + RealtimeEvent: The parsed event. + """ + return [parse_oai_message(message)] + + @classmethod + def get_factory( + cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any + ) -> Optional[Callable[[], "RealtimeClientProtocol"]]: + """Create a Realtime API client. + + Args: + llm_config: The config for the client. + logger: The logger to use for logging events. + **kwargs: Additional arguments. + + Returns: + RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern + """ + if llm_config["config_list"][0].get("api_type", "openai") == "openai" and list(kwargs.keys()) == ["websocket"]: + return lambda: OpenAIRealtimeWebRTCClient(llm_config=llm_config, logger=logger, **kwargs) + + return None + + +# needed for mypy to check if OpenAIRealtimeWebRTCClient implements RealtimeClientProtocol +if TYPE_CHECKING: + + def _rtc_client(websocket: "WebSocket") -> RealtimeClientProtocol: + return OpenAIRealtimeWebRTCClient(llm_config={}, websocket=websocket) diff --git a/mm_agents/coact/autogen/agentchat/realtime/experimental/clients/oai/utils.py b/mm_agents/coact/autogen/agentchat/realtime/experimental/clients/oai/utils.py new file mode 100644 index 0000000..28c00ea --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/realtime/experimental/clients/oai/utils.py @@ -0,0 +1,48 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +import json +from typing import Any + +from ...realtime_events import ( + AudioDelta, + FunctionCall, + InputAudioBufferDelta, + RealtimeEvent, + SessionCreated, + SessionUpdated, + SpeechStarted, +) + +__all__ = ["parse_oai_message"] + + +def parse_oai_message(message: dict[str, Any]) -> RealtimeEvent: + """Parse a message from the OpenAI Realtime API. + + Args: + message (dict[str, Any]): The message to parse. + + Returns: + RealtimeEvent: The parsed event. + """ + if message.get("type") == "session.created": + return SessionCreated(raw_message=message) + elif message.get("type") == "session.updated": + return SessionUpdated(raw_message=message) + elif message.get("type") == "response.audio.delta": + return AudioDelta(raw_message=message, delta=message["delta"], item_id=message["item_id"]) + elif message.get("type") == "input_audio_buffer.speech_started": + return SpeechStarted(raw_message=message) + elif message.get("type") == "input_audio_buffer.delta": + return InputAudioBufferDelta(delta=message["delta"], item_id=None, raw_message=message) + elif message.get("type") == "response.function_call_arguments.done": + return FunctionCall( + raw_message=message, + call_id=message["call_id"], + name=message["name"], + arguments=json.loads(message["arguments"]), + ) + else: + return RealtimeEvent(raw_message=message) diff --git a/mm_agents/coact/autogen/agentchat/realtime/experimental/clients/realtime_client.py b/mm_agents/coact/autogen/agentchat/realtime/experimental/clients/realtime_client.py new file mode 100644 index 0000000..edc6f7a --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/realtime/experimental/clients/realtime_client.py @@ -0,0 +1,190 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +from collections.abc import AsyncGenerator +from logging import Logger +from typing import Any, AsyncContextManager, Callable, Literal, Optional, Protocol, TypeVar, Union, runtime_checkable + +from asyncer import create_task_group + +from .....doc_utils import export_module +from .....llm_config import LLMConfig +from ..realtime_events import InputAudioBufferDelta, RealtimeEvent + +__all__ = ["RealtimeClientProtocol", "Role", "get_client", "register_realtime_client"] + +# define role literal type for typing +Role = Literal["user", "assistant", "system"] + + +@runtime_checkable +@export_module("autogen.agentchat.realtime.experimental.clients") +class RealtimeClientProtocol(Protocol): + async def send_function_result(self, call_id: str, result: str) -> None: + """Send the result of a function call to a Realtime API. + + Args: + call_id (str): The ID of the function call. + result (str): The result of the function call. + """ + ... + + async def send_text(self, *, role: Role, text: str) -> None: + """Send a text message to a Realtime API. + + Args: + role (str): The role of the message. + text (str): The text of the message. + """ + ... + + async def send_audio(self, audio: str) -> None: + """Send audio to a Realtime API. + + Args: + audio (str): The audio to send. + """ + ... + + async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None: + """Truncate audio in a Realtime API. + + Args: + audio_end_ms (int): The end of the audio to truncate. + content_index (int): The index of the content to truncate. + item_id (str): The ID of the item to truncate. + """ + ... + + async def session_update(self, session_options: dict[str, Any]) -> None: + """Send a session update to a Realtime API. + + Args: + session_options (dict[str, Any]): The session options to update. + """ + ... + + def connect(self) -> AsyncContextManager[None]: ... + + def read_events(self) -> AsyncGenerator[RealtimeEvent, None]: + """Read events from a Realtime Client.""" + ... + + async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]: + """Read events from a Realtime connection.""" + ... + + def _parse_message(self, message: dict[str, Any]) -> list[RealtimeEvent]: + """Parse a message from a Realtime API. + + Args: + message (dict[str, Any]): The message to parse. + + Returns: + list[RealtimeEvent]: The parsed events. + """ + ... + + @classmethod + def get_factory( + cls, llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any + ) -> Optional[Callable[[], "RealtimeClientProtocol"]]: + """Create a Realtime API client. + + Args: + llm_config: The config for the client. + logger: The logger to use for logging events. + **kwargs: Additional arguments. + + Returns: + RealtimeClientProtocol: The Realtime API client is returned if the model matches the pattern + """ + ... + + +class RealtimeClientBase: + def __init__(self): + self._eventQueue = asyncio.Queue() + + async def add_event(self, event: Optional[RealtimeEvent]): + await self._eventQueue.put(event) + + async def get_event(self) -> Optional[RealtimeEvent]: + return await self._eventQueue.get() + + async def _read_from_connection_task(self): + async for event in self._read_from_connection(): + await self.add_event(event) + await self.add_event(None) + + async def _read_events(self) -> AsyncGenerator[RealtimeEvent, None]: + """Read events from a Realtime Client.""" + async with create_task_group() as tg: + tg.start_soon(self._read_from_connection_task) + while True: + try: + event = await self._eventQueue.get() + if event is not None: + yield event + else: + break + except Exception: + break + + async def queue_input_audio_buffer_delta(self, audio: str) -> None: + """queue InputAudioBufferDelta. + + Args: + audio (str): The audio. + """ + await self.add_event(InputAudioBufferDelta(delta=audio, item_id=None, raw_message=dict())) + + +_realtime_client_classes: dict[str, type[RealtimeClientProtocol]] = {} + +T = TypeVar("T", bound=RealtimeClientProtocol) + + +def register_realtime_client() -> Callable[[type[T]], type[T]]: + """Register a Realtime API client. + + Returns: + Callable[[type[T]], type[T]]: The decorator to register the Realtime API client + """ + + def decorator(client_cls: type[T]) -> type[T]: + """Register a Realtime API client. + + Args: + client_cls: The client to register. + """ + global _realtime_client_classes + fqn = f"{client_cls.__module__}.{client_cls.__name__}" + _realtime_client_classes[fqn] = client_cls + + return client_cls + + return decorator + + +@export_module("autogen.agentchat.realtime.experimental.clients") +def get_client(llm_config: Union[LLMConfig, dict[str, Any]], logger: Logger, **kwargs: Any) -> "RealtimeClientProtocol": + """Get a registered Realtime API client. + + Args: + llm_config: The config for the client. + logger: The logger to use for logging events. + **kwargs: Additional arguments. + + Returns: + RealtimeClientProtocol: The Realtime API client. + """ + global _realtime_client_classes + for _, client_cls in _realtime_client_classes.items(): + factory = client_cls.get_factory(llm_config=llm_config, logger=logger, **kwargs) + if factory: + return factory() + + raise ValueError("Realtime API client not found.") diff --git a/mm_agents/coact/autogen/agentchat/realtime/experimental/function_observer.py b/mm_agents/coact/autogen/agentchat/realtime/experimental/function_observer.py new file mode 100644 index 0000000..98df085 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/realtime/experimental/function_observer.py @@ -0,0 +1,85 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import json +from typing import TYPE_CHECKING, Any, Optional + +from asyncer import asyncify +from pydantic import BaseModel + +from ....doc_utils import export_module +from .realtime_events import FunctionCall, RealtimeEvent +from .realtime_observer import RealtimeObserver + +if TYPE_CHECKING: + from logging import Logger + + +@export_module("autogen.agentchat.realtime.experimental") +class FunctionObserver(RealtimeObserver): + """Observer for handling function calls from the OpenAI Realtime API.""" + + def __init__(self, *, logger: Optional["Logger"] = None) -> None: + """Observer for handling function calls from the OpenAI Realtime API.""" + super().__init__(logger=logger) + + async def on_event(self, event: RealtimeEvent) -> None: + """Handle function call events from the OpenAI Realtime API. + + Args: + event (dict[str, Any]): The event from the OpenAI Realtime API. + """ + if isinstance(event, FunctionCall): + self.logger.info("Received function call event") + await self.call_function( + call_id=event.call_id, + name=event.name, + kwargs=event.arguments, + ) + + async def call_function(self, call_id: str, name: str, kwargs: dict[str, Any]) -> None: + """Call a function registered with the agent. + + Args: + call_id (str): The ID of the function call. + name (str): The name of the function to call. + kwargs (Any[str, Any]): The arguments to pass to the function. + """ + if name in self.agent.registered_realtime_tools: + func = self.agent.registered_realtime_tools[name].func + func = func if asyncio.iscoroutinefunction(func) else asyncify(func) + try: + result = await func(**kwargs) + except Exception: + result = "Function call failed" + self.logger.info(f"Function call failed: {name=}, {kwargs=}", stack_info=True) + + if isinstance(result, BaseModel): + result = result.model_dump_json() + elif not isinstance(result, str): + try: + result = json.dumps(result) + except Exception: + result = str(result) + + await self.realtime_client.send_function_result(call_id, result) + else: + self.logger.warning(f"Function {name} called, but is not registered with the realtime agent.") + + async def initialize_session(self) -> None: + """Add registered tools to OpenAI with a session update.""" + session_update = { + "tools": [tool.realtime_tool_schema for tool in self.agent.registered_realtime_tools.values()], + "tool_choice": "auto", + } + await self.realtime_client.session_update(session_update) + + async def run_loop(self) -> None: + """Run the observer loop.""" + pass + + +if TYPE_CHECKING: + function_observer: RealtimeObserver = FunctionObserver() diff --git a/mm_agents/coact/autogen/agentchat/realtime/experimental/realtime_agent.py b/mm_agents/coact/autogen/agentchat/realtime/experimental/realtime_agent.py new file mode 100644 index 0000000..703820c --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/realtime/experimental/realtime_agent.py @@ -0,0 +1,158 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from logging import Logger, getLogger +from typing import Any, Callable, Optional, TypeVar, Union + +from anyio import lowlevel +from asyncer import create_task_group + +from ....doc_utils import export_module +from ....llm_config import LLMConfig +from ....tools import Tool +from .clients.realtime_client import RealtimeClientProtocol, get_client +from .function_observer import FunctionObserver +from .realtime_observer import RealtimeObserver + +F = TypeVar("F", bound=Callable[..., Any]) + +global_logger = getLogger(__name__) + + +@dataclass +class RealtimeAgentCallbacks: + """Callbacks for the Realtime Agent.""" + + # async empty placeholder function + on_observers_ready: Callable[[], Any] = lambda: lowlevel.checkpoint() + + +@export_module("autogen.agentchat.realtime.experimental") +class RealtimeAgent: + def __init__( + self, + *, + name: str, + audio_adapter: Optional[RealtimeObserver] = None, + system_message: str = "You are a helpful AI Assistant.", + llm_config: Optional[Union[LLMConfig, dict[str, Any]]] = None, + logger: Optional[Logger] = None, + observers: Optional[list[RealtimeObserver]] = None, + **client_kwargs: Any, + ): + """(Experimental) Agent for interacting with the Realtime Clients. + + Args: + name (str): The name of the agent. + audio_adapter (Optional[RealtimeObserver] = None): The audio adapter for the agent. + system_message (str): The system message for the agent. + llm_config (LLMConfig, dict[str, Any], bool): The config for the agent. + logger (Optional[Logger]): The logger for the agent. + observers (Optional[list[RealtimeObserver]]): The additional observers for the agent. + **client_kwargs (Any): The keyword arguments for the client. + """ + self._logger = logger + self._name = name + self._system_message = system_message + + llm_config = LLMConfig.get_current_llm_config(llm_config) + + self._realtime_client: RealtimeClientProtocol = get_client( + llm_config=llm_config, logger=self.logger, **client_kwargs + ) + + self._registered_realtime_tools: dict[str, Tool] = {} + self._observers: list[RealtimeObserver] = observers if observers else [] + self._observers.append(FunctionObserver(logger=logger)) + if audio_adapter: + self._observers.append(audio_adapter) + + self.callbacks = RealtimeAgentCallbacks() + + @property + def system_message(self) -> str: + """Get the system message for the agent.""" + return self._system_message + + @property + def logger(self) -> Logger: + """Get the logger for the agent.""" + return self._logger or global_logger + + @property + def realtime_client(self) -> RealtimeClientProtocol: + """Get the OpenAI Realtime Client.""" + return self._realtime_client + + @property + def registered_realtime_tools(self) -> dict[str, Tool]: + """Get the registered realtime tools.""" + return self._registered_realtime_tools + + def register_observer(self, observer: RealtimeObserver) -> None: + """Register an observer with the Realtime Agent. + + Args: + observer (RealtimeObserver): The observer to register. + """ + self._observers.append(observer) + + async def start_observers(self) -> None: + for observer in self._observers: + self._tg.soonify(observer.run)(self) + + # wait for the observers to be ready + for observer in self._observers: + await observer.wait_for_ready() + + await self.callbacks.on_observers_ready() + + async def run(self) -> None: + """Run the agent.""" + # everything is run in the same task group to enable easy cancellation using self._tg.cancel_scope.cancel() + async with create_task_group() as self._tg: # noqa: SIM117 + # connect with the client first (establishes a connection and initializes a session) + async with self._realtime_client.connect(): + # start the observers and wait for them to be ready + await self.realtime_client.session_update(session_options={"instructions": self.system_message}) + await self.start_observers() + + # iterate over the events + async for event in self.realtime_client.read_events(): + for observer in self._observers: + await observer.on_event(event) + + def register_realtime_function( + self, + *, + name: Optional[str] = None, + description: Optional[str] = None, + ) -> Callable[[Union[F, Tool]], Tool]: + """Decorator for registering a function to be used by an agent. + + Args: + name (str): The name of the function. + description (str): The description of the function. + + Returns: + Callable[[Union[F, Tool]], Tool]: The decorator for registering a function. + """ + + def _decorator(func_or_tool: Union[F, Tool]) -> Tool: + """Decorator for registering a function to be used by an agent. + + Args: + func_or_tool (Union[F, Tool]): The function or tool to register. + + Returns: + Tool: The registered tool. + """ + tool = Tool(func_or_tool=func_or_tool, name=name, description=description) + + self._registered_realtime_tools[tool.name] = tool + + return tool + + return _decorator diff --git a/mm_agents/coact/autogen/agentchat/realtime/experimental/realtime_events.py b/mm_agents/coact/autogen/agentchat/realtime/experimental/realtime_events.py new file mode 100644 index 0000000..a94f8b3 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/realtime/experimental/realtime_events.py @@ -0,0 +1,42 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Literal + +from pydantic import BaseModel + + +class RealtimeEvent(BaseModel): + raw_message: dict[str, Any] + + +class SessionCreated(RealtimeEvent): + type: Literal["session.created"] = "session.created" + + +class SessionUpdated(RealtimeEvent): + type: Literal["session.updated"] = "session.updated" + + +class AudioDelta(RealtimeEvent): + type: Literal["response.audio.delta"] = "response.audio.delta" + delta: str + item_id: Any + + +class InputAudioBufferDelta(RealtimeEvent): + type: Literal["input_audio_buffer.delta"] = "input_audio_buffer.delta" + delta: str + item_id: Any + + +class SpeechStarted(RealtimeEvent): + type: Literal["input_audio_buffer.speech_started"] = "input_audio_buffer.speech_started" + + +class FunctionCall(RealtimeEvent): + type: Literal["response.function_call_arguments.done"] = "response.function_call_arguments.done" + name: str + arguments: dict[str, Any] + call_id: str diff --git a/mm_agents/coact/autogen/agentchat/realtime/experimental/realtime_observer.py b/mm_agents/coact/autogen/agentchat/realtime/experimental/realtime_observer.py new file mode 100644 index 0000000..ca6890b --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/realtime/experimental/realtime_observer.py @@ -0,0 +1,100 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from logging import Logger, getLogger +from typing import TYPE_CHECKING, Optional + +from anyio import Event + +from ....doc_utils import export_module +from .clients.realtime_client import RealtimeClientProtocol +from .realtime_events import RealtimeEvent + +if TYPE_CHECKING: + from .realtime_agent import RealtimeAgent + +__all__ = ["RealtimeObserver"] + +global_logger = getLogger(__name__) + + +@export_module("autogen.agentchat.realtime.experimental") +class RealtimeObserver(ABC): + """Observer for the OpenAI Realtime API.""" + + def __init__(self, *, logger: Optional[Logger] = None) -> None: + """Observer for the OpenAI Realtime API. + + Args: + logger (Logger): The logger for the observer. + """ + self._ready_event = Event() + self._agent: Optional[RealtimeAgent] = None + self._logger = logger + + @property + def logger(self) -> Logger: + return self._logger or global_logger + + @property + def agent(self) -> "RealtimeAgent": + if self._agent is None: + raise RuntimeError("Agent has not been set.") + return self._agent + + @property + def realtime_client(self) -> RealtimeClientProtocol: + if self._agent is None: + raise RuntimeError("Agent has not been set.") + if self._agent.realtime_client is None: + raise RuntimeError("Realtime client has not been set.") + + return self._agent.realtime_client + + async def run(self, agent: "RealtimeAgent") -> None: + """Run the observer with the agent. + + When implementing, be sure to call `self._ready_event.set()` when the observer is ready to process events. + + Args: + agent (RealtimeAgent): The realtime agent attached to the observer. + """ + self._agent = agent + await self.initialize_session() + self._ready_event.set() + + await self.run_loop() + + @abstractmethod + async def run_loop(self) -> None: + """Run the loop if needed. + + This method is called after the observer is ready to process events. + Events will be processed by the on_event method, this is just a hook for additional processing. + Use initialize_session to set up the session. + """ + ... + + @abstractmethod + async def initialize_session(self) -> None: + """Initialize the session for the observer.""" + ... + + async def wait_for_ready(self) -> None: + """Get the event that is set when the observer is ready.""" + await self._ready_event.wait() + + @abstractmethod + async def on_event(self, event: RealtimeEvent) -> None: + """Handle an event from the OpenAI Realtime API. + + Args: + event (RealtimeServerEvent): The event from the OpenAI Realtime API. + """ + ... + + async def on_close(self) -> None: + """Handle close of RealtimeClient.""" + ... diff --git a/mm_agents/coact/autogen/agentchat/realtime/experimental/realtime_swarm.py b/mm_agents/coact/autogen/agentchat/realtime/experimental/realtime_swarm.py new file mode 100644 index 0000000..ca7f2c3 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/realtime/experimental/realtime_swarm.py @@ -0,0 +1,483 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +import logging +import warnings +from collections import defaultdict +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union + +import anyio +from asyncer import asyncify, create_task_group, syncify + +from ....agentchat.contrib.swarm_agent import AfterWorkOption, initiate_swarm_chat +from ....cache import AbstractCache +from ....code_utils import content_str +from ....doc_utils import export_module +from ... import Agent, ChatResult, ConversableAgent, LLMAgent +from ...utils import consolidate_chat_info, gather_usage_summary + +if TYPE_CHECKING: + from .clients import Role + from .realtime_agent import RealtimeAgent + +__all__ = ["register_swarm"] + +SWARM_SYSTEM_MESSAGE = ( + "You are a helpful voice assistant. Your task is to listen to user and to coordinate the tasks based on his/her inputs." + "You can and will communicate using audio output only." +) + +QUESTION_ROLE: "Role" = "user" +QUESTION_MESSAGE = ( + "I have a question/information for myself. DO NOT ANSWER YOURSELF, GET THE ANSWER FROM ME. " + "repeat the question to me **WITH AUDIO OUTPUT** and AFTER YOU GET THE ANSWER FROM ME call 'answer_task_question' with the answer in first person\n\n" + "IMPORTANT: repeat just the question, without any additional information or context\n\n" + "The question is: '{}'\n\n" +) +QUESTION_TIMEOUT_SECONDS = 20 + +logger = logging.getLogger(__name__) + +F = TypeVar("F", bound=Callable[..., Any]) + + +def message_to_dict(message: Union[dict[str, Any], str]) -> dict[str, Any]: + if isinstance(message, str): + return {"content": message} + elif isinstance(message, dict): + return message + else: + return dict(message) + + +def parse_oai_message(message: Union[dict[str, Any], str], role: str, adressee: Agent) -> dict[str, Any]: + """ + Parse a message into an OpenAI-compatible message format. + + Args: + message: The message to parse. + role: The role associated with the message. + adressee: The agent that will receive the message. + + Returns: + The parsed message in OpenAI-compatible format. + + Raises: + ValueError: If the message lacks required fields like 'content', 'function_call', or 'tool_calls'. + """ + message = message_to_dict(message) + + # Extract relevant fields while ensuring none are None + oai_message = { + key: message[key] + for key in ("content", "function_call", "tool_calls", "tool_responses", "tool_call_id", "name", "context") + if key in message and message[key] is not None + } + + # Validate or set the content field + if "content" not in oai_message: + if "function_call" in oai_message or "tool_calls" in oai_message: + oai_message["content"] = None + else: + raise ValueError("Message must have either 'content', 'function_call', or 'tool_calls' field.") + + # Determine and assign the role + if message.get("role") in ["function", "tool"]: + oai_message["role"] = message["role"] + # Ensure all tool responses have string content + for tool_response in oai_message.get("tool_responses", []): + tool_response["content"] = str(tool_response["content"]) + elif "override_role" in message: + oai_message["role"] = message["override_role"] + else: + oai_message["role"] = role + + # Enforce specific role requirements for assistant messages + if oai_message.get("function_call") or oai_message.get("tool_calls"): + oai_message["role"] = "assistant" + + # Add a name field if missing + if "name" not in oai_message: + oai_message["name"] = adressee.name + + return oai_message + + +class SwarmableAgent(Agent): + """A class for an agent that can participate in a swarm chat.""" + + def __init__( + self, + name: str, + system_message: str = "You are a helpful AI Assistant.", + is_termination_msg: Optional[Callable[..., bool]] = None, + description: Optional[str] = None, + silent: Optional[bool] = None, + ): + self._oai_messages: dict[Agent, Any] = defaultdict(list) + + self._system_message = system_message + self._description = description if description is not None else system_message + self._is_termination_msg = ( + is_termination_msg + if is_termination_msg is not None + else (lambda x: content_str(x.get("content")) == "TERMINATE") + ) + self.silent = silent + + self._name = name + + # Initialize standalone client cache object. + self.client_cache = None + self.previous_cache = None + + self.reply_at_receive: dict[Agent, bool] = defaultdict(bool) + + @property + def system_message(self) -> str: + return self._system_message + + def update_system_message(self, system_message: str) -> None: + """Update this agent's system message. + + Args: + system_message (str): system message for inference. + """ + self._system_message = system_message + + @property + def name(self) -> str: + return self._name + + @property + def description(self) -> str: + return self._description + + def send( + self, + message: Union[dict[str, Any], str], + recipient: Agent, + request_reply: Optional[bool] = None, + silent: Optional[bool] = False, + ) -> None: + self._oai_messages[recipient].append(parse_oai_message(message, "assistant", recipient)) + recipient.receive(message, self, request_reply) + + def receive( + self, + message: Union[dict[str, Any], str], + sender: Agent, + request_reply: Optional[bool] = None, + silent: Optional[bool] = False, + ) -> None: + self._oai_messages[sender].append(parse_oai_message(message, "user", self)) + if request_reply is False or (request_reply is None and self.reply_at_receive[sender] is False): + return + reply = self.generate_reply(messages=self.chat_messages[sender], sender=sender) + if reply is not None: + self.send(reply, sender, silent=silent) + + def generate_reply( + self, + messages: Optional[list[dict[str, Any]]] = None, + sender: Optional["Agent"] = None, + **kwargs: Any, + ) -> Union[str, dict[str, Any], None]: + if messages is None: + if sender is None: + raise ValueError("Either messages or sender must be provided.") + messages = self._oai_messages[sender] + + _, reply = self.check_termination_and_human_reply(messages=messages, sender=sender, config=None) + + return reply + + def check_termination_and_human_reply( + self, + messages: Optional[list[dict[str, Any]]] = None, + sender: Optional[Agent] = None, + config: Optional[Any] = None, + ) -> tuple[bool, Union[str, None]]: + raise NotImplementedError + + def initiate_chat( + self, + recipient: ConversableAgent, + message: Union[dict[str, Any], str], + clear_history: bool = True, + silent: Optional[bool] = False, + cache: Optional[AbstractCache] = None, + summary_args: Optional[dict[str, Any]] = {}, + **kwargs: dict[str, Any], + ) -> ChatResult: + _chat_info = locals().copy() + _chat_info["sender"] = self + consolidate_chat_info(_chat_info, uniform_sender=self) + recipient._raise_exception_on_async_reply_functions() + recipient.previous_cache = recipient.client_cache # type: ignore[attr-defined] + recipient.client_cache = cache # type: ignore[attr-defined, assignment] + + self._prepare_chat(recipient, clear_history) + self.send(message, recipient, silent=silent) + summary = self._last_msg_as_summary(self, recipient, summary_args) + + recipient.client_cache = recipient.previous_cache # type: ignore[attr-defined] + recipient.previous_cache = None # type: ignore[attr-defined] + + chat_result = ChatResult( + chat_history=self.chat_messages[recipient], + summary=summary, + cost=gather_usage_summary([self, recipient]), # type: ignore[arg-type] + human_input=[], + ) + return chat_result + + async def a_generate_reply( + self, + messages: Optional[list[dict[str, Any]]] = None, + sender: Optional["Agent"] = None, + **kwargs: Any, + ) -> Union[str, dict[str, Any], None]: + return self.generate_reply(messages=messages, sender=sender, **kwargs) + + async def a_receive( + self, + message: Union[dict[str, Any], str], + sender: "Agent", + request_reply: Optional[bool] = None, + ) -> None: + self.receive(message, sender, request_reply) + + async def a_send( + self, + message: Union[dict[str, Any], str], + recipient: "Agent", + request_reply: Optional[bool] = None, + ) -> None: + self.send(message, recipient, request_reply) + + @property + def chat_messages(self) -> dict[Agent, list[dict[str, Any]]]: + """A dictionary of conversations from agent to list of messages.""" + return self._oai_messages + + def last_message(self, agent: Optional[Agent] = None) -> Optional[dict[str, Any]]: + if agent is None: + n_conversations = len(self._oai_messages) + if n_conversations == 0: + return None + if n_conversations == 1: + for conversation in self._oai_messages.values(): + return conversation[-1] # type: ignore[no-any-return] + raise ValueError("More than one conversation is found. Please specify the sender to get the last message.") + if agent not in self._oai_messages(): + raise KeyError( + f"The agent '{agent.name}' is not present in any conversation. No history available for this agent." + ) + return self._oai_messages[agent][-1] # type: ignore[no-any-return] + + def _prepare_chat( + self, + recipient: ConversableAgent, + clear_history: bool, + prepare_recipient: bool = True, + reply_at_receive: bool = True, + ) -> None: + self.reply_at_receive[recipient] = reply_at_receive + if clear_history: + self._oai_messages[recipient].clear() + if prepare_recipient: + recipient._prepare_chat(self, clear_history, False, reply_at_receive) # type: ignore[arg-type] + + def _raise_exception_on_async_reply_functions(self) -> None: + pass + + def set_ui_tools(self, tools: Optional[list] = None) -> None: + """Set UI tools for the agent.""" + pass + + def unset_ui_tools(self) -> None: + """Unset UI tools for the agent.""" + pass + + @staticmethod + def _last_msg_as_summary(sender: Agent, recipient: Agent, summary_args: Optional[dict[str, Any]]) -> str: + """Get a chat summary from the last message of the recipient.""" + summary = "" + try: + content = recipient.last_message(sender)["content"] # type: ignore[attr-defined] + if isinstance(content, str): + summary = content.replace("TERMINATE", "") + elif isinstance(content, list): + summary = "\n".join( + x["text"].replace("TERMINATE", "") for x in content if isinstance(x, dict) and "text" in x + ) + except (IndexError, AttributeError) as e: + warnings.warn(f"Cannot extract summary using last_msg: {e}. Using an empty str as summary.", UserWarning) + return summary + + +# check that the SwarmableAgent class is implementing LLMAgent protocol +if TYPE_CHECKING: + + def _create_swarmable_agent( + name: str, + system_message: str, + is_termination_msg: Optional[Callable[..., bool]], + description: Optional[str], + silent: Optional[bool], + ) -> LLMAgent: + return SwarmableAgent( + name=name, + system_message=system_message, + is_termination_msg=is_termination_msg, + description=description, + silent=silent, + ) + + +class SwarmableRealtimeAgent(SwarmableAgent): + def __init__( + self, + realtime_agent: "RealtimeAgent", + initial_agent: ConversableAgent, + agents: list[ConversableAgent], + question_message: Optional[str] = None, + ) -> None: + self._initial_agent = initial_agent + self._agents = agents + self._realtime_agent = realtime_agent + + self._answer_event: anyio.Event = anyio.Event() + self._answer: str = "" + self.question_message = question_message or QUESTION_MESSAGE + + super().__init__( + name=realtime_agent._name, + is_termination_msg=None, + description=None, + silent=None, + ) + + def reset_answer(self) -> None: + """Reset the answer event.""" + self._answer_event = anyio.Event() + + def set_answer(self, answer: str) -> str: + """Set the answer to the question.""" + self._answer = answer + self._answer_event.set() + return "Answer set successfully." + + async def get_answer(self) -> str: + """Get the answer to the question.""" + await self._answer_event.wait() + return self._answer + + async def ask_question(self, question: str, question_timeout: int) -> None: + """Send a question for the user to the agent and wait for the answer. + If the answer is not received within the timeout, the question is repeated. + + Args: + question: The question to ask the user. + question_timeout: The time in seconds to wait for the answer. + """ + self.reset_answer() + realtime_client = self._realtime_agent._realtime_client + await realtime_client.send_text(role=QUESTION_ROLE, text=question) + + async def _check_event_set(timeout: int = question_timeout) -> bool: + for _ in range(timeout): + if self._answer_event.is_set(): + return True + await anyio.sleep(1) + return False + + while not await _check_event_set(): + await realtime_client.send_text(role=QUESTION_ROLE, text=question) + + def check_termination_and_human_reply( + self, + messages: Optional[list[dict[str, Any]]] = None, + sender: Optional[Agent] = None, + config: Optional[Any] = None, + ) -> tuple[bool, Optional[str]]: + """Check if the conversation should be terminated and if the agent should reply. + + Called when its agents turn in the chat conversation. + + Args: + messages (list[dict[str, Any]]): The messages in the conversation. + sender (Agent): The agent that sent the message. + config (Optional[Any]): The configuration for the agent. + """ + if not messages: + return False, None + + async def get_input() -> None: + async with create_task_group() as tg: + tg.soonify(self.ask_question)( + self.question_message.format(messages[-1]["content"]), + question_timeout=QUESTION_TIMEOUT_SECONDS, + ) + + syncify(get_input)() + + return True, {"role": "user", "content": self._answer} # type: ignore[return-value] + + def start_chat(self) -> None: + raise NotImplementedError + + def configure_realtime_agent(self, system_message: Optional[str]) -> None: + realtime_agent = self._realtime_agent + + logger = realtime_agent.logger + if not system_message: + if realtime_agent.system_message != "You are a helpful AI Assistant.": + logger.warning( + "Overriding system message set up in `__init__`, please use `system_message` parameter of the `register_swarm` function instead." + ) + system_message = SWARM_SYSTEM_MESSAGE + + realtime_agent._system_message = system_message + + realtime_agent.register_realtime_function( + name="answer_task_question", description="Answer question from the task" + )(self.set_answer) + + async def on_observers_ready() -> None: + self._realtime_agent._tg.soonify(asyncify(initiate_swarm_chat))( + initial_agent=self._initial_agent, + agents=self._agents, + user_agent=self, # type: ignore[arg-type] + messages="Find out what the user wants.", + after_work=AfterWorkOption.REVERT_TO_USER, + ) + + self._realtime_agent.callbacks.on_observers_ready = on_observers_ready + + +@export_module("autogen.agentchat.realtime.experimental") +def register_swarm( + *, + realtime_agent: "RealtimeAgent", + initial_agent: ConversableAgent, + agents: list[ConversableAgent], + system_message: Optional[str] = None, + question_message: Optional[str] = None, +) -> None: + """Create a SwarmableRealtimeAgent. + + Args: + realtime_agent (RealtimeAgent): The RealtimeAgent to create the SwarmableRealtimeAgent from. + initial_agent (ConversableAgent): The initial agent. + agents (list[ConversableAgent]): The agents in the swarm. + system_message (Optional[str]): The system message to set for the agent. If None, the default system message is used. + question_message (Optional[str]): The question message to set for the agent. If None, the default QUESTION_MESSAGE is used. + """ + swarmable_agent = SwarmableRealtimeAgent( + realtime_agent=realtime_agent, initial_agent=initial_agent, agents=agents, question_message=question_message + ) + + swarmable_agent.configure_realtime_agent(system_message=system_message) diff --git a/mm_agents/coact/autogen/agentchat/realtime/experimental/websockets.py b/mm_agents/coact/autogen/agentchat/realtime/experimental/websockets.py new file mode 100644 index 0000000..e5ab408 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/realtime/experimental/websockets.py @@ -0,0 +1,21 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import AsyncIterator +from typing import Any, Protocol, runtime_checkable + +__all__ = ["WebSocketProtocol"] + + +@runtime_checkable +class WebSocketProtocol(Protocol): + """WebSocket protocol for sending and receiving JSON data modelled after FastAPI's WebSocket.""" + + async def send_json(self, data: Any, mode: str = "text") -> None: ... + + async def receive_json(self, mode: str = "text") -> Any: ... + + async def receive_text(self) -> str: ... + + def iter_text(self) -> AsyncIterator[str]: ... diff --git a/mm_agents/coact/autogen/agentchat/realtime_agent/__init__.py b/mm_agents/coact/autogen/agentchat/realtime_agent/__init__.py new file mode 100644 index 0000000..6e14b66 --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/realtime_agent/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from ..realtime.experimental import ( + FunctionObserver, + RealtimeAgent, + RealtimeObserver, + TwilioAudioAdapter, + WebSocketAudioAdapter, + register_swarm, +) + +__all__ = [ + "FunctionObserver", + "RealtimeAgent", + "RealtimeObserver", + "TwilioAudioAdapter", + "WebSocketAudioAdapter", + "register_swarm", +] diff --git a/mm_agents/coact/autogen/agentchat/user_proxy_agent.py b/mm_agents/coact/autogen/agentchat/user_proxy_agent.py new file mode 100644 index 0000000..134a74b --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/user_proxy_agent.py @@ -0,0 +1,111 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 +# +# Portions derived from https://github.com/microsoft/autogen are under the MIT License. +# SPDX-License-Identifier: MIT +from typing import Any, Callable, Literal, Optional, Union + +from ..doc_utils import export_module +from ..llm_config import LLMConfig +from ..runtime_logging import log_new_agent, logging_enabled +from .conversable_agent import ConversableAgent + + +@export_module("autogen") +class UserProxyAgent(ConversableAgent): + """(In preview) A proxy agent for the user, that can execute code and provide feedback to the other agents. + + UserProxyAgent is a subclass of ConversableAgent configured with `human_input_mode` to ALWAYS + and `llm_config` to False. By default, the agent will prompt for human input every time a message is received. + Code execution is enabled by default. LLM-based auto reply is disabled by default. + To modify auto reply, register a method with [`register_reply`](../ConversableAgent#register-reply). + To modify the way to get human input, override `get_human_input` method. + To modify the way to execute code blocks, single code block, or function call, override `execute_code_blocks`, + `run_code`, and `execute_function` methods respectively. + """ + + # Default UserProxyAgent.description values, based on human_input_mode + DEFAULT_USER_PROXY_AGENT_DESCRIPTIONS = { + "ALWAYS": "An attentive HUMAN user who can answer questions about the task, and can perform tasks such as running Python code or inputting command line commands at a Linux terminal and reporting back the execution results.", + "TERMINATE": "A user that can run Python code or input command line commands at a Linux terminal and report back the execution results.", + "NEVER": "A computer terminal that performs no other action than running Python scripts (provided to it quoted in ```python code blocks), or sh shell scripts (provided to it quoted in ```sh code blocks).", + } + + def __init__( + self, + name: str, + is_termination_msg: Optional[Callable[[dict[str, Any]], bool]] = None, + max_consecutive_auto_reply: Optional[int] = None, + human_input_mode: Literal["ALWAYS", "TERMINATE", "NEVER"] = "ALWAYS", + function_map: Optional[dict[str, Callable[..., Any]]] = None, + code_execution_config: Union[dict[str, Any], Literal[False]] = {}, + default_auto_reply: Optional[Union[str, dict[str, Any]]] = "", + llm_config: Optional[Union[LLMConfig, dict[str, Any], Literal[False]]] = False, + system_message: Optional[Union[str, list[str]]] = "", + description: Optional[str] = None, + **kwargs: Any, + ): + """Args: + name (str): name of the agent. + is_termination_msg (function): a function that takes a message in the form of a dictionary + and returns a boolean value indicating if this received message is a termination message. + The dict can contain the following keys: "content", "role", "name", "function_call". + max_consecutive_auto_reply (int): the maximum number of consecutive auto replies. + default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case). + The limit only plays a role when human_input_mode is not "ALWAYS". + human_input_mode (str): whether to ask for human inputs every time a message is received. + Possible values are "ALWAYS", "TERMINATE", "NEVER". + (1) When "ALWAYS", the agent prompts for human input every time a message is received. + Under this mode, the conversation stops when the human input is "exit", + or when is_termination_msg is True and there is no human input. + (2) When "TERMINATE", the agent only prompts for human input only when a termination message is received or + the number of auto reply reaches the max_consecutive_auto_reply. + (3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops + when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True. + function_map (dict[str, callable]): Mapping function names (passed to openai) to callable functions. + code_execution_config (dict or False): config for the code execution. + To disable code execution, set to False. Otherwise, set to a dictionary with the following keys: + - work_dir (Optional, str): The working directory for the code execution. + If None, a default working directory will be used. + The default working directory is the "extensions" directory under + "path_to_autogen". + - use_docker (Optional, list, str or bool): The docker image to use for code execution. + Default is True, which means the code will be executed in a docker container. A default list of images will be used. + If a list or a str of image name(s) is provided, the code will be executed in a docker container + with the first image successfully pulled. + If False, the code will be executed in the current environment. + We strongly recommend using docker for code execution. + - timeout (Optional, int): The maximum execution time in seconds. + - last_n_messages (Experimental, Optional, int): The number of messages to look back for code execution. Default to 1. + default_auto_reply (str or dict or None): the default auto reply message when no code execution or llm based reply is generated. + llm_config (LLMConfig or dict or False or None): llm inference configuration. + Please refer to [OpenAIWrapper.create](https://docs.ag2.ai/latest/docs/api-reference/autogen/OpenAIWrapper/#autogen.OpenAIWrapper.create) + for available options. + Default to False, which disables llm-based auto reply. + When set to None, will use self.DEFAULT_CONFIG, which defaults to False. + system_message (str or List): system message for ChatCompletion inference. + Only used when llm_config is not False. Use it to reprogram the agent. + description (str): a short description of the agent. This description is used by other agents + (e.g. the GroupChatManager) to decide when to call upon this agent. (Default: system_message) + **kwargs (dict): Please refer to other kwargs in + [ConversableAgent](https://docs.ag2.ai/latest/docs/api-reference/autogen/ConversableAgent). + """ + super().__init__( + name=name, + system_message=system_message, + is_termination_msg=is_termination_msg, + max_consecutive_auto_reply=max_consecutive_auto_reply, + human_input_mode=human_input_mode, + function_map=function_map, + code_execution_config=code_execution_config, + llm_config=llm_config, + default_auto_reply=default_auto_reply, + description=( + description if description is not None else self.DEFAULT_USER_PROXY_AGENT_DESCRIPTIONS[human_input_mode] + ), + **kwargs, + ) + + if logging_enabled(): + log_new_agent(self, locals()) diff --git a/mm_agents/coact/autogen/agentchat/utils.py b/mm_agents/coact/autogen/agentchat/utils.py new file mode 100644 index 0000000..d2784bc --- /dev/null +++ b/mm_agents/coact/autogen/agentchat/utils.py @@ -0,0 +1,206 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 +# +# Portions derived from https://github.com/microsoft/autogen are under the MIT License. +# SPDX-License-Identifier: MIT +import re +from typing import Any, Optional, Union + +from ..doc_utils import export_module +from .agent import Agent + + +def consolidate_chat_info( + chat_info: Union[dict[str, Any], list[dict[str, Any]]], uniform_sender: Optional[Agent] = None +) -> None: + if isinstance(chat_info, dict): + chat_info = [chat_info] + for c in chat_info: + if uniform_sender is None: + assert "sender" in c, "sender must be provided." + sender = c["sender"] + else: + sender = uniform_sender + assert "recipient" in c, "recipient must be provided." + summary_method = c.get("summary_method") + assert ( + summary_method is None or callable(summary_method) or summary_method in ("last_msg", "reflection_with_llm") + ), "summary_method must be a string chosen from 'reflection_with_llm' or 'last_msg' or a callable, or None." + if summary_method == "reflection_with_llm": + assert sender.client is not None or c["recipient"].client is not None, ( + "llm client must be set in either the recipient or sender when summary_method is reflection_with_llm." + ) + + +@export_module("autogen") +def gather_usage_summary(agents: list[Agent]) -> dict[str, dict[str, Any]]: + r"""Gather usage summary from all agents. + + Args: + agents: (list): List of agents. + + Returns: + dictionary: A dictionary containing two keys: + - "usage_including_cached_inference": Cost information on the total usage, including the tokens in cached inference. + - "usage_excluding_cached_inference": Cost information on the usage of tokens, excluding the tokens in cache. No larger than "usage_including_cached_inference". + + Example: + ```python + { + "usage_including_cached_inference": { + "total_cost": 0.0006090000000000001, + "gpt-35-turbo": { + "cost": 0.0006090000000000001, + "prompt_tokens": 242, + "completion_tokens": 123, + "total_tokens": 365, + }, + }, + "usage_excluding_cached_inference": { + "total_cost": 0.0006090000000000001, + "gpt-35-turbo": { + "cost": 0.0006090000000000001, + "prompt_tokens": 242, + "completion_tokens": 123, + "total_tokens": 365, + }, + }, + } + ``` + + Note: + If none of the agents incurred any cost (not having a client), then the usage_including_cached_inference and usage_excluding_cached_inference will be `{'total_cost': 0}`. + """ + + def aggregate_summary(usage_summary: dict[str, Any], agent_summary: dict[str, Any]) -> None: + if agent_summary is None: + return + usage_summary["total_cost"] += agent_summary.get("total_cost", 0) + for model, data in agent_summary.items(): + if model != "total_cost": + if model not in usage_summary: + usage_summary[model] = data.copy() + else: + usage_summary[model]["cost"] += data.get("cost", 0) + usage_summary[model]["prompt_tokens"] += data.get("prompt_tokens", 0) + usage_summary[model]["completion_tokens"] += data.get("completion_tokens", 0) + usage_summary[model]["total_tokens"] += data.get("total_tokens", 0) + + usage_including_cached_inference = {"total_cost": 0} + usage_excluding_cached_inference = {"total_cost": 0} + + for agent in agents: + if getattr(agent, "client", None): + aggregate_summary(usage_including_cached_inference, agent.client.total_usage_summary) # type: ignore[attr-defined] + aggregate_summary(usage_excluding_cached_inference, agent.client.actual_usage_summary) # type: ignore[attr-defined] + + return { + "usage_including_cached_inference": usage_including_cached_inference, + "usage_excluding_cached_inference": usage_excluding_cached_inference, + } + + +def parse_tags_from_content(tag: str, content: Union[str, list[dict[str, Any]]]) -> list[dict[str, Any]]: + """Parses HTML style tags from message contents. + + The parsing is done by looking for patterns in the text that match the format of HTML tags. The tag to be parsed is + specified as an argument to the function. The function looks for this tag in the text and extracts its content. The + content of a tag is everything that is inside the tag, between the opening and closing angle brackets. The content + can be a single string or a set of attribute-value pairs. + + Examples: + ` -> [{"tag": "img", "attr": {"src": "http://example.com/image.png"}, "match": re.Match}]` + ```