feat(agent): add screenshot compression and dynamic resolution support

This commit is contained in:
cui0711
2026-01-30 16:28:02 +08:00
parent 7e9090e115
commit 47bcfc0f0b

View File

@@ -49,6 +49,48 @@ def encode_image(image_content):
return base64.b64encode(image_content).decode('utf-8')
def compress_screenshot(image_bytes, quality=75, resize_ratio=None):
"""
Compress screenshot to reduce file size while maintaining resolution.
Args:
image_bytes: Raw image bytes (PNG format)
quality: JPEG quality (1-100, default 75)
resize_ratio: Optional resize ratio (e.g., 0.5 for 50% size). None = keep original size.
Returns:
Compressed image bytes in JPEG format
"""
try:
# Open image from bytes
img = Image.open(BytesIO(image_bytes))
# Optionally resize if ratio is provided
if resize_ratio and resize_ratio != 1.0:
new_size = (int(img.size[0] * resize_ratio), int(img.size[1] * resize_ratio))
img = img.resize(new_size, Image.Resampling.LANCZOS)
# Convert to RGB if necessary (JPEG doesn't support alpha channel)
if img.mode in ('RGBA', 'LA', 'P'):
background = Image.new('RGB', img.size, (255, 255, 255))
if img.mode == 'P':
img = img.convert('RGBA')
background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None)
img = background
# Save as JPEG with compression
output = BytesIO()
img.save(output, format='JPEG', quality=quality, optimize=True)
compressed_size = len(output.getvalue())
logger.debug(f"Screenshot compressed: original={len(image_bytes)/1024:.1f}KB, compressed={compressed_size/1024:.1f}KB, ratio={compressed_size/len(image_bytes):.2%}")
return output.getvalue()
except Exception as e:
logger.warning(f"Failed to compress screenshot: {e}, using original")
return image_bytes
def encoded_img_to_pil_img(data_str):
base64_str = data_str.replace("data:image/png;base64,", "")
image_data = base64.b64decode(base64_str)
@@ -236,7 +278,9 @@ class PromptAgent:
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
max_trajectory_length=3,
a11y_tree_max_tokens=10000,
client_password="password"
client_password="password",
screen_width=1920,
screen_height=1080
):
self.platform = platform
self.model = model
@@ -248,6 +292,8 @@ class PromptAgent:
self.max_trajectory_length = max_trajectory_length
self.a11y_tree_max_tokens = a11y_tree_max_tokens
self.client_password = client_password
self.screen_width = screen_width
self.screen_height = screen_height
self.thoughts = []
self.actions = []
@@ -284,7 +330,7 @@ class PromptAgent:
else:
raise ValueError("Invalid experiment type: " + observation_type)
self.system_message = self.system_message.format(CLIENT_PASSWORD=self.client_password)
self.system_message = self.system_message.format(CLIENT_PASSWORD=self.client_password, SCREEN_WIDTH=self.screen_width, SCREEN_HEIGHT=self.screen_height)
def predict(self, instruction: str, obs: Dict) -> List:
"""
@@ -342,8 +388,8 @@ class PromptAgent:
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{_screenshot}",
"detail": "high"
"url": f"data:image/jpeg;base64,{_screenshot}",
"detail": "auto"
}
}
]
@@ -361,8 +407,8 @@ class PromptAgent:
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{_screenshot}",
"detail": "high"
"url": f"data:image/jpeg;base64,{_screenshot}",
"detail": "auto"
}
}
]
@@ -380,8 +426,8 @@ class PromptAgent:
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{_screenshot}",
"detail": "high"
"url": f"data:image/jpeg;base64,{_screenshot}",
"detail": "auto"
}
}
]
@@ -414,7 +460,9 @@ class PromptAgent:
# {{{1
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
base64_image = encode_image(obs["screenshot"])
# Compress screenshot to JPEG (keep original resolution for accurate coordinates)
compressed_screenshot = compress_screenshot(obs["screenshot"], quality=75)
base64_image = encode_image(compressed_screenshot)
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"],
platform=self.platform) if self.observation_type == "screenshot_a11y_tree" else None
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
@@ -447,8 +495,8 @@ class PromptAgent:
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image}",
"detail": "high"
"url": f"data:image/jpeg;base64,{base64_image}",
"detail": "auto"
}
}
]
@@ -481,7 +529,9 @@ class PromptAgent:
# Add som to the screenshot
masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot(obs["screenshot"], obs[
"accessibility_tree"], self.platform)
base64_image = encode_image(tagged_screenshot)
# Compress tagged screenshot (keep original resolution)
compressed_screenshot = compress_screenshot(tagged_screenshot, quality=75)
base64_image = encode_image(compressed_screenshot)
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
if linearized_accessibility_tree:
@@ -504,8 +554,8 @@ class PromptAgent:
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image}",
"detail": "high"
"url": f"data:image/jpeg;base64,{base64_image}",
"detail": "auto"
}
}
]
@@ -523,7 +573,7 @@ class PromptAgent:
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
# "top_p": self.top_p,
"temperature": self.temperature
})
except Exception as e:
@@ -691,8 +741,8 @@ class PromptAgent:
logger.debug("CLAUDE MESSAGE: %s", repr(claude_messages))
headers = {
"x-api-key": os.environ["ANTHROPIC_API_KEY"],
"anthropic-version": "2023-06-01",
"x-api-key": os.environ["OPENAI_API_KEY"],
# "anthropic-version": "2023-06-01",
"content-type": "application/json"
}
@@ -705,7 +755,7 @@ class PromptAgent:
}
response = requests.post(
"https://api.anthropic.com/v1/messages",
"https://api.apiyi.com/v1/messages",
headers=headers,
json=payload
)