diff --git a/mm_agents/agent.py b/mm_agents/agent.py index 7599b02..744ee9c 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -5,19 +5,20 @@ import os import re import time import uuid -import openai import xml.etree.ElementTree as ET from http import HTTPStatus from io import BytesIO from typing import Dict, List -from google.api_core.exceptions import InvalidArgument + import backoff import dashscope import google.generativeai as genai +import openai import requests from PIL import Image +from google.api_core.exceptions import InvalidArgument -from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes +from mm_agents.accessibility_tree_wrap.heuristic_retrieve import filter_nodes, draw_bounding_boxes from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \ SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \ SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \ @@ -422,7 +423,6 @@ class PromptAgent: # with open("messages.json", "w") as f: # f.write(json.dumps(messages, indent=4)) - logger.info("Generating content with GPT model: %s", self.model) response = self.call_llm({ "model": self.model, "messages": messages, @@ -461,7 +461,7 @@ class PromptAgent: "Content-Type": "application/json", "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}" } - # logger.info("Generating content with GPT model: %s", self.model) + logger.info("Generating content with GPT model: %s", self.model) response = requests.post( "https://api.openai.com/v1/chat/completions", headers=headers, @@ -495,7 +495,7 @@ class PromptAgent: temperature = payload["temperature"] claude_messages = [] - + for i, message in enumerate(messages): claude_message = { "role": message["role"], @@ -503,17 +503,17 @@ class PromptAgent: } assert len(message["content"]) in [1, 2], "One text, or one text with one image" for part in message["content"]: - + if part['type'] == "image_url": image_source = {} image_source["type"] = "base64" image_source["media_type"] = "image/png" image_source["data"] = part['image_url']['url'].replace("data:image/png;base64,", "") claude_message['content'].append({"type": "image", "source": image_source}) - + if part['type'] == "text": claude_message['content'].append({"type": "text", "text": part['text']}) - + claude_messages.append(claude_message) # the claude not support system message in our endpoint, so we concatenate it at the first user message @@ -522,7 +522,6 @@ class PromptAgent: claude_messages[1]['content'].insert(0, claude_system_message_item) claude_messages.pop(0) - headers = { "x-api-key": os.environ["ANTHROPIC_API_KEY"], "anthropic-version": "2023-06-01", @@ -540,7 +539,7 @@ class PromptAgent: headers=headers, json=payload ) - + if response.status_code != 200: logger.error("Failed to call LLM: " + response.text) @@ -674,6 +673,7 @@ class PromptAgent: try: return response.text except Exception as e: + logger.error("Meet exception when calling Gemini API, " + str(e)) return "" elif self.model.startswith("qwen"): messages = payload["messages"] diff --git a/run.py b/run.py index 3014e87..5e8e664 100644 --- a/run.py +++ b/run.py @@ -6,6 +6,7 @@ import datetime import json import logging import os +import random import sys from tqdm import tqdm @@ -69,7 +70,7 @@ def config() -> argparse.Namespace: "screenshot_a11y_tree", "som" ], - default="som", + default="a11y_tree", help="Observation type", ) parser.add_argument("--screen_width", type=int, default=1920) @@ -82,7 +83,7 @@ def config() -> argparse.Namespace: parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples") # lm config - parser.add_argument("--model", type=str, default="gpt-4-vision-preview") + parser.add_argument("--model", type=str, default="gpt-4-0125-preview") parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--top_p", type=float, default=0.9) parser.add_argument("--max_tokens", type=int, default=1500) @@ -147,7 +148,7 @@ def test( try: lib_run_single.run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores) - except Exception as e: + except TimeoutError as e: env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) logger.error(f"Time limit exceeded in {domain}/{example_id}") with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: @@ -155,6 +156,14 @@ def test( "Error": f"Time limit exceeded in {domain}/{example_id}" })) f.write("\n") + except Exception as e: + env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) + logger.error(f"Exception in {domain}/{example_id}" + str(e)) + with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + f.write(json.dumps({ + "Error": f"Exception in {domain}/{example_id}" + str(e) + })) + f.write("\n") env.close() logger.info(f"Average score: {sum(scores) / len(scores)}") @@ -193,15 +202,13 @@ def get_unfinished(action_space, use_model, observation_type, result_dir, total_ def get_result(action_space, use_model, observation_type, result_dir, total_file_json): target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + if not os.path.exists(target_dir): + print("New experiment, no result yet.") + return None all_result = [] - if not os.path.exists(target_dir): - return total_file_json - - finished = {} for domain in os.listdir(target_dir): - finished[domain] = [] domain_path = os.path.join(target_dir, domain) if os.path.isdir(domain_path): for example_id in os.listdir(domain_path): @@ -209,10 +216,17 @@ def get_result(action_space, use_model, observation_type, result_dir, total_file if os.path.isdir(example_path): if "result.txt" in os.listdir(example_path): # empty all files under example_id - all_result.append(float(open(os.path.join(example_path, "result.txt"), "r").read())) + try: + all_result.append(float(open(os.path.join(example_path, "result.txt"), "r").read())) + except: + all_result.append(0.0) - print("Success Rate:", sum(all_result) / len(all_result) * 100, "%") - return all_result + if not all_result: + print("New experiment, no result yet.") + return None + else: + print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%") + return all_result if __name__ == '__main__': @@ -242,4 +256,8 @@ if __name__ == '__main__': test_all_meta ) - # test(args, test_all_meta) + # make the order of key random in test_all_meta + for domain in test_all_meta: + random.shuffle(test_all_meta[domain]) + + test(args, test_all_meta)