This commit is contained in:
tsuky_chen
2024-02-02 14:50:16 +08:00
14 changed files with 490 additions and 182 deletions

View File

@@ -24,6 +24,8 @@ todo
- [x] Add accessibility tree from the OS into the observation space - [x] Add accessibility tree from the OS into the observation space
- [ ] Add pre-process and post-process action support for benchmarking setup and evaluation - [ ] Add pre-process and post-process action support for benchmarking setup and evaluation
- [ ] Multiprocess support, this can enable the reinforcement learning to be more efficient - [ ] Multiprocess support, this can enable the reinforcement learning to be more efficient
- [ ] Experiment logging and visualization system
- [ ] Add more tasks, maybe scale to 300 for v1.0.0, and create a dynamic leaderboard
## Road map of benchmark, tools and resources (Proposed) ## Road map of benchmark, tools and resources (Proposed)
- [ ] Improve the annotation tool base on DuckTrack, make it more robust which align on accessibility tree - [ ] Improve the annotation tool base on DuckTrack, make it more robust which align on accessibility tree

View File

@@ -11,7 +11,7 @@ logger = logging.getLogger("desktopenv.pycontroller")
class PythonController: class PythonController:
def __init__(self, vm_ip: str, pkgs_prefix: str = "import pyautogui; import time; {command}"): def __init__(self, vm_ip: str, pkgs_prefix: str = "import pyautogui; import time; pyautogui.FAILSAFE = False; {command}"):
self.vm_ip = vm_ip self.vm_ip = vm_ip
self.http_server = f"http://{vm_ip}:5000" self.http_server = f"http://{vm_ip}:5000"
self.pkgs_prefix = pkgs_prefix # fixme: this is a hacky way to execute python commands. fix it and combine it with installation of packages self.pkgs_prefix = pkgs_prefix # fixme: this is a hacky way to execute python commands. fix it and combine it with installation of packages

View File

@@ -7,6 +7,7 @@ import uuid
import tempfile import tempfile
from typing import Any, Union, Optional from typing import Any, Union, Optional
from typing import Dict, List from typing import Dict, List
import os
import requests import requests
from pydrive.auth import GoogleAuth from pydrive.auth import GoogleAuth
@@ -114,6 +115,7 @@ class SetupController:
if not os.path.exists(cache_path): if not os.path.exists(cache_path):
max_retries = 3 max_retries = 3
downloaded = False downloaded = False
e = None
for i in range(max_retries): for i in range(max_retries):
try: try:
response = requests.get(url, stream=True) response = requests.get(url, stream=True)
@@ -128,7 +130,7 @@ class SetupController:
break break
except requests.RequestException as e: except requests.RequestException as e:
logger.error(f"Failed to download {url}. Retrying... ({max_retries - i - 1} attempts left)") logger.error(f"Failed to download {url} caused by {e}. Retrying... ({max_retries - i - 1} attempts left)")
if not downloaded: if not downloaded:
raise requests.RequestException(f"Failed to download {url}. No retries left. Error: {e}") raise requests.RequestException(f"Failed to download {url}. No retries left. Error: {e}")
@@ -344,39 +346,49 @@ class SetupController:
port = 9222 # fixme: this port is hard-coded, need to be changed from config file port = 9222 # fixme: this port is hard-coded, need to be changed from config file
remote_debugging_url = f"http://{host}:{port}" remote_debugging_url = f"http://{host}:{port}"
with sync_playwright() as p: logger.info("Connect to Chrome @: %s", remote_debugging_url)
logger.debug("PLAYWRIGHT ENV: %s", repr(os.environ))
for attempt in range(15):
if attempt>0:
time.sleep(5)
browser = None browser = None
for attempt in range(15): with sync_playwright() as p:
try: try:
browser = p.chromium.connect_over_cdp(remote_debugging_url) browser = p.chromium.connect_over_cdp(remote_debugging_url)
break #break
except Exception as e: except Exception as e:
if attempt < 14: if attempt < 14:
logger.error(f"Attempt {attempt + 1}: Failed to connect, retrying. Error: {e}") logger.error(f"Attempt {attempt + 1}: Failed to connect, retrying. Error: {e}")
time.sleep(1) #time.sleep(10)
continue
else: else:
logger.error(f"Failed to connect after multiple attempts: {e}") logger.error(f"Failed to connect after multiple attempts: {e}")
raise e raise e
if not browser: if not browser:
return return
for i, url in enumerate(urls_to_open): logger.info("Opening %s...", urls_to_open)
# Use the first context (which should be the only one if using default profile) for i, url in enumerate(urls_to_open):
if i == 0: # Use the first context (which should be the only one if using default profile)
context = browser.contexts[0] if i == 0:
context = browser.contexts[0]
page = context.new_page() # Create a new page (tab) within the existing context page = context.new_page() # Create a new page (tab) within the existing context
page.goto(url, timeout=60000) try:
logger.info(f"Opened tab {i + 1}: {url}") page.goto(url, timeout=60000)
except:
logger.warning("Opening %s exceeds time limit", url) # only for human test
logger.info(f"Opened tab {i + 1}: {url}")
if i == 0: if i == 0:
# clear the default tab # clear the default tab
default_page = context.pages[0] default_page = context.pages[0]
default_page.close() default_page.close()
# Do not close the context or browser; they will remain open after script ends # Do not close the context or browser; they will remain open after script ends
return browser, context return browser, context
def _chrome_close_tabs_setup(self, urls_to_close: List[str]): def _chrome_close_tabs_setup(self, urls_to_close: List[str]):
time.sleep(5) # Wait for Chrome to finish launching time.sleep(5) # Wait for Chrome to finish launching
@@ -552,4 +564,4 @@ class SetupController:
else: else:
raise NotImplementedError raise NotImplementedError
return browser, context return browser, context

View File

@@ -191,7 +191,7 @@ To enable and use the HTTP interface in VLC Media Player for remote control and
#### 4. Configure Lua HTTP #### 4. Configure Lua HTTP
- Expand the `Main interfaces` node and select `Lua`. - Expand the `Main interfaces` node and select `Lua`.
- Under `Lua HTTP`, set a password in the `Lua HTTP` section. This password will be required to access the HTTP interface. - Under `Lua HTTP`, set a password `password` in the `Lua HTTP` section. This password will be required to access the HTTP interface.
#### 5. Save and Restart VLC #### 5. Save and Restart VLC
@@ -217,4 +217,4 @@ pip install opencv-python-headless Pillow imagehash
- If the port is in use by another application, you may change the port number in VLC's settings. - If the port is in use by another application, you may change the port number in VLC's settings.
## GIMP ## GIMP
Click on the "Keep" of the image loading pop-up. Click on the "Keep" of the image loading pop-up.

View File

@@ -2,7 +2,7 @@ import ctypes
import os import os
import platform import platform
import shlex import shlex
import subprocess import subprocess, signal
from pathlib import Path from pathlib import Path
from typing import Any, Optional from typing import Any, Optional
from typing import List, Dict, Tuple from typing import List, Dict, Tuple
@@ -997,7 +997,7 @@ def start_recording():
start_command = f"ffmpeg -y -f x11grab -draw_mouse 1 -s {screen_width}x{screen_height} -i :0.0 -c:v libx264 -r 30 {recording_path}" start_command = f"ffmpeg -y -f x11grab -draw_mouse 1 -s {screen_width}x{screen_height} -i :0.0 -c:v libx264 -r 30 {recording_path}"
recording_process = subprocess.Popen(shlex.split(start_command), stdout=subprocess.PIPE, stderr=subprocess.PIPE) recording_process = subprocess.Popen(shlex.split(start_command), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
return jsonify({'status': 'success', 'message': 'Started recording.'}) return jsonify({'status': 'success', 'message': 'Started recording.'})
@@ -1009,10 +1009,8 @@ def end_recording():
if not recording_process: if not recording_process:
return jsonify({'status': 'error', 'message': 'No recording in progress to stop.'}), 400 return jsonify({'status': 'error', 'message': 'No recording in progress to stop.'}), 400
recording_process.terminate() recording_process.send_signal(signal.SIGINT)
recording_process.wait() recording_process.wait()
# return_code = recording_process.returncode
output, error = recording_process.communicate()
recording_process = None recording_process = None
# return recording video file # return recording video file

View File

@@ -12,6 +12,16 @@
"--remote-debugging-port=9222" "--remote-debugging-port=9222"
] ]
} }
},
{
"type": "launch",
"parameters": {
"command": [
"socat",
"tcp-listen:9222,fork",
"tcp:localhost:1337"
]
}
} }
], ],
"trajectory": "trajectories/", "trajectory": "trajectories/",

View File

@@ -1,7 +1,7 @@
{ {
"id": "66399b0d-8fda-4618-95c4-bfc6191617e9", "id": "66399b0d-8fda-4618-95c4-bfc6191617e9",
"snapshot": "libreoffice_writer", "snapshot": "libreoffice_writer",
"instruction": "Could you help me insert a 7*5 empty table at the point of cursor?", "instruction": "Could you help me insert a 7(columns)*5(rows) empty table at the point of cursor?",
"source": "https://www.youtube.com/watch?v=l25Evu4ohKg", "source": "https://www.youtube.com/watch?v=l25Evu4ohKg",
"config": [ "config": [
{ {
@@ -27,7 +27,7 @@
"command": [ "command": [
"python", "python",
"-c", "-c",
"import pyautogui; import time; time.sleep(5); pyautogui.press(\"down\", presses=40, interval=10); time.sleep(1); pyautogui.scroll(-2)" "import pyautogui; import time; pyautogui.press(\"down\", presses=40, interval=0.01); time.sleep(1); pyautogui.scroll(-2)"
] ]
} }
} }

View File

@@ -38,7 +38,7 @@
"command": [ "command": [
"python", "python",
"-c", "-c",
"import pyautogui; import time; time.sleep(5); pyautogui.press(\"down\", presses=8, interval=3); time.sleep(1); pyautogui.scroll(-2)" "import pyautogui; import time; time.sleep(5); pyautogui.press(\"down\", presses=8, interval=0.01); time.sleep(1); pyautogui.scroll(-2)"
] ]
} }
} }
@@ -81,7 +81,7 @@
}, },
"expected": { "expected": {
"type": "cloud_file", "type": "cloud_file",
"path": "https://drive.usercontent.google.com/download?id=1xbhlfqGrPutHHi2aHg66jwXD-yaZpe9j&export=download&authuser=0&confirm=t&uuid=427765e0-3f97-4a72-92db-a1fe7cdde73b&at=APZUnTUhNLh2PDu4OGkCVQW-LPCd:1704173991269", "path": "https://drive.usercontent.google.com/download?id=1xbhlfqGrPutHHi2aHg66jwXD-yaZpe9j&export=download&authuser=0&confirm=t&uuid=802d477e-d97b-4641-84fb-9eaf8805c35c&at=APZUnTWS0KOqHCPnufPJfDEfGE2u:1706822844322",
"dest": "Viewing_Your_Class_Schedule_and_Textbooks_Gold.docx" "dest": "Viewing_Your_Class_Schedule_and_Textbooks_Gold.docx"
} }
} }

View File

@@ -1,7 +1,7 @@
{ {
"id": "8f080098-ddb1-424c-b438-4e96e5e4786e", "id": "8f080098-ddb1-424c-b438-4e96e5e4786e",
"snapshot": "base_setup", "snapshot": "base_setup",
"instruction": "Could you download the song from this music video and save it as an MP3 file? I'd like to have it on my device to play whenever I want. Please title the file \"Baby Justin Bieber.mp3.\" I really appreciate your help!", "instruction": "Could you download the song from this music video and save it as an MP3 file? I'd like to have it on my device to play whenever I want. Please save the file just on the desktop and title the file \"Baby Justin Bieber.mp3.\" I really appreciate your help!",
"source": "https://medium.com/@jetscribe_ai/how-to-extract-mp3-audio-from-videos-using-vlc-media-player-beeef644ebfb", "source": "https://medium.com/@jetscribe_ai/how-to-extract-mp3-audio-from-videos-using-vlc-media-player-beeef644ebfb",
"config": [ "config": [
{ {

View File

@@ -242,18 +242,18 @@ if __name__ == '__main__':
# main("libreoffice_calc", example_id) # main("libreoffice_calc", example_id)
impress_list = [ impress_list = [
# "5d901039-a89c-4bfb-967b-bf66f4df075e", "5d901039-a89c-4bfb-967b-bf66f4df075e",
# "550ce7e7-747b-495f-b122-acdc4d0b8e54", "550ce7e7-747b-495f-b122-acdc4d0b8e54",
# "455d3c66-7dc6-4537-a39a-36d3e9119df7", "455d3c66-7dc6-4537-a39a-36d3e9119df7",
# "af23762e-2bfd-4a1d-aada-20fa8de9ce07", "af23762e-2bfd-4a1d-aada-20fa8de9ce07",
# "c59742c0-4323-4b9d-8a02-723c251deaa0", "c59742c0-4323-4b9d-8a02-723c251deaa0",
# "ef9d12bd-bcee-4ba0-a40e-918400f43ddf", "ef9d12bd-bcee-4ba0-a40e-918400f43ddf",
# "9ec204e4-f0a3-42f8-8458-b772a6797cab", "9ec204e4-f0a3-42f8-8458-b772a6797cab",
# "0f84bef9-9790-432e-92b7-eece357603fb", "0f84bef9-9790-432e-92b7-eece357603fb",
# "ce88f674-ab7a-43da-9201-468d38539e4a", "ce88f674-ab7a-43da-9201-468d38539e4a",
# "3b27600c-3668-4abd-8f84-7bcdebbccbdb", "3b27600c-3668-4abd-8f84-7bcdebbccbdb",
# "a097acff-6266-4291-9fbd-137af7ecd439", "a097acff-6266-4291-9fbd-137af7ecd439",
# "bf4e9888-f10f-47af-8dba-76413038b73c", "bf4e9888-f10f-47af-8dba-76413038b73c",
"21760ecb-8f62-40d2-8d85-0cee5725cb72" "21760ecb-8f62-40d2-8d85-0cee5725cb72"
] ]
# for example_id in impress_list: # for example_id in impress_list:
@@ -326,27 +326,106 @@ if __name__ == '__main__':
# logger.error("An error occurred while running the example: %s", e) # logger.error("An error occurred while running the example: %s", e)
# continue # continue
from tqdm import tqdm
# for example_id in tqdm(vlc_list):
# try:
# main("vlc", example_id, gpt4_model="gpt-3.5-turbo-16k")
# except Exception as e:
# print(f"An error occurred while running the example: {e}")
# continue
chrome_list = [
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
"06fe7178-4491-4589-810f-2e2bc9502122",
"e1e75309-3ddb-4d09-92ec-de869c928143",
"35253b65-1c19-4304-8aa4-6884b8218fc0",
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
"7a5a7856-f1b6-42a4-ade9-1ca81ca0f263",
"44ee5668-ecd5-4366-a6ce-c1c9b8d4e938",
"2ae9ba84-3a0d-4d4c-8338-3a1478dc5fe3",
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
"af630914-714e-4a24-a7bb-f9af687d3b91"
]
# for example_id in tqdm(chrome_list):
# try:
# main("chrome", example_id, gpt4_model="gpt-3.5-turbo-16k")
# except Exception as e:
# print(f"An error occurred while running the example: {e}")
# continue
vs_code_list = [
# "0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
# "53ad5833-3455-407b-bbc6-45b4c79ab8fb",
# "eabc805a-bfcf-4460-b250-ac92135819f6",
# "982d12a5-beab-424f-8d38-d2a48429e511",
# "4e60007a-f5be-4bfc-9723-c39affa0a6d3",
# "e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
# "9439a27b-18ae-42d8-9778-5f68f891805e",
# "ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
# "930fdb3b-11a8-46fe-9bac-577332e2640e",
# "276cc624-87ea-4f08-ab93-f770e3790175",
# "9d425400-e9b2-4424-9a4b-d4c7abac4140"
]
for example_id in tqdm(vs_code_list):
try:
main("vs_code", example_id, gpt4_model="gpt-3.5-turbo-16k")
except Exception as e:
print(f"An error occurred while running the example: {e}")
continue
thunderbird_list = [
# "bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
# "7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
"12086550-11c0-466b-b367-1d9e75b3910e",
"06fe7178-4491-4589-810f-2e2bc9502122",
"6766f2b8-8a72-417f-a9e5-56fcaa735837",
"e1e75309-3ddb-4d09-92ec-de869c928143",
"3d1682a7-0fb0-49ae-a4dc-a73afd2d06d5",
"35253b65-1c19-4304-8aa4-6884b8218fc0",
"d088f539-cab4-4f9a-ac92-9999fc3a656e",
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
"030eeff7-b492-4218-b312-701ec99ee0cc",
"94760984-3ff5-41ee-8347-cf1af709fea0",
"99146c54-4f37-4ab8-9327-5f3291665e1e",
"c9e7eaf2-b1a1-4efc-a982-721972fa9f02"
]
# for example_id in tqdm(thunderbird_list):
# try:
# main("thunderbird", example_id, gpt4_model="gpt-3.5-turbo-16k")
# except Exception as e:
# print(f"An error occurred while running the example: {e}")
# continue
multiple_list = [ multiple_list = [
"f8cfa149-d1c1-4215-8dac-4a0932bad3c2", # "f8cfa149-d1c1-4215-8dac-4a0932bad3c2",
"897e3b53-5d4d-444b-85cb-2cdc8a97d903", # "897e3b53-5d4d-444b-85cb-2cdc8a97d903",
"4e9f0faf-2ecc-4ae8-a804-28c9a75d1ddc", "2fe4b718-3bd7-46ec-bdce-b184f5653624",
"b52b40a5-ad70-4c53-b5b0-5650a8387052",
"46407397-a7d5-4c6b-92c6-dbe038b1457b",
"2b9493d7-49b8-493a-a71b-56cd1f4d6908",
"51f5801c-18b3-4f25-b0c3-02f85507a078",
"2c9fc0de-3ee7-45e1-a5df-c86206ad78b5",
"510f64c8-9bcc-4be1-8d30-638705850618",
"937087b6-f668-4ba6-9110-60682ee33441",
"ee9a3c83-f437-4879-8918-be5efbb9fac7",
"3680a5ee-6870-426a-a997-eba929a0d25c", "3680a5ee-6870-426a-a997-eba929a0d25c",
"e135df7c-7687-4ac0-a5f0-76b74438b53e", # "4e9f0faf-2ecc-4ae8-a804-28c9a75d1ddc",
# "b52b40a5-ad70-4c53-b5b0-5650a8387052",
# "46407397-a7d5-4c6b-92c6-dbe038b1457b",
# "2b9493d7-49b8-493a-a71b-56cd1f4d6908",
# "51f5801c-18b3-4f25-b0c3-02f85507a078",
"58565672-7bfe-48ab-b828-db349231de6b", "58565672-7bfe-48ab-b828-db349231de6b",
"2fe4b718-3bd7-46ec-bdce-b184f5653624" # "2c9fc0de-3ee7-45e1-a5df-c86206ad78b5",
# "510f64c8-9bcc-4be1-8d30-638705850618",
# "937087b6-f668-4ba6-9110-60682ee33441",
# "ee9a3c83-f437-4879-8918-be5efbb9fac7",
# "3680a5ee-6870-426a-a997-eba929a0d25c",
# "e135df7c-7687-4ac0-a5f0-76b74438b53e",
"ee9a3c83-f437-4879-8918-be5efbb9fac7",
# "58565672-7bfe-48ab-b828-db349231de6b",
# "2fe4b718-3bd7-46ec-bdce-b184f5653624"
] ]
for example_id in multiple_list: for example_id in multiple_list:
try: try:
main("multi_apps", example_id) main("multi_apps", example_id, gpt4_model="gpt-3.5-turbo-16k")
except Exception as e: except Exception as e:
logger.error("An error occurred while running the example: %s", e) logger.error("An error occurred while running the example: %s", e)
continue continue

View File

@@ -123,9 +123,8 @@ def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_tr
logger.info("Environment closed.") logger.info("Environment closed.")
def main(example_class, example_id): def main(example_class, example_id, gpt4_model = "gpt-4-vision-preview"):
action_space = "pyautogui" action_space = "pyautogui"
gpt4_model = "gpt-4-vision-preview"
gemini_model = "gemini-pro-vision" gemini_model = "gemini-pro-vision"
logger.info("Running example %s/%s", example_class, example_id) logger.info("Running example %s/%s", example_class, example_id)
@@ -137,7 +136,7 @@ def main(example_class, example_id):
example["snapshot"] = "exp_v5" example["snapshot"] = "exp_v5"
api_key = os.environ.get("OPENAI_API_KEY") api_key = os.environ.get("OPENAI_API_KEY")
agent = GPT4v_Agent(api_key=api_key, instruction=example['instruction'], action_space=action_space, agent = GPT4v_Agent(api_key=api_key, model=gpt4_model, instruction=example['instruction'], action_space=action_space,
exp="screenshot") exp="screenshot")
# #
# api_key = os.environ.get("GENAI_API_KEY") # api_key = os.environ.get("GENAI_API_KEY")
@@ -168,35 +167,49 @@ if __name__ == '__main__':
"af630914-714e-4a24-a7bb-f9af687d3b91" "af630914-714e-4a24-a7bb-f9af687d3b91"
] ]
calc_list = [ calc_list = [
# "eb03d19a-b88d-4de4-8a64-ca0ac66f426b", "a9f325aa-8c05-4e4f-8341-9e4358565f4f",
# "0bf05a7d-b28b-44d2-955a-50b41e24012a", "ecb0df7a-4e8d-4a03-b162-053391d3afaf",
# "7a4e4bc8-922c-4c84-865c-25ba34136be1", "7efeb4b1-3d19-4762-b163-63328d66303b",
# "2bd59342-0664-4ccb-ba87-79379096cc08", "4e6fcf72-daf3-439f-a232-c434ce416af6",
# "ecb0df7a-4e8d-4a03-b162-053391d3afaf", "6054afcb-5bab-4702-90a0-b259b5d3217c",
# "7efeb4b1-3d19-4762-b163-63328d66303b", "abed40dc-063f-4598-8ba5-9fe749c0615d",
# "4e6fcf72-daf3-439f-a232-c434ce416af6", "01b269ae-2111-4a07-81fd-3fcd711993b0",
# "6054afcb-5bab-4702-90a0-b259b5d3217c", "8b1ce5f2-59d2-4dcc-b0b0-666a714b9a14",
# "abed40dc-063f-4598-8ba5-9fe749c0615d", "af2b02f7-acee-4be4-8b66-499fab394915",
# "01b269ae-2111-4a07-81fd-3fcd711993b0", "da1d63b8-fa12-417b-ba18-f748e5f770f3",
# "8b1ce5f2-59d2-4dcc-b0b0-666a714b9a14", "636380ea-d5f6-4474-b6ca-b2ed578a20f1",
"0cecd4f3-74de-457b-ba94-29ad6b5dafb6", "5ba77536-05c5-4aae-a9ff-6e298d094c3e",
"4188d3a4-077d-46b7-9c86-23e1a036f6c1", "4bc4eaf4-ca5e-4db2-8138-8d4e65af7c0b",
"51b11269-2ca8-4b2a-9163-f21758420e78", "672a1b02-c62f-4ae2-acf0-37f5fb3052b0",
"7e429b8d-a3f0-4ed0-9b58-08957d00b127", "648fe544-16ba-44af-a587-12ccbe280ea6",
"347ef137-7eeb-4c80-a3bb-0951f26a8aff", "8985d1e4-5b99-4711-add4-88949ebb2308",
"6e99a1ad-07d2-4b66-a1ce-ece6d99c20a5", "9e606842-2e27-43bf-b1d1-b43289c9589b",
"3aaa4e37-dc91-482e-99af-132a612d40f3", "fcb6e45b-25c4-4087-9483-03d714f473a9",
"37608790-6147-45d0-9f20-1137bb35703d", "68c0c5b7-96f3-4e87-92a7-6c1b967fd2d2",
"f9584479-3d0d-4c79-affa-9ad7afdd8850", "fff629ea-046e-4793-8eec-1a5a15c3eb35",
"d681960f-7bc3-4286-9913-a8812ba3261a", "5c9a206c-bb00-4fb6-bb46-ee675c187df5",
"21df9241-f8d7-4509-b7f1-37e501a823f7", "e975ae74-79bd-4672-8d1c-dc841a85781d",
"1334ca3e-f9e3-4db8-9ca7-b4c653be7d17", "34a6938a-58da-4897-8639-9b90d6db5391",
"357ef137-7eeb-4c80-a3bb-0951f26a8aff", "b5a22759-b4eb-4bf2-aeed-ad14e8615f19",
"aa3a8974-2e85-438b-b29e-a64df44deb4b", "2f9913a1-51ed-4db6-bfe0-7e1c95b3139e",
"a01fbce3-2793-461f-ab86-43680ccbae25", "2558031e-401d-4579-8e00-3ecf540fb492",
"4f07fbe9-70de-4927-a4d5-bb28bc12c52c", "0cecd4f3-74de-457b-ba94-29ad6b5dafb6",
] "4188d3a4-077d-46b7-9c86-23e1a036f6c1",
"51b11269-2ca8-4b2a-9163-f21758420e78",
"7e429b8d-a3f0-4ed0-9b58-08957d00b127",
"347ef137-7eeb-4c80-a3bb-0951f26a8aff",
"6e99a1ad-07d2-4b66-a1ce-ece6d99c20a5",
"3aaa4e37-dc91-482e-99af-132a612d40f3",
"37608790-6147-45d0-9f20-1137bb35703d",
"f9584479-3d0d-4c79-affa-9ad7afdd8850",
"d681960f-7bc3-4286-9913-a8812ba3261a",
"21df9241-f8d7-4509-b7f1-37e501a823f7",
"1334ca3e-f9e3-4db8-9ca7-b4c653be7d17",
"357ef137-7eeb-4c80-a3bb-0951f26a8aff",
"aa3a8974-2e85-438b-b29e-a64df44deb4b",
"a01fbce3-2793-461f-ab86-43680ccbae25",
"4f07fbe9-70de-4927-a4d5-bb28bc12c52c"
]
# for example_id in calc_list: # for example_id in calc_list:
# main("libreoffice_calc", example_id) # main("libreoffice_calc", example_id)
@@ -246,13 +259,13 @@ if __name__ == '__main__':
# #
vs_code_list = [ vs_code_list = [
"0ed39f63-6049-43d4-ba4d-5fa2fe04a951", # "0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
"53ad5833-3455-407b-bbc6-45b4c79ab8fb", # "53ad5833-3455-407b-bbc6-45b4c79ab8fb",
"eabc805a-bfcf-4460-b250-ac92135819f6", # "eabc805a-bfcf-4460-b250-ac92135819f6",
"982d12a5-beab-424f-8d38-d2a48429e511", # "982d12a5-beab-424f-8d38-d2a48429e511",
"4e60007a-f5be-4bfc-9723-c39affa0a6d3", # "4e60007a-f5be-4bfc-9723-c39affa0a6d3",
"e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2", # "e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
"9439a27b-18ae-42d8-9778-5f68f891805e", # "9439a27b-18ae-42d8-9778-5f68f891805e",
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae", "ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
"930fdb3b-11a8-46fe-9bac-577332e2640e", "930fdb3b-11a8-46fe-9bac-577332e2640e",
"276cc624-87ea-4f08-ab93-f770e3790175", "276cc624-87ea-4f08-ab93-f770e3790175",
@@ -266,28 +279,28 @@ if __name__ == '__main__':
# logger.error("An error occurred while running the example: %s", e) # logger.error("An error occurred while running the example: %s", e)
# continue # continue
multiple_list = [ # multiple_list = [
"f8cfa149-d1c1-4215-8dac-4a0932bad3c2", # "f8cfa149-d1c1-4215-8dac-4a0932bad3c2",
"897e3b53-5d4d-444b-85cb-2cdc8a97d903", # "897e3b53-5d4d-444b-85cb-2cdc8a97d903",
"4e9f0faf-2ecc-4ae8-a804-28c9a75d1ddc", # "4e9f0faf-2ecc-4ae8-a804-28c9a75d1ddc",
"b52b40a5-ad70-4c53-b5b0-5650a8387052", # "b52b40a5-ad70-4c53-b5b0-5650a8387052",
"46407397-a7d5-4c6b-92c6-dbe038b1457b", # "46407397-a7d5-4c6b-92c6-dbe038b1457b",
"2b9493d7-49b8-493a-a71b-56cd1f4d6908", # "2b9493d7-49b8-493a-a71b-56cd1f4d6908",
"51f5801c-18b3-4f25-b0c3-02f85507a078", # "51f5801c-18b3-4f25-b0c3-02f85507a078",
"2c9fc0de-3ee7-45e1-a5df-c86206ad78b5", # "2c9fc0de-3ee7-45e1-a5df-c86206ad78b5",
"510f64c8-9bcc-4be1-8d30-638705850618", # "510f64c8-9bcc-4be1-8d30-638705850618",
"937087b6-f668-4ba6-9110-60682ee33441", # "937087b6-f668-4ba6-9110-60682ee33441",
"ee9a3c83-f437-4879-8918-be5efbb9fac7", # "ee9a3c83-f437-4879-8918-be5efbb9fac7",
"3680a5ee-6870-426a-a997-eba929a0d25c", # "3680a5ee-6870-426a-a997-eba929a0d25c",
"e135df7c-7687-4ac0-a5f0-76b74438b53e", # "e135df7c-7687-4ac0-a5f0-76b74438b53e",
"58565672-7bfe-48ab-b828-db349231de6b", # "58565672-7bfe-48ab-b828-db349231de6b",
"2fe4b718-3bd7-46ec-bdce-b184f5653624" # "2fe4b718-3bd7-46ec-bdce-b184f5653624"
] # ]
#
for example_id in multiple_list: # for example_id in multiple_list:
try: # try:
main("multi_apps", example_id) # main("multi_apps", example_id)
except Exception as e: # except Exception as e:
logger.error("An error occurred while running the example: %s", e) # logger.error("An error occurred while running the example: %s", e)
continue # continue

View File

@@ -43,7 +43,7 @@ logger.addHandler(sdebug_handler)
logger = logging.getLogger("desktopenv.experiment") logger = logging.getLogger("desktopenv.experiment")
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx" PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu2\Ubuntu2.vmx"
# PATH_TO_VM = "../../../../大文件/镜像/Ubuntu-1218/Ubuntu/Ubuntu.vmx" # PATH_TO_VM = "../../../../大文件/镜像/Ubuntu-1218/Ubuntu/Ubuntu.vmx"
@@ -185,34 +185,49 @@ if __name__ == '__main__':
# continue # continue
calc_list = [ calc_list = [
# "eb03d19a-b88d-4de4-8a64-ca0ac66f426b", "a9f325aa-8c05-4e4f-8341-9e4358565f4f",
# "0bf05a7d-b28b-44d2-955a-50b41e24012a", "ecb0df7a-4e8d-4a03-b162-053391d3afaf",
# "7a4e4bc8-922c-4c84-865c-25ba34136be1", "7efeb4b1-3d19-4762-b163-63328d66303b",
# "2bd59342-0664-4ccb-ba87-79379096cc08", "4e6fcf72-daf3-439f-a232-c434ce416af6",
# "ecb0df7a-4e8d-4a03-b162-053391d3afaf", "6054afcb-5bab-4702-90a0-b259b5d3217c",
# "7efeb4b1-3d19-4762-b163-63328d66303b", "abed40dc-063f-4598-8ba5-9fe749c0615d",
# "4e6fcf72-daf3-439f-a232-c434ce416af6", "01b269ae-2111-4a07-81fd-3fcd711993b0",
# "6054afcb-5bab-4702-90a0-b259b5d3217c", "8b1ce5f2-59d2-4dcc-b0b0-666a714b9a14",
# "abed40dc-063f-4598-8ba5-9fe749c0615d", "af2b02f7-acee-4be4-8b66-499fab394915",
# "01b269ae-2111-4a07-81fd-3fcd711993b0", "da1d63b8-fa12-417b-ba18-f748e5f770f3",
# "8b1ce5f2-59d2-4dcc-b0b0-666a714b9a14", "636380ea-d5f6-4474-b6ca-b2ed578a20f1",
# "0cecd4f3-74de-457b-ba94-29ad6b5dafb6", "5ba77536-05c5-4aae-a9ff-6e298d094c3e",
# "4188d3a4-077d-46b7-9c86-23e1a036f6c1", "4bc4eaf4-ca5e-4db2-8138-8d4e65af7c0b",
# "51b11269-2ca8-4b2a-9163-f21758420e78", "672a1b02-c62f-4ae2-acf0-37f5fb3052b0",
# "7e429b8d-a3f0-4ed0-9b58-08957d00b127", "648fe544-16ba-44af-a587-12ccbe280ea6",
# "347ef137-7eeb-4c80-a3bb-0951f26a8aff", "8985d1e4-5b99-4711-add4-88949ebb2308",
# "6e99a1ad-07d2-4b66-a1ce-ece6d99c20a5", "9e606842-2e27-43bf-b1d1-b43289c9589b",
# "3aaa4e37-dc91-482e-99af-132a612d40f3", "fcb6e45b-25c4-4087-9483-03d714f473a9",
# "37608790-6147-45d0-9f20-1137bb35703d", "68c0c5b7-96f3-4e87-92a7-6c1b967fd2d2",
# "f9584479-3d0d-4c79-affa-9ad7afdd8850", "fff629ea-046e-4793-8eec-1a5a15c3eb35",
"d681960f-7bc3-4286-9913-a8812ba3261a", "5c9a206c-bb00-4fb6-bb46-ee675c187df5",
"21df9241-f8d7-4509-b7f1-37e501a823f7", "e975ae74-79bd-4672-8d1c-dc841a85781d",
"1334ca3e-f9e3-4db8-9ca7-b4c653be7d17", "34a6938a-58da-4897-8639-9b90d6db5391",
"357ef137-7eeb-4c80-a3bb-0951f26a8aff", "b5a22759-b4eb-4bf2-aeed-ad14e8615f19",
"aa3a8974-2e85-438b-b29e-a64df44deb4b", "2f9913a1-51ed-4db6-bfe0-7e1c95b3139e",
"a01fbce3-2793-461f-ab86-43680ccbae25", "2558031e-401d-4579-8e00-3ecf540fb492",
"4f07fbe9-70de-4927-a4d5-bb28bc12c52c", "0cecd4f3-74de-457b-ba94-29ad6b5dafb6",
] "4188d3a4-077d-46b7-9c86-23e1a036f6c1",
"51b11269-2ca8-4b2a-9163-f21758420e78",
"7e429b8d-a3f0-4ed0-9b58-08957d00b127",
"347ef137-7eeb-4c80-a3bb-0951f26a8aff",
"6e99a1ad-07d2-4b66-a1ce-ece6d99c20a5",
"3aaa4e37-dc91-482e-99af-132a612d40f3",
"37608790-6147-45d0-9f20-1137bb35703d",
"f9584479-3d0d-4c79-affa-9ad7afdd8850",
"d681960f-7bc3-4286-9913-a8812ba3261a",
"21df9241-f8d7-4509-b7f1-37e501a823f7",
"1334ca3e-f9e3-4db8-9ca7-b4c653be7d17",
"357ef137-7eeb-4c80-a3bb-0951f26a8aff",
"aa3a8974-2e85-438b-b29e-a64df44deb4b",
"a01fbce3-2793-461f-ab86-43680ccbae25",
"4f07fbe9-70de-4927-a4d5-bb28bc12c52c"
]
# for example_id in calc_list: # for example_id in calc_list:
# try: # try:
@@ -283,9 +298,64 @@ if __name__ == '__main__':
"2fe4b718-3bd7-46ec-bdce-b184f5653624" "2fe4b718-3bd7-46ec-bdce-b184f5653624"
] ]
for example_id in multiple_list: # for example_id in multiple_list:
# try:
# main("multi_apps", example_id)
# except Exception as e:
# logger.error("An error occurred while running the example: %s", e)
# continue
chrome_list = [
# "bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
"06fe7178-4491-4589-810f-2e2bc9502122",
"e1e75309-3ddb-4d09-92ec-de869c928143",
"35253b65-1c19-4304-8aa4-6884b8218fc0",
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
"7a5a7856-f1b6-42a4-ade9-1ca81ca0f263",
"44ee5668-ecd5-4366-a6ce-c1c9b8d4e938",
"2ae9ba84-3a0d-4d4c-8338-3a1478dc5fe3",
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
"af630914-714e-4a24-a7bb-f9af687d3b91"
]
# for example_id in chrome_list:
# try:
# main("chrome", example_id)
# except Exception as e:
# logger.error("An error occurred while running the example: %s", e)
# continue
writer_list = [
"6ada715d-3aae-4a32-a6a7-429b2e43fb93",
"ecc2413d-8a48-416e-a3a2-d30106ca36cb",
"0e47de2a-32e0-456c-a366-8c607ef7a9d2",
"4bcb1253-a636-4df4-8cb0-a35c04dfef31",
"0810415c-bde4-4443-9047-d5f70165a697",
"e528b65e-1107-4b8c-8988-490e4fece599",
"66399b0d-8fda-4618-95c4-bfc6191617e9",
"936321ce-5236-426a-9a20-e0e3c5dc536f",
"3ef2b351-8a84-4ff2-8724-d86eae9b842e",
"0b17a146-2934-46c7-8727-73ff6b6483e8",
"0e763496-b6bb-4508-a427-fad0b6c3e195",
"f178a4a9-d090-4b56-bc4c-4b72a61a035d",
"adf5e2c3-64c7-4644-b7b6-d2f0167927e7",
"0a0faba3-5580-44df-965d-f562a99b291c",
"e246f6d8-78d7-44ac-b668-fcf47946cb50",
"8472fece-c7dd-4241-8d65-9b3cd1a0b568",
"88fe4b2d-3040-4c70-9a70-546a47764b48",
"d53ff5ee-3b1a-431e-b2be-30ed2673079b",
"72b810ef-4156-4d09-8f08-a0cf57e7cefe",
"6f81754e-285d-4ce0-b59e-af7edb02d108",
"b21acd93-60fd-4127-8a43-2f5178f4a830"
]
for example_id in writer_list:
try: try:
main("multi_apps", example_id) main("libreoffice_writer", example_id)
except Exception as e: except Exception as e:
logger.error("An error occurred while running the example: %s", e) logger.error("An error occurred while running the example: %s", e)
continue continue

View File

@@ -151,7 +151,111 @@ def main(example_class, example_id):
if __name__ == '__main__': if __name__ == '__main__':
xx_list = [ from tqdm import tqdm
# impress_list = [
# # "5d901039-a89c-4bfb-967b-bf66f4df075e",
# "550ce7e7-747b-495f-b122-acdc4d0b8e54",
# "455d3c66-7dc6-4537-a39a-36d3e9119df7",
# "af23762e-2bfd-4a1d-aada-20fa8de9ce07",
# "c59742c0-4323-4b9d-8a02-723c251deaa0",
# "ef9d12bd-bcee-4ba0-a40e-918400f43ddf",
# "9ec204e4-f0a3-42f8-8458-b772a6797cab",
# "0f84bef9-9790-432e-92b7-eece357603fb",
# "ce88f674-ab7a-43da-9201-468d38539e4a",
# "3b27600c-3668-4abd-8f84-7bcdebbccbdb",
# "a097acff-6266-4291-9fbd-137af7ecd439",
# "bf4e9888-f10f-47af-8dba-76413038b73c",
# "21760ecb-8f62-40d2-8d85-0cee5725cb72"
# ]
# for example_id in impress_list:
# main("libreoffice_impress", example_id)
vlc_list = [
"8ba5ae7a-5ae5-4eab-9fcc-5dd4fe3abf89",
"8ba5ae7a-5ae5-4eab-9fcc-5dd4fe3abf89",
"8f080098-ddb1-424c-b438-4e96e5e4786e",
"bba3381f-b5eb-4439-bd9e-80c22218d5a7",
"fba2c100-79e8-42df-ae74-b592418d54f4",
"efcf0d81-0835-4880-b2fd-d866e8bc2294",
"8d9fd4e2-6fdb-46b0-b9b9-02f06495c62f",
"aa4b5023-aef6-4ed9-bdc9-705f59ab9ad6",
"386dbd0e-0241-4a0a-b6a2-6704fba26b1c",
"9195653c-f4aa-453d-aa95-787f6ccfaae9",
"d06f0d4d-2cd5-4ede-8de9-598629438c6e",
"a5bbbcd5-b398-4c91-83d4-55e1e31bbb81",
"f3977615-2b45-4ac5-8bba-80c17dbe2a37",
"215dfd39-f493-4bc3-a027-8a97d72c61bf"
] ]
for example_id in xx_list:
main("xx", example_id) # for example_id in tqdm(vlc_list):
# try:
# main("vlc", example_id)
# except Exception as e:
# print(f"An error occurred while running the example: {e}")
# continue
chrome_list = [
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
"06fe7178-4491-4589-810f-2e2bc9502122",
"e1e75309-3ddb-4d09-92ec-de869c928143",
"35253b65-1c19-4304-8aa4-6884b8218fc0",
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
"7a5a7856-f1b6-42a4-ade9-1ca81ca0f263",
"44ee5668-ecd5-4366-a6ce-c1c9b8d4e938",
"2ae9ba84-3a0d-4d4c-8338-3a1478dc5fe3",
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
"af630914-714e-4a24-a7bb-f9af687d3b91"
]
for example_id in tqdm(chrome_list):
try:
main("chrome", example_id)
except Exception as e:
print(f"An error occurred while running the example: {e}")
continue
vs_code_list = [
"0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
"53ad5833-3455-407b-bbc6-45b4c79ab8fb",
"eabc805a-bfcf-4460-b250-ac92135819f6",
"982d12a5-beab-424f-8d38-d2a48429e511",
"4e60007a-f5be-4bfc-9723-c39affa0a6d3",
"e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
"9439a27b-18ae-42d8-9778-5f68f891805e",
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
"930fdb3b-11a8-46fe-9bac-577332e2640e",
"276cc624-87ea-4f08-ab93-f770e3790175",
"9d425400-e9b2-4424-9a4b-d4c7abac4140"
]
for example_id in tqdm(vs_code_list):
try:
main("vs_code", example_id)
except Exception as e:
print(f"An error occurred while running the example: {e}")
continue
thunderbird_list = [
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
"12086550-11c0-466b-b367-1d9e75b3910e",
"06fe7178-4491-4589-810f-2e2bc9502122",
"6766f2b8-8a72-417f-a9e5-56fcaa735837",
"e1e75309-3ddb-4d09-92ec-de869c928143",
"3d1682a7-0fb0-49ae-a4dc-a73afd2d06d5",
"35253b65-1c19-4304-8aa4-6884b8218fc0",
"d088f539-cab4-4f9a-ac92-9999fc3a656e",
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
"030eeff7-b492-4218-b312-701ec99ee0cc",
"94760984-3ff5-41ee-8347-cf1af709fea0",
"99146c54-4f37-4ab8-9327-5f3291665e1e",
"c9e7eaf2-b1a1-4efc-a982-721972fa9f02"
]
for example_id in tqdm(thunderbird_list):
try:
main("thunderbird", example_id)
except Exception as e:
print(f"An error occurred while running the example: {e}")
continue

View File

@@ -15,12 +15,10 @@ import google.generativeai as genai
import openai import openai
import requests import requests
from PIL import Image from PIL import Image
from openai.error import ( from openai import (
APIConnectionError, APIConnectionError,
APIError, APIError,
RateLimitError, RateLimitError
ServiceUnavailableError,
InvalidRequestError
) )
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes
@@ -115,6 +113,7 @@ def parse_actions_from_string(input_string):
def parse_code_from_string(input_string): def parse_code_from_string(input_string):
input_string = input_string.replace(";", "\n")
if input_string.strip() in ['WAIT', 'DONE', 'FAIL']: if input_string.strip() in ['WAIT', 'DONE', 'FAIL']:
return [input_string.strip()] return [input_string.strip()]
@@ -475,19 +474,15 @@ class GPT4v_Agent:
with open("messages.json", "w") as f: with open("messages.json", "w") as f:
f.write(json.dumps(messages, indent=4)) f.write(json.dumps(messages, indent=4))
try:
response = self.call_llm({ response = self.call_llm({
"model": self.model, "model": self.model,
"messages": messages, "messages": messages,
"max_tokens": self.max_tokens "max_tokens": self.max_tokens
}) })
except Exception as e:
logger.warning("LLM INVOCATION ERROR: %s", str(e))
response = ""
logger.debug("RESPONSE: %s", response) logger.debug("RESPONSE: %s", response)
# {{{
if self.exp == "seeact": if self.exp == "seeact":
messages.append({ messages.append({
"role": "assistant", "role": "assistant",
@@ -523,13 +518,13 @@ class GPT4v_Agent:
except Exception as e: except Exception as e:
print("Failed to parse action from response", e) print("Failed to parse action from response", e)
actions = None actions = None
self.thoughts.append("") # }}} self.thoughts.append("")
return actions return actions
@backoff.on_exception( @backoff.on_exception(
backoff.expo, backoff.expo,
(APIError, RateLimitError, APIConnectionError, ServiceUnavailableError, InvalidRequestError), (APIError, RateLimitError, APIConnectionError),
max_tries=10 max_tries=10
) )
def call_llm(self, payload): def call_llm(self, payload):
@@ -582,23 +577,34 @@ class GPT4v_Agent:
misrtal_messages[1]['content'] = misrtal_messages[0]['content'] + "\n" + misrtal_messages[1]['content'] misrtal_messages[1]['content'] = misrtal_messages[0]['content'] + "\n" + misrtal_messages[1]['content']
misrtal_messages.pop(0) misrtal_messages.pop(0)
openai.api_base = "http://localhost:8000/v1" # openai.api_base = "http://localhost:8000/v1"
openai.api_key = "test" # openai.api_key = "test"
response = openai.ChatCompletion.create( # response = openai.ChatCompletion.create(
# messages=misrtal_messages,
# model="Mixtral-8x7B-Instruct-v0.1"
# )
from openai import OpenAI
TOGETHER_API_KEY = "d011650e7537797148fb6170ec1e0be7ae75160375686fae02277136078e90d2"
client = OpenAI(api_key=TOGETHER_API_KEY,
base_url='https://api.together.xyz',
)
response = client.chat.completions.create(
messages=misrtal_messages, messages=misrtal_messages,
model="Mixtral-8x7B-Instruct-v0.1" model="mistralai/Mixtral-8x7B-Instruct-v0.1",
max_tokens=1024
) )
try: try:
return response['choices'][0]['message']['content'] # return response['choices'][0]['message']['content']
return response.choices[0].message.content
except Exception as e: except Exception as e:
print("Failed to call LLM: " + str(e)) print("Failed to call LLM: " + str(e))
return "" return ""
elif self.model.startswith("gemini"): elif self.model.startswith("gemini"):
api_key = os.environ.get("GENAI_API_KEY")
genai.api_key = api_key
def encoded_img_to_pil_img(data_str): def encoded_img_to_pil_img(data_str):
base64_str = data_str.replace("data:image/png;base64,", "") base64_str = data_str.replace("data:image/png;base64,", "")
image_data = base64.b64decode(base64_str) image_data = base64.b64decode(base64_str)
@@ -611,8 +617,13 @@ class GPT4v_Agent:
gemini_messages = [] gemini_messages = []
for i, message in enumerate(messages): for i, message in enumerate(messages):
role_mapping = {
"assistant": "model",
"user": "user",
"system": "system"
}
gemini_message = { gemini_message = {
"role": message["role"], "role": role_mapping[message["role"]],
"parts": [] "parts": []
} }
assert len(message["content"]) in [1, 2], "One text, or one text with one image" assert len(message["content"]) in [1, 2], "One text, or one text with one image"
@@ -628,6 +639,15 @@ class GPT4v_Agent:
gemini_messages.append(gemini_message) gemini_messages.append(gemini_message)
# the mistral not support system message in our endpoint, so we concatenate it at the first user message
if gemini_messages[0]['role'] == "system":
gemini_messages[1]['parts'][0] = gemini_messages[0]['parts'][0] + "\n" + gemini_messages[1]['parts'][0]
gemini_messages.pop(0)
print(gemini_messages)
api_key = os.environ.get("GENAI_API_KEY")
assert api_key is not None, "Please set the GENAI_API_KEY environment variable"
genai.configure(api_key=api_key)
response = genai.GenerativeModel(self.model).generate_content( response = genai.GenerativeModel(self.model).generate_content(
gemini_messages, gemini_messages,
generation_config={ generation_config={