128 lines
4.2 KiB
Python
128 lines
4.2 KiB
Python
from typing import Callable, Any, Optional, Tuple
|
|
import os
|
|
from test_env import PythonController
|
|
|
|
|
|
class DesktopEnv:
|
|
def __init__(
|
|
self,
|
|
action_space: str = "computer_13",
|
|
screen_size: Tuple[int] = (1920, 1080),
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
):
|
|
self.obs_options = {}
|
|
self._step_no = 0
|
|
self.action_history = []
|
|
self.action_space = action_space
|
|
self.resolution = screen_size
|
|
self.controller = PythonController()
|
|
|
|
|
|
# Load test screenshots and accessibility trees
|
|
test_obs_dir = os.path.join(os.path.dirname(__file__), "test_observations")
|
|
|
|
self.screenshots = [
|
|
self._load_image(os.path.join(test_obs_dir, "screenshot0.jpg")),
|
|
self._load_image(os.path.join(test_obs_dir, "screenshot1.jpg")),
|
|
]
|
|
self.accessibility_trees = [
|
|
self._load_accessibility_tree(os.path.join(test_obs_dir, "a11y_tree0.txt")),
|
|
self._load_accessibility_tree(os.path.join(test_obs_dir, "a11y_tree1.txt")),
|
|
]
|
|
|
|
def _get_screenshot(self):
|
|
if self._step_no == 0:
|
|
return self.screenshots[0]
|
|
return self.screenshots[1]
|
|
|
|
def _get_accessibility_tree(self):
|
|
if self._step_no == 0:
|
|
return self.accessibility_trees[0]
|
|
return self.accessibility_trees[1]
|
|
|
|
def set_obs_options(self, obs_options):
|
|
print(f"Setting obs options to {obs_options}")
|
|
self.obs_options = obs_options
|
|
|
|
def _load_image(self, image_path):
|
|
try:
|
|
with open(image_path, "rb") as image_file:
|
|
# Read the image file in binary mode
|
|
image_data = image_file.read()
|
|
# Encode the binary data as Base64
|
|
return image_data
|
|
except FileNotFoundError:
|
|
print(f"Error: File not found at {image_path}")
|
|
except Exception as e:
|
|
print(f"An error occurred: {e}")
|
|
|
|
def _load_accessibility_tree(self, tree_path):
|
|
try:
|
|
with open(tree_path, "r") as tree_file:
|
|
# Read the accessibility tree file
|
|
tree_data = tree_file.read()
|
|
return tree_data
|
|
except FileNotFoundError:
|
|
print(f"Error: File not found at {tree_path}")
|
|
except Exception as e:
|
|
print(f"An error occurred: {e}")
|
|
|
|
def _get_obs(self):
|
|
obs = {}
|
|
obs["screenshot"] = self._get_screenshot()
|
|
obs["accessibility_tree"] = self._get_accessibility_tree()
|
|
obs["terminal"] = ""
|
|
obs["instruction"] = "Open Chrome browser"
|
|
|
|
return obs
|
|
|
|
def _start_video_recording(self):
|
|
pass
|
|
|
|
def _stop_video_recording(self):
|
|
pass
|
|
|
|
def step(self, action, *args, **kargs) -> Tuple:
|
|
self._step_no += 1
|
|
self.action_history.append(action)
|
|
|
|
info = {}
|
|
terminated = False # todo: Define episode termination condition for each example
|
|
|
|
if action == 'FAIL' or action == 'DONE':
|
|
terminated = True
|
|
|
|
else:
|
|
if self.action_space == "claude_computer_use":
|
|
tool_result = {
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "tool_result",
|
|
"tool_use_id": "toolu_01A09q90qw90lq917835lq9",
|
|
"content": [
|
|
{
|
|
"type": "image",
|
|
"source": {
|
|
"type": "base64",
|
|
"media_type": "image/jpeg",
|
|
"data": self.screenshots[1],
|
|
}
|
|
}
|
|
]
|
|
}
|
|
]
|
|
}
|
|
info.update({"tool_result": tool_result})
|
|
|
|
return (self._get_obs(), 0, terminated, info)
|
|
|
|
def close(self):
|
|
self._step_no = 0
|
|
self.action_history = []
|
|
self.obs_options = {}
|
|
self.controller = None
|
|
|
|
def reset(self, *args: Any, **kwargs: Any) -> dict:
|
|
return self._get_obs() |