Add gemini agent implementation; Add missed requirements; Minor fix some small bugs
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
|
||||
import requests
|
||||
|
||||
@@ -63,7 +63,8 @@ def parse_code_from_string(input_string):
|
||||
|
||||
|
||||
class GPT4v_Agent:
|
||||
def __init__(self, api_key, model="gpt-4-vision-preview", max_tokens=300, action_space="computer_13"):
|
||||
def __init__(self, api_key, instruction, model="gpt-4-vision-preview", max_tokens=300, action_space="computer_13"):
|
||||
self.instruction = instruction
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
self.action_space = action_space
|
||||
@@ -82,13 +83,13 @@ class GPT4v_Agent:
|
||||
"text": {
|
||||
"computer_13": SYS_PROMPT_ACTION,
|
||||
"pyautogui": SYS_PROMPT_CODE
|
||||
}[action_space]
|
||||
}[action_space] + "\nHere is the instruction for the task: {}".format(self.instruction)
|
||||
},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
def predict(self, obs: Dict):
|
||||
def predict(self, obs: Dict) -> List:
|
||||
"""
|
||||
Predict the next action(s) based on the current observation.
|
||||
"""
|
||||
@@ -98,8 +99,7 @@ class GPT4v_Agent:
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "To accomplish the task '{}' and given the current screenshot, what's the next step?".format(
|
||||
obs["instruction"])
|
||||
"text": "What's the next step that you will do to help with the task?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
@@ -128,23 +128,12 @@ class GPT4v_Agent:
|
||||
try:
|
||||
actions = self.parse_actions(response.json()['choices'][0]['message']['content'])
|
||||
except:
|
||||
# todo: add error handling
|
||||
print("Failed to parse action from response:", response.json()['choices'][0]['message']['content'])
|
||||
print("Failed to parse action from response:", response.json())
|
||||
actions = None
|
||||
|
||||
return actions
|
||||
|
||||
def parse_actions(self, response: str):
|
||||
# response example
|
||||
"""
|
||||
```json
|
||||
{
|
||||
"action_type": "CLICK",
|
||||
"click_type": "RIGHT"
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
# parse from the response
|
||||
if self.action_space == "computer_13":
|
||||
actions = parse_actions_from_string(response)
|
||||
|
||||
Reference in New Issue
Block a user