Wrap up SeeAct implementation

This commit is contained in:
Timothyxxx
2024-01-20 19:19:37 +08:00
parent f88331416c
commit 6f27c5bf50
5 changed files with 437 additions and 1410 deletions

View File

@@ -235,7 +235,7 @@ class GPT4v_Agent:
for previous_obs, previous_action in zip(_observations, _actions):
if self.exp in ["both", "som", "seeact"]:
if self.exp == "both":
_screenshot = previous_obs["screenshot"]
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
@@ -244,7 +244,28 @@ class GPT4v_Agent:
"content": [
{
"type": "text",
"text": "Given the info from the tagged screenshot as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
"text": "Given the screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
_linearized_accessibility_tree)
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{_screenshot}",
"detail": "high"
}
}
]
})
elif self.exp in ["som", "seeact"]:
_screenshot = previous_obs["screenshot"]
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
_linearized_accessibility_tree)
},
{
@@ -369,7 +390,7 @@ class GPT4v_Agent:
"content": [
{
"type": "text",
"text": "Given the info from the tagged screenshot as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
"text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
linearized_accessibility_tree)
},
{
@@ -383,8 +404,7 @@ class GPT4v_Agent:
})
elif self.exp == "seeact":
# Add som to the screenshot
masks, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
base64_image = encode_image(tagged_screenshot)
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
@@ -421,6 +441,8 @@ class GPT4v_Agent:
"max_tokens": self.max_tokens
})
print(response)
if self.exp == "seeact":
messages.append({
"role": "assistant",
@@ -448,6 +470,7 @@ class GPT4v_Agent:
"messages": messages,
"max_tokens": self.max_tokens
})
print(response)
try:
actions = self.parse_actions(response, masks)

File diff suppressed because it is too large Load Diff