feat(agent): add screenshot compression and dynamic resolution support
This commit is contained in:
@@ -49,6 +49,48 @@ def encode_image(image_content):
|
|||||||
return base64.b64encode(image_content).decode('utf-8')
|
return base64.b64encode(image_content).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
def compress_screenshot(image_bytes, quality=75, resize_ratio=None):
|
||||||
|
"""
|
||||||
|
Compress screenshot to reduce file size while maintaining resolution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_bytes: Raw image bytes (PNG format)
|
||||||
|
quality: JPEG quality (1-100, default 75)
|
||||||
|
resize_ratio: Optional resize ratio (e.g., 0.5 for 50% size). None = keep original size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compressed image bytes in JPEG format
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Open image from bytes
|
||||||
|
img = Image.open(BytesIO(image_bytes))
|
||||||
|
|
||||||
|
# Optionally resize if ratio is provided
|
||||||
|
if resize_ratio and resize_ratio != 1.0:
|
||||||
|
new_size = (int(img.size[0] * resize_ratio), int(img.size[1] * resize_ratio))
|
||||||
|
img = img.resize(new_size, Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
|
# Convert to RGB if necessary (JPEG doesn't support alpha channel)
|
||||||
|
if img.mode in ('RGBA', 'LA', 'P'):
|
||||||
|
background = Image.new('RGB', img.size, (255, 255, 255))
|
||||||
|
if img.mode == 'P':
|
||||||
|
img = img.convert('RGBA')
|
||||||
|
background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None)
|
||||||
|
img = background
|
||||||
|
|
||||||
|
# Save as JPEG with compression
|
||||||
|
output = BytesIO()
|
||||||
|
img.save(output, format='JPEG', quality=quality, optimize=True)
|
||||||
|
compressed_size = len(output.getvalue())
|
||||||
|
|
||||||
|
logger.debug(f"Screenshot compressed: original={len(image_bytes)/1024:.1f}KB, compressed={compressed_size/1024:.1f}KB, ratio={compressed_size/len(image_bytes):.2%}")
|
||||||
|
|
||||||
|
return output.getvalue()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to compress screenshot: {e}, using original")
|
||||||
|
return image_bytes
|
||||||
|
|
||||||
|
|
||||||
def encoded_img_to_pil_img(data_str):
|
def encoded_img_to_pil_img(data_str):
|
||||||
base64_str = data_str.replace("data:image/png;base64,", "")
|
base64_str = data_str.replace("data:image/png;base64,", "")
|
||||||
image_data = base64.b64decode(base64_str)
|
image_data = base64.b64decode(base64_str)
|
||||||
@@ -236,7 +278,9 @@ class PromptAgent:
|
|||||||
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
|
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
|
||||||
max_trajectory_length=3,
|
max_trajectory_length=3,
|
||||||
a11y_tree_max_tokens=10000,
|
a11y_tree_max_tokens=10000,
|
||||||
client_password="password"
|
client_password="password",
|
||||||
|
screen_width=1920,
|
||||||
|
screen_height=1080
|
||||||
):
|
):
|
||||||
self.platform = platform
|
self.platform = platform
|
||||||
self.model = model
|
self.model = model
|
||||||
@@ -248,6 +292,8 @@ class PromptAgent:
|
|||||||
self.max_trajectory_length = max_trajectory_length
|
self.max_trajectory_length = max_trajectory_length
|
||||||
self.a11y_tree_max_tokens = a11y_tree_max_tokens
|
self.a11y_tree_max_tokens = a11y_tree_max_tokens
|
||||||
self.client_password = client_password
|
self.client_password = client_password
|
||||||
|
self.screen_width = screen_width
|
||||||
|
self.screen_height = screen_height
|
||||||
|
|
||||||
self.thoughts = []
|
self.thoughts = []
|
||||||
self.actions = []
|
self.actions = []
|
||||||
@@ -284,7 +330,7 @@ class PromptAgent:
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Invalid experiment type: " + observation_type)
|
raise ValueError("Invalid experiment type: " + observation_type)
|
||||||
|
|
||||||
self.system_message = self.system_message.format(CLIENT_PASSWORD=self.client_password)
|
self.system_message = self.system_message.format(CLIENT_PASSWORD=self.client_password, SCREEN_WIDTH=self.screen_width, SCREEN_HEIGHT=self.screen_height)
|
||||||
|
|
||||||
def predict(self, instruction: str, obs: Dict) -> List:
|
def predict(self, instruction: str, obs: Dict) -> List:
|
||||||
"""
|
"""
|
||||||
@@ -342,8 +388,8 @@ class PromptAgent:
|
|||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {
|
||||||
"url": f"data:image/png;base64,{_screenshot}",
|
"url": f"data:image/jpeg;base64,{_screenshot}",
|
||||||
"detail": "high"
|
"detail": "auto"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@@ -361,8 +407,8 @@ class PromptAgent:
|
|||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {
|
||||||
"url": f"data:image/png;base64,{_screenshot}",
|
"url": f"data:image/jpeg;base64,{_screenshot}",
|
||||||
"detail": "high"
|
"detail": "auto"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@@ -380,8 +426,8 @@ class PromptAgent:
|
|||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {
|
||||||
"url": f"data:image/png;base64,{_screenshot}",
|
"url": f"data:image/jpeg;base64,{_screenshot}",
|
||||||
"detail": "high"
|
"detail": "auto"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@@ -414,7 +460,9 @@ class PromptAgent:
|
|||||||
|
|
||||||
# {{{1
|
# {{{1
|
||||||
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
|
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
|
||||||
base64_image = encode_image(obs["screenshot"])
|
# Compress screenshot to JPEG (keep original resolution for accurate coordinates)
|
||||||
|
compressed_screenshot = compress_screenshot(obs["screenshot"], quality=75)
|
||||||
|
base64_image = encode_image(compressed_screenshot)
|
||||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"],
|
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"],
|
||||||
platform=self.platform) if self.observation_type == "screenshot_a11y_tree" else None
|
platform=self.platform) if self.observation_type == "screenshot_a11y_tree" else None
|
||||||
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
||||||
@@ -447,8 +495,8 @@ class PromptAgent:
|
|||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {
|
||||||
"url": f"data:image/png;base64,{base64_image}",
|
"url": f"data:image/jpeg;base64,{base64_image}",
|
||||||
"detail": "high"
|
"detail": "auto"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@@ -481,7 +529,9 @@ class PromptAgent:
|
|||||||
# Add som to the screenshot
|
# Add som to the screenshot
|
||||||
masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot(obs["screenshot"], obs[
|
masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot(obs["screenshot"], obs[
|
||||||
"accessibility_tree"], self.platform)
|
"accessibility_tree"], self.platform)
|
||||||
base64_image = encode_image(tagged_screenshot)
|
# Compress tagged screenshot (keep original resolution)
|
||||||
|
compressed_screenshot = compress_screenshot(tagged_screenshot, quality=75)
|
||||||
|
base64_image = encode_image(compressed_screenshot)
|
||||||
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
||||||
|
|
||||||
if linearized_accessibility_tree:
|
if linearized_accessibility_tree:
|
||||||
@@ -504,8 +554,8 @@ class PromptAgent:
|
|||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {
|
||||||
"url": f"data:image/png;base64,{base64_image}",
|
"url": f"data:image/jpeg;base64,{base64_image}",
|
||||||
"detail": "high"
|
"detail": "auto"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@@ -523,7 +573,7 @@ class PromptAgent:
|
|||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"max_tokens": self.max_tokens,
|
"max_tokens": self.max_tokens,
|
||||||
"top_p": self.top_p,
|
# "top_p": self.top_p,
|
||||||
"temperature": self.temperature
|
"temperature": self.temperature
|
||||||
})
|
})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -691,8 +741,8 @@ class PromptAgent:
|
|||||||
logger.debug("CLAUDE MESSAGE: %s", repr(claude_messages))
|
logger.debug("CLAUDE MESSAGE: %s", repr(claude_messages))
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"x-api-key": os.environ["ANTHROPIC_API_KEY"],
|
"x-api-key": os.environ["OPENAI_API_KEY"],
|
||||||
"anthropic-version": "2023-06-01",
|
# "anthropic-version": "2023-06-01",
|
||||||
"content-type": "application/json"
|
"content-type": "application/json"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -705,7 +755,7 @@ class PromptAgent:
|
|||||||
}
|
}
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
"https://api.anthropic.com/v1/messages",
|
"https://api.apiyi.com/v1/messages",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
json=payload
|
json=payload
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user