From 47bcfc0f0b048de7cb870e7cbd230d4f9cb08860 Mon Sep 17 00:00:00 2001 From: cui0711 <1729461967@qq.com> Date: Fri, 30 Jan 2026 16:28:02 +0800 Subject: [PATCH] feat(agent): add screenshot compression and dynamic resolution support --- mm_agents/agent.py | 86 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 68 insertions(+), 18 deletions(-) diff --git a/mm_agents/agent.py b/mm_agents/agent.py index b2c3dc9..351b371 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -49,6 +49,48 @@ def encode_image(image_content): return base64.b64encode(image_content).decode('utf-8') +def compress_screenshot(image_bytes, quality=75, resize_ratio=None): + """ + Compress screenshot to reduce file size while maintaining resolution. + + Args: + image_bytes: Raw image bytes (PNG format) + quality: JPEG quality (1-100, default 75) + resize_ratio: Optional resize ratio (e.g., 0.5 for 50% size). None = keep original size. + + Returns: + Compressed image bytes in JPEG format + """ + try: + # Open image from bytes + img = Image.open(BytesIO(image_bytes)) + + # Optionally resize if ratio is provided + if resize_ratio and resize_ratio != 1.0: + new_size = (int(img.size[0] * resize_ratio), int(img.size[1] * resize_ratio)) + img = img.resize(new_size, Image.Resampling.LANCZOS) + + # Convert to RGB if necessary (JPEG doesn't support alpha channel) + if img.mode in ('RGBA', 'LA', 'P'): + background = Image.new('RGB', img.size, (255, 255, 255)) + if img.mode == 'P': + img = img.convert('RGBA') + background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None) + img = background + + # Save as JPEG with compression + output = BytesIO() + img.save(output, format='JPEG', quality=quality, optimize=True) + compressed_size = len(output.getvalue()) + + logger.debug(f"Screenshot compressed: original={len(image_bytes)/1024:.1f}KB, compressed={compressed_size/1024:.1f}KB, ratio={compressed_size/len(image_bytes):.2%}") + + return output.getvalue() + except Exception as e: + logger.warning(f"Failed to compress screenshot: {e}, using original") + return image_bytes + + def encoded_img_to_pil_img(data_str): base64_str = data_str.replace("data:image/png;base64,", "") image_data = base64.b64decode(base64_str) @@ -236,7 +278,9 @@ class PromptAgent: # observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"] max_trajectory_length=3, a11y_tree_max_tokens=10000, - client_password="password" + client_password="password", + screen_width=1920, + screen_height=1080 ): self.platform = platform self.model = model @@ -248,6 +292,8 @@ class PromptAgent: self.max_trajectory_length = max_trajectory_length self.a11y_tree_max_tokens = a11y_tree_max_tokens self.client_password = client_password + self.screen_width = screen_width + self.screen_height = screen_height self.thoughts = [] self.actions = [] @@ -284,7 +330,7 @@ class PromptAgent: else: raise ValueError("Invalid experiment type: " + observation_type) - self.system_message = self.system_message.format(CLIENT_PASSWORD=self.client_password) + self.system_message = self.system_message.format(CLIENT_PASSWORD=self.client_password, SCREEN_WIDTH=self.screen_width, SCREEN_HEIGHT=self.screen_height) def predict(self, instruction: str, obs: Dict) -> List: """ @@ -342,8 +388,8 @@ class PromptAgent: { "type": "image_url", "image_url": { - "url": f"data:image/png;base64,{_screenshot}", - "detail": "high" + "url": f"data:image/jpeg;base64,{_screenshot}", + "detail": "auto" } } ] @@ -361,8 +407,8 @@ class PromptAgent: { "type": "image_url", "image_url": { - "url": f"data:image/png;base64,{_screenshot}", - "detail": "high" + "url": f"data:image/jpeg;base64,{_screenshot}", + "detail": "auto" } } ] @@ -380,8 +426,8 @@ class PromptAgent: { "type": "image_url", "image_url": { - "url": f"data:image/png;base64,{_screenshot}", - "detail": "high" + "url": f"data:image/jpeg;base64,{_screenshot}", + "detail": "auto" } } ] @@ -414,7 +460,9 @@ class PromptAgent: # {{{1 if self.observation_type in ["screenshot", "screenshot_a11y_tree"]: - base64_image = encode_image(obs["screenshot"]) + # Compress screenshot to JPEG (keep original resolution for accurate coordinates) + compressed_screenshot = compress_screenshot(obs["screenshot"], quality=75) + base64_image = encode_image(compressed_screenshot) linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"], platform=self.platform) if self.observation_type == "screenshot_a11y_tree" else None logger.debug("LINEAR AT: %s", linearized_accessibility_tree) @@ -447,8 +495,8 @@ class PromptAgent: { "type": "image_url", "image_url": { - "url": f"data:image/png;base64,{base64_image}", - "detail": "high" + "url": f"data:image/jpeg;base64,{base64_image}", + "detail": "auto" } } ] @@ -481,7 +529,9 @@ class PromptAgent: # Add som to the screenshot masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot(obs["screenshot"], obs[ "accessibility_tree"], self.platform) - base64_image = encode_image(tagged_screenshot) + # Compress tagged screenshot (keep original resolution) + compressed_screenshot = compress_screenshot(tagged_screenshot, quality=75) + base64_image = encode_image(compressed_screenshot) logger.debug("LINEAR AT: %s", linearized_accessibility_tree) if linearized_accessibility_tree: @@ -504,8 +554,8 @@ class PromptAgent: { "type": "image_url", "image_url": { - "url": f"data:image/png;base64,{base64_image}", - "detail": "high" + "url": f"data:image/jpeg;base64,{base64_image}", + "detail": "auto" } } ] @@ -523,7 +573,7 @@ class PromptAgent: "model": self.model, "messages": messages, "max_tokens": self.max_tokens, - "top_p": self.top_p, + # "top_p": self.top_p, "temperature": self.temperature }) except Exception as e: @@ -691,8 +741,8 @@ class PromptAgent: logger.debug("CLAUDE MESSAGE: %s", repr(claude_messages)) headers = { - "x-api-key": os.environ["ANTHROPIC_API_KEY"], - "anthropic-version": "2023-06-01", + "x-api-key": os.environ["OPENAI_API_KEY"], + # "anthropic-version": "2023-06-01", "content-type": "application/json" } @@ -705,7 +755,7 @@ class PromptAgent: } response = requests.post( - "https://api.anthropic.com/v1/messages", + "https://api.apiyi.com/v1/messages", headers=headers, json=payload )