diff --git a/mm_agents/agent.py b/mm_agents/agent.py index a6b9ed8..8d9494d 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -1,12 +1,15 @@ import base64 +import hashlib import json import logging import os import re +import tempfile import time import xml.etree.ElementTree as ET from http import HTTPStatus from io import BytesIO +from pathlib import Path from typing import Dict, List import backoff @@ -32,6 +35,25 @@ def encode_image(image_content): return base64.b64encode(image_content).decode('utf-8') +def encoded_img_to_pil_img(data_str): + base64_str = data_str.replace("data:image/png;base64,", "") + image_data = base64.b64decode(base64_str) + image = Image.open(BytesIO(image_data)) + + return image + + +def save_to_tmp_img_file(data_str): + base64_str = data_str.replace("data:image/png;base64,", "") + image_data = base64.b64decode(base64_str) + image = Image.open(BytesIO(image_data)) + + tmp_img_path = os.path.join(tempfile.mkdtemp(), "tmp_img.png") + image.save(tmp_img_path) + + return tmp_img_path + + def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"): # leaf_nodes = find_leaf_nodes(accessibility_tree) filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform) @@ -695,14 +717,7 @@ class PromptAgent: print("Failed to call LLM: ", response.status_code) return "" - elif self.model.startswith("gemini"): - def encoded_img_to_pil_img(data_str): - base64_str = data_str.replace("data:image/png;base64,", "") - image_data = base64.b64decode(base64_str) - image = Image.open(BytesIO(image_data)) - - return image - + elif self.model in ["gemini-pro", "gemini-pro-vision"]: messages = payload["messages"] max_tokens = payload["max_tokens"] top_p = payload["top_p"] @@ -732,7 +747,7 @@ class PromptAgent: gemini_messages.append(gemini_message) - # the mistral not support system message in our endpoint, so we concatenate it at the first user message + # the gemini not support system message in our endpoint, so we concatenate it at the first user message if gemini_messages[0]['role'] == "system": gemini_messages[1]['parts'][0] = gemini_messages[0]['parts'][0] + "\n" + gemini_messages[1]['parts'][0] gemini_messages.pop(0) @@ -775,6 +790,93 @@ class PromptAgent: logger.error(f"count_tokens: {gemini_model.count_tokens(gemini_messages)}") logger.error(f"generation_config: {max_tokens}, {top_p}, {temperature}") return "" + + elif self.model == "gemini-1.5-pro-latest": + messages = payload["messages"] + max_tokens = payload["max_tokens"] + top_p = payload["top_p"] + temperature = payload["temperature"] + + uploaded_files = [] + + # def upload_if_needed(pathname: str) -> list[str]: + # path = Path(pathname) + # hash_id = hashlib.sha256(path.read_bytes()).hexdigest() + # try: + # existing_file = genai.get_file(name=hash_id) + # return [existing_file.uri] + # except: + # pass + # uploaded_files.append(genai.upload_file(path=path, display_name=hash_id)) + # return [uploaded_files[-1].uri] + + gemini_messages = [] + for i, message in enumerate(messages): + role_mapping = { + "assistant": "model", + "user": "user", + "system": "system" + } + assert len(message["content"]) in [1, 2], "One text, or one text with one image" + + # The gemini only support the last image as single image input + for part in message["content"]: + gemini_message = { + "role": role_mapping[message["role"]], + "parts": [] + } + if part['type'] == "image_url": + gemini_message['parts'].append(encoded_img_to_pil_img(part['image_url']['url'])) + elif part['type'] == "text": + gemini_message['parts'].append(part['text']) + else: + raise ValueError("Invalid content type: " + part['type']) + + gemini_messages.append(gemini_message) + + # the system message of gemini-1.5-pro-latest need to be inputted through model initialization parameter + system_instruction = None + if gemini_messages[0]['role'] == "system": + system_instruction = gemini_messages[0]['parts'][0] + gemini_messages.pop(0) + + api_key = os.environ.get("GENAI_API_KEY") + assert api_key is not None, "Please set the GENAI_API_KEY environment variable" + genai.configure(api_key=api_key) + logger.info("Generating content with Gemini model: %s", self.model) + request_options = {"timeout": 120} + gemini_model = genai.GenerativeModel( + self.model, + system_instruction=system_instruction + ) + try: + response = gemini_model.generate_content( + gemini_messages, + generation_config={ + "candidate_count": 1, + "max_output_tokens": max_tokens, + "top_p": top_p, + "temperature": temperature + }, + safety_settings={ + "harassment": "block_none", + "hate": "block_none", + "sex": "block_none", + "danger": "block_none" + }, + request_options=request_options + ) + for uploaded_file in uploaded_files: + genai.delete_file(name=uploaded_file.name) + return response.text + except Exception as e: + logger.error("Meet exception when calling Gemini API, " + str(e.__class__.__name__) + str(e)) + logger.error(f"count_tokens: {gemini_model.count_tokens(gemini_messages)}") + logger.error(f"generation_config: {max_tokens}, {top_p}, {temperature}") + for uploaded_file in uploaded_files: + genai.delete_file(name=uploaded_file.name) + return "" + elif self.model.startswith("qwen"): messages = payload["messages"] max_tokens = payload["max_tokens"]