Merge branch 'zdy'

This commit is contained in:
David Chang
2024-01-26 22:13:32 +08:00

View File

@@ -160,6 +160,7 @@ class GPT4v_Agent:
"Authorization": f"Bearer {api_key}"
}
self.thoughts = []
self.actions = []
self.observations = []
@@ -224,17 +225,21 @@ class GPT4v_Agent:
})
# Append trajectory
assert len(self.observations) == len(self.actions), "The number of observations and actions should be the same."
assert len(self.observations) == len(self.actions) and len(self.actions) == len(self.thoughts)\
, "The number of observations and actions should be the same."
if len(self.observations) > self.max_trajectory_length:
_observations = self.observations[-self.max_trajectory_length:]
_actions = self.actions[-self.max_trajectory_length:]
_thoughts = self.thoughts[-self.max_trajectory_length:]
else:
_observations = self.observations
_actions = self.actions
_thoughts = self.thoughts
for previous_obs, previous_action in zip(_observations, _actions):
for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts):
# {{{1
if self.exp == "both":
_screenshot = previous_obs["screenshot"]
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
@@ -310,18 +315,19 @@ class GPT4v_Agent:
]
})
else:
raise ValueError("Invalid experiment type: " + self.exp)
raise ValueError("Invalid experiment type: " + self.exp) # 1}}}
messages.append({
"role": "assistant",
"content": [
{
"type": "text",
"text": "\n".join(previous_action) if len(previous_action) > 0 else "No valid action"
"text": previous_thought.stip() if len(previous_thought)>0 else "No valid action"
},
]
})
# {{{1
if self.exp in ["screenshot", "both"]:
base64_image = encode_image(obs["screenshot"])
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
@@ -430,7 +436,7 @@ class GPT4v_Agent:
]
})
else:
raise ValueError("Invalid experiment type: " + self.exp)
raise ValueError("Invalid experiment type: " + self.exp) # 1}}}
with open("messages.json", "w") as f:
f.write(json.dumps(messages, indent=4))
@@ -474,9 +480,11 @@ class GPT4v_Agent:
try:
actions = self.parse_actions(response, masks)
self.thoughts.append(response)
except Exception as e:
print("Failed to parse action from response", e)
actions = None
self.thoughts.append("")
return actions