Add gemini agent implementation; Add missed requirements; Minor fix some small bugs

This commit is contained in:
Timothyxxx
2024-01-15 21:58:33 +08:00
parent c68796e842
commit 493b719821
10 changed files with 82 additions and 83 deletions

View File

@@ -1,7 +1,7 @@
import base64
import json
import re
from typing import Dict
from typing import Dict, List
import requests
@@ -63,7 +63,8 @@ def parse_code_from_string(input_string):
class GPT4v_Agent:
def __init__(self, api_key, model="gpt-4-vision-preview", max_tokens=300, action_space="computer_13"):
def __init__(self, api_key, instruction, model="gpt-4-vision-preview", max_tokens=300, action_space="computer_13"):
self.instruction = instruction
self.model = model
self.max_tokens = max_tokens
self.action_space = action_space
@@ -82,13 +83,13 @@ class GPT4v_Agent:
"text": {
"computer_13": SYS_PROMPT_ACTION,
"pyautogui": SYS_PROMPT_CODE
}[action_space]
}[action_space] + "\nHere is the instruction for the task: {}".format(self.instruction)
},
]
}
]
def predict(self, obs: Dict):
def predict(self, obs: Dict) -> List:
"""
Predict the next action(s) based on the current observation.
"""
@@ -98,8 +99,7 @@ class GPT4v_Agent:
"content": [
{
"type": "text",
"text": "To accomplish the task '{}' and given the current screenshot, what's the next step?".format(
obs["instruction"])
"text": "What's the next step that you will do to help with the task?"
},
{
"type": "image_url",
@@ -128,23 +128,12 @@ class GPT4v_Agent:
try:
actions = self.parse_actions(response.json()['choices'][0]['message']['content'])
except:
# todo: add error handling
print("Failed to parse action from response:", response.json()['choices'][0]['message']['content'])
print("Failed to parse action from response:", response.json())
actions = None
return actions
def parse_actions(self, response: str):
# response example
"""
```json
{
"action_type": "CLICK",
"click_type": "RIGHT"
}
```
"""
# parse from the response
if self.action_space == "computer_13":
actions = parse_actions_from_string(response)