135 lines
4.8 KiB
Python
135 lines
4.8 KiB
Python
import time
|
|
from typing import Dict, List
|
|
|
|
import google.generativeai as genai
|
|
|
|
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes
|
|
from mm_agents.gpt_4_prompt_action import SYS_PROMPT as SYS_PROMPT_ACTION
|
|
from mm_agents.gpt_4_prompt_code import SYS_PROMPT as SYS_PROMPT_CODE
|
|
from mm_agents.gpt_4v_agent import parse_actions_from_string, parse_code_from_string
|
|
|
|
|
|
class GeminiPro_Agent:
|
|
def __init__(self, api_key, instruction, model='gemini-pro', max_tokens=300, temperature=0.0,
|
|
action_space="computer_13"):
|
|
genai.configure(api_key=api_key)
|
|
self.instruction = instruction
|
|
self.model = genai.GenerativeModel(model)
|
|
self.max_tokens = max_tokens
|
|
self.temperature = temperature
|
|
self.action_space = action_space
|
|
|
|
self.trajectory = [
|
|
{
|
|
"role": "system",
|
|
"parts": [
|
|
{
|
|
"computer_13": SYS_PROMPT_ACTION,
|
|
"pyautogui": SYS_PROMPT_CODE
|
|
}[action_space] + "\nHere is the instruction for the task: {}".format(self.instruction)
|
|
]
|
|
}
|
|
]
|
|
|
|
def predict(self, obs: Dict) -> List:
|
|
"""
|
|
Predict the next action(s) based on the current observation.
|
|
Only support single-round conversation, only fill-in the last desktop screenshot.
|
|
"""
|
|
accessibility_tree = obs["accessibility_tree"]
|
|
|
|
leaf_nodes = find_leaf_nodes(accessibility_tree)
|
|
filtered_nodes = filter_nodes(leaf_nodes)
|
|
|
|
linearized_accessibility_tree = "tag\ttext\tposition\tsize\n"
|
|
# Linearize the accessibility tree nodes into a table format
|
|
|
|
for node in filtered_nodes:
|
|
linearized_accessibility_tree += node.tag + "\t"
|
|
linearized_accessibility_tree += node.attrib.get('name') + "\t"
|
|
linearized_accessibility_tree += node.attrib.get(
|
|
'{uri:deskat:component.at-spi.gnome.org}screencoord') + "\t"
|
|
linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size') + "\n"
|
|
|
|
self.trajectory.append({
|
|
"role": "user",
|
|
"parts": [
|
|
"Given the XML format of accessibility tree (convert and formatted into table) as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
|
linearized_accessibility_tree)]
|
|
})
|
|
|
|
# todo: Remove this step once the Gemini supports multi-round conversation
|
|
all_message_str = ""
|
|
for i in range(len(self.trajectory)):
|
|
if i == 0:
|
|
all_message_template = "<|im_start|>system\n{}\n<|im_end|>\n"
|
|
elif i % 2 == 1:
|
|
all_message_template = "<|im_start|>user\n{}\n<|im_end|>\n"
|
|
else:
|
|
all_message_template = "<|im_start|>assistant\n{}\n<|im_end|>\n"
|
|
|
|
all_message_str += all_message_template.format(self.trajectory[i]["parts"][0])
|
|
|
|
print("All message: >>>>>>>>>>>>>>>> ")
|
|
print(
|
|
all_message_str
|
|
)
|
|
|
|
message_for_gemini = {
|
|
"role": "user",
|
|
"parts": [all_message_str]
|
|
}
|
|
|
|
traj_to_show = []
|
|
for i in range(len(self.trajectory)):
|
|
traj_to_show.append(self.trajectory[i]["parts"][0])
|
|
if len(self.trajectory[i]["parts"]) > 1:
|
|
traj_to_show.append("screenshot_obs")
|
|
|
|
print("Trajectory:", traj_to_show)
|
|
|
|
while True:
|
|
try:
|
|
response = self.model.generate_content(
|
|
message_for_gemini,
|
|
generation_config={
|
|
"max_output_tokens": self.max_tokens,
|
|
"temperature": self.temperature
|
|
}
|
|
)
|
|
break
|
|
except:
|
|
print("Failed to generate response, retrying...")
|
|
time.sleep(5)
|
|
pass
|
|
|
|
try:
|
|
response_text = response.text
|
|
except:
|
|
return []
|
|
|
|
try:
|
|
actions = self.parse_actions(response_text)
|
|
except:
|
|
print("Failed to parse action from response:", response_text)
|
|
actions = []
|
|
|
|
return actions
|
|
|
|
def parse_actions(self, response: str):
|
|
# parse from the response
|
|
if self.action_space == "computer_13":
|
|
actions = parse_actions_from_string(response)
|
|
elif self.action_space == "pyautogui":
|
|
actions = parse_code_from_string(response)
|
|
else:
|
|
raise ValueError("Invalid action space: " + self.action_space)
|
|
|
|
# add action into the trajectory
|
|
self.trajectory.append({
|
|
"role": "assistant",
|
|
"parts": [response]
|
|
})
|
|
|
|
return actions
|