feat: 增强任务步骤注入与a11y状态表达,提升树形交互稳定性
- 打通 metadata.steps 传递链路,将任务步骤注入 agent 预测上下文 - 优化 a11y tree 线性化输出:使用中心坐标并新增 states 列(expanded/collapsed/selected 等) - 放宽可保留节点条件,保留无文本输入类控件(edit/textfield/searchbox 等) - 强化输出约束:单轮仅允许动作代码或 WAIT/DONE/FAIL,禁止动作与 DONE 同轮返回 - 补充 avogadro 示例步骤:展开 aromatics 并选择 benzene.cjson
This commit is contained in:
@@ -126,7 +126,7 @@ def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
|
||||
raise ValueError("Invalid platform, must be 'ubuntu' or 'windows'")
|
||||
|
||||
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform)
|
||||
linearized_accessibility_tree = ["tag\tname\ttext\tclass\tdescription\tposition (top-left x&y)\tsize (w&h)"]
|
||||
linearized_accessibility_tree = ["tag\tname\ttext\tposition (center x&y)\tsize (w&h)\tstates"]
|
||||
|
||||
# Linearize the accessibility tree nodes into a table format
|
||||
for node in filtered_nodes:
|
||||
@@ -145,14 +145,36 @@ def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
|
||||
else:
|
||||
text = '""'
|
||||
|
||||
# Compute center coordinates from top-left + size/2
|
||||
coords_str = node.get('{{{:}}}screencoord'.format(_component_ns), "")
|
||||
size_str = node.get('{{{:}}}size'.format(_component_ns), "")
|
||||
if coords_str and size_str:
|
||||
try:
|
||||
cx, cy = coords_str.strip('()').split(', ')
|
||||
sw, sh = size_str.strip('()').split(', ')
|
||||
center_x = int(cx) + int(sw) // 2
|
||||
center_y = int(cy) + int(sh) // 2
|
||||
center_str = "({:d}, {:d})".format(center_x, center_y)
|
||||
except (ValueError, IndexError):
|
||||
center_str = coords_str
|
||||
else:
|
||||
center_str = coords_str
|
||||
|
||||
# Extract useful UI states (expanded/collapsed/checked/selected/focused)
|
||||
state_flags = []
|
||||
for state_name in ["expanded", "collapsed", "checked", "selected", "focused", "pressed"]:
|
||||
val = node.get("{{{:}}}{:}".format(_state_ns, state_name), "")
|
||||
if val == "true":
|
||||
state_flags.append(state_name)
|
||||
states_str = ",".join(state_flags) if state_flags else ""
|
||||
|
||||
linearized_accessibility_tree.append(
|
||||
"{:}\t{:}\t{:}\t{:}\t{:}\t{:}\t{:}".format(
|
||||
"{:}\t{:}\t{:}\t{:}\t{:}\t{:}".format(
|
||||
node.tag, node.get("name", ""),
|
||||
text,
|
||||
node.get("{{{:}}}class".format(_attributes_ns), "") if platform == "ubuntu" else node.get("{{{:}}}class".format(class_ns_windows), ""),
|
||||
node.get("{{{:}}}description".format(_attributes_ns), ""),
|
||||
node.get('{{{:}}}screencoord'.format(_component_ns), ""),
|
||||
node.get('{{{:}}}size'.format(_component_ns), "")
|
||||
center_str,
|
||||
size_str,
|
||||
states_str
|
||||
)
|
||||
)
|
||||
|
||||
@@ -332,11 +354,13 @@ class PromptAgent:
|
||||
|
||||
self.system_message = self.system_message.format(CLIENT_PASSWORD=self.client_password, SCREEN_WIDTH=self.screen_width, SCREEN_HEIGHT=self.screen_height)
|
||||
|
||||
def predict(self, instruction: str, obs: Dict) -> List:
|
||||
def predict(self, instruction: str, obs: Dict, metadata_steps: str = "") -> List:
|
||||
"""
|
||||
Predict the next action(s) based on the current observation.
|
||||
"""
|
||||
system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(instruction)
|
||||
if metadata_steps:
|
||||
system_message += "\n\nHere are the reference steps from the software tutorial, which may help you complete the task:\n{}".format(metadata_steps)
|
||||
|
||||
# Prepare the payload for the API call
|
||||
messages = []
|
||||
|
||||
Reference in New Issue
Block a user