feat: add fake env
This commit is contained in:
2
test_env/__init__.py
Normal file
2
test_env/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .fake_python_controller import PythonController
|
||||
from .fake_env import DesktopEnv
|
||||
128
test_env/fake_env.py
Normal file
128
test_env/fake_env.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from typing import Callable, Any, Optional, Tuple
|
||||
import os
|
||||
from test_env import PythonController
|
||||
|
||||
|
||||
class DesktopEnv:
|
||||
def __init__(
|
||||
self,
|
||||
action_space: str = "computer_13",
|
||||
screen_size: Tuple[int] = (1920, 1080),
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.obs_options = {}
|
||||
self._step_no = 0
|
||||
self.action_history = []
|
||||
self.action_space = action_space
|
||||
self.resolution = screen_size
|
||||
self.controller = PythonController()
|
||||
|
||||
|
||||
# Load test screenshots and accessibility trees
|
||||
test_obs_dir = os.path.join(os.path.dirname(__file__), "test_observations")
|
||||
|
||||
self.screenshots = [
|
||||
self._load_image(os.path.join(test_obs_dir, "screenshot0.jpg")),
|
||||
self._load_image(os.path.join(test_obs_dir, "screenshot1.jpg")),
|
||||
]
|
||||
self.accessibility_trees = [
|
||||
self._load_accessibility_tree(os.path.join(test_obs_dir, "a11y_tree0.txt")),
|
||||
self._load_accessibility_tree(os.path.join(test_obs_dir, "a11y_tree1.txt")),
|
||||
]
|
||||
|
||||
def _get_screenshot(self):
|
||||
if self._step_no == 0:
|
||||
return self.screenshots[0]
|
||||
return self.screenshots[1]
|
||||
|
||||
def _get_accessibility_tree(self):
|
||||
if self._step_no == 0:
|
||||
return self.accessibility_trees[0]
|
||||
return self.accessibility_trees[1]
|
||||
|
||||
def set_obs_options(self, obs_options):
|
||||
print(f"Setting obs options to {obs_options}")
|
||||
self.obs_options = obs_options
|
||||
|
||||
def _load_image(self, image_path):
|
||||
try:
|
||||
with open(image_path, "rb") as image_file:
|
||||
# Read the image file in binary mode
|
||||
image_data = image_file.read()
|
||||
# Encode the binary data as Base64
|
||||
return image_data
|
||||
except FileNotFoundError:
|
||||
print(f"Error: File not found at {image_path}")
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
|
||||
def _load_accessibility_tree(self, tree_path):
|
||||
try:
|
||||
with open(tree_path, "r") as tree_file:
|
||||
# Read the accessibility tree file
|
||||
tree_data = tree_file.read()
|
||||
return tree_data
|
||||
except FileNotFoundError:
|
||||
print(f"Error: File not found at {tree_path}")
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
|
||||
def _get_obs(self):
|
||||
obs = {}
|
||||
obs["screenshot"] = self._get_screenshot()
|
||||
obs["accessibility_tree"] = self._get_accessibility_tree()
|
||||
obs["terminal"] = ""
|
||||
obs["instruction"] = "Open Chrome browser"
|
||||
|
||||
return obs
|
||||
|
||||
def _start_video_recording(self):
|
||||
pass
|
||||
|
||||
def _stop_video_recording(self):
|
||||
pass
|
||||
|
||||
def step(self, action) -> Tuple:
|
||||
self._step_no += 1
|
||||
self.action_history.append(action)
|
||||
|
||||
info = {}
|
||||
terminated = False # todo: Define episode termination condition for each example
|
||||
|
||||
if action == 'FAIL' or action == 'DONE':
|
||||
terminated = True
|
||||
|
||||
else:
|
||||
if self.action_space == "claude_computer_use":
|
||||
tool_result = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_01A09q90qw90lq917835lq9",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": self.screenshots[1],
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
info.update({"tool_result": tool_result})
|
||||
|
||||
return (terminated, info)
|
||||
|
||||
def close(self):
|
||||
self._step_no = 0
|
||||
self.action_history = []
|
||||
self.obs_options = {}
|
||||
self.controller = None
|
||||
|
||||
def reset(self, *args: Any, **kwargs: Any) -> dict:
|
||||
return self._get_obs()
|
||||
50
test_env/fake_python_controller.py
Normal file
50
test_env/fake_python_controller.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class PythonController:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_screenshot(self) -> Optional[bytes]:
|
||||
pass
|
||||
|
||||
def get_accessibility_tree(self) -> Optional[str]:
|
||||
pass
|
||||
|
||||
def get_terminal_output(self) -> Optional[str]:
|
||||
pass
|
||||
|
||||
def get_file(self, file_path: str) -> Optional[bytes]:
|
||||
pass
|
||||
|
||||
def execute_python_command(self, command: str) -> None:
|
||||
pass
|
||||
|
||||
def execute_action(self, action: Dict[str, Any]):
|
||||
pass
|
||||
|
||||
# Record video
|
||||
def start_recording(self):
|
||||
pass
|
||||
|
||||
def end_recording(self, dest: str):
|
||||
pass
|
||||
|
||||
# Additional info
|
||||
def get_vm_platform(self):
|
||||
pass
|
||||
|
||||
def get_vm_screen_size(self):
|
||||
pass
|
||||
|
||||
def get_vm_window_size(self, app_class_name: str):
|
||||
pass
|
||||
|
||||
def get_vm_wallpaper(self):
|
||||
pass
|
||||
|
||||
def get_vm_desktop_path(self) -> Optional[str]:
|
||||
pass
|
||||
|
||||
def get_vm_directory_tree(self, path) -> Optional[Dict[str, Any]]:
|
||||
pass
|
||||
1
test_env/test_observations/a11y_tree0.txt
Normal file
1
test_env/test_observations/a11y_tree0.txt
Normal file
File diff suppressed because one or more lines are too long
1
test_env/test_observations/a11y_tree1.txt
Normal file
1
test_env/test_observations/a11y_tree1.txt
Normal file
File diff suppressed because one or more lines are too long
BIN
test_env/test_observations/screenshot0.jpg
Normal file
BIN
test_env/test_observations/screenshot0.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 235 KiB |
BIN
test_env/test_observations/screenshot1.jpg
Normal file
BIN
test_env/test_observations/screenshot1.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 173 KiB |
Reference in New Issue
Block a user