feat(agent): add screenshot compression and dynamic resolution support
This commit is contained in:
@@ -49,6 +49,48 @@ def encode_image(image_content):
|
||||
return base64.b64encode(image_content).decode('utf-8')
|
||||
|
||||
|
||||
def compress_screenshot(image_bytes, quality=75, resize_ratio=None):
|
||||
"""
|
||||
Compress screenshot to reduce file size while maintaining resolution.
|
||||
|
||||
Args:
|
||||
image_bytes: Raw image bytes (PNG format)
|
||||
quality: JPEG quality (1-100, default 75)
|
||||
resize_ratio: Optional resize ratio (e.g., 0.5 for 50% size). None = keep original size.
|
||||
|
||||
Returns:
|
||||
Compressed image bytes in JPEG format
|
||||
"""
|
||||
try:
|
||||
# Open image from bytes
|
||||
img = Image.open(BytesIO(image_bytes))
|
||||
|
||||
# Optionally resize if ratio is provided
|
||||
if resize_ratio and resize_ratio != 1.0:
|
||||
new_size = (int(img.size[0] * resize_ratio), int(img.size[1] * resize_ratio))
|
||||
img = img.resize(new_size, Image.Resampling.LANCZOS)
|
||||
|
||||
# Convert to RGB if necessary (JPEG doesn't support alpha channel)
|
||||
if img.mode in ('RGBA', 'LA', 'P'):
|
||||
background = Image.new('RGB', img.size, (255, 255, 255))
|
||||
if img.mode == 'P':
|
||||
img = img.convert('RGBA')
|
||||
background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None)
|
||||
img = background
|
||||
|
||||
# Save as JPEG with compression
|
||||
output = BytesIO()
|
||||
img.save(output, format='JPEG', quality=quality, optimize=True)
|
||||
compressed_size = len(output.getvalue())
|
||||
|
||||
logger.debug(f"Screenshot compressed: original={len(image_bytes)/1024:.1f}KB, compressed={compressed_size/1024:.1f}KB, ratio={compressed_size/len(image_bytes):.2%}")
|
||||
|
||||
return output.getvalue()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to compress screenshot: {e}, using original")
|
||||
return image_bytes
|
||||
|
||||
|
||||
def encoded_img_to_pil_img(data_str):
|
||||
base64_str = data_str.replace("data:image/png;base64,", "")
|
||||
image_data = base64.b64decode(base64_str)
|
||||
@@ -236,7 +278,9 @@ class PromptAgent:
|
||||
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
|
||||
max_trajectory_length=3,
|
||||
a11y_tree_max_tokens=10000,
|
||||
client_password="password"
|
||||
client_password="password",
|
||||
screen_width=1920,
|
||||
screen_height=1080
|
||||
):
|
||||
self.platform = platform
|
||||
self.model = model
|
||||
@@ -248,6 +292,8 @@ class PromptAgent:
|
||||
self.max_trajectory_length = max_trajectory_length
|
||||
self.a11y_tree_max_tokens = a11y_tree_max_tokens
|
||||
self.client_password = client_password
|
||||
self.screen_width = screen_width
|
||||
self.screen_height = screen_height
|
||||
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
@@ -284,7 +330,7 @@ class PromptAgent:
|
||||
else:
|
||||
raise ValueError("Invalid experiment type: " + observation_type)
|
||||
|
||||
self.system_message = self.system_message.format(CLIENT_PASSWORD=self.client_password)
|
||||
self.system_message = self.system_message.format(CLIENT_PASSWORD=self.client_password, SCREEN_WIDTH=self.screen_width, SCREEN_HEIGHT=self.screen_height)
|
||||
|
||||
def predict(self, instruction: str, obs: Dict) -> List:
|
||||
"""
|
||||
@@ -342,8 +388,8 @@ class PromptAgent:
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{_screenshot}",
|
||||
"detail": "high"
|
||||
"url": f"data:image/jpeg;base64,{_screenshot}",
|
||||
"detail": "auto"
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -361,8 +407,8 @@ class PromptAgent:
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{_screenshot}",
|
||||
"detail": "high"
|
||||
"url": f"data:image/jpeg;base64,{_screenshot}",
|
||||
"detail": "auto"
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -380,8 +426,8 @@ class PromptAgent:
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{_screenshot}",
|
||||
"detail": "high"
|
||||
"url": f"data:image/jpeg;base64,{_screenshot}",
|
||||
"detail": "auto"
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -414,7 +460,9 @@ class PromptAgent:
|
||||
|
||||
# {{{1
|
||||
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
|
||||
base64_image = encode_image(obs["screenshot"])
|
||||
# Compress screenshot to JPEG (keep original resolution for accurate coordinates)
|
||||
compressed_screenshot = compress_screenshot(obs["screenshot"], quality=75)
|
||||
base64_image = encode_image(compressed_screenshot)
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"],
|
||||
platform=self.platform) if self.observation_type == "screenshot_a11y_tree" else None
|
||||
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
||||
@@ -447,8 +495,8 @@ class PromptAgent:
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{base64_image}",
|
||||
"detail": "high"
|
||||
"url": f"data:image/jpeg;base64,{base64_image}",
|
||||
"detail": "auto"
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -481,7 +529,9 @@ class PromptAgent:
|
||||
# Add som to the screenshot
|
||||
masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot(obs["screenshot"], obs[
|
||||
"accessibility_tree"], self.platform)
|
||||
base64_image = encode_image(tagged_screenshot)
|
||||
# Compress tagged screenshot (keep original resolution)
|
||||
compressed_screenshot = compress_screenshot(tagged_screenshot, quality=75)
|
||||
base64_image = encode_image(compressed_screenshot)
|
||||
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
||||
|
||||
if linearized_accessibility_tree:
|
||||
@@ -504,8 +554,8 @@ class PromptAgent:
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{base64_image}",
|
||||
"detail": "high"
|
||||
"url": f"data:image/jpeg;base64,{base64_image}",
|
||||
"detail": "auto"
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -523,7 +573,7 @@ class PromptAgent:
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
# "top_p": self.top_p,
|
||||
"temperature": self.temperature
|
||||
})
|
||||
except Exception as e:
|
||||
@@ -691,8 +741,8 @@ class PromptAgent:
|
||||
logger.debug("CLAUDE MESSAGE: %s", repr(claude_messages))
|
||||
|
||||
headers = {
|
||||
"x-api-key": os.environ["ANTHROPIC_API_KEY"],
|
||||
"anthropic-version": "2023-06-01",
|
||||
"x-api-key": os.environ["OPENAI_API_KEY"],
|
||||
# "anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json"
|
||||
}
|
||||
|
||||
@@ -705,7 +755,7 @@ class PromptAgent:
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
"https://api.anthropic.com/v1/messages",
|
||||
"https://api.apiyi.com/v1/messages",
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user