From 992d8f8fcedfe801a688a72bcaf8977742a5acd0 Mon Sep 17 00:00:00 2001
From: Timothyxxx <384084775@qq.com>
Date: Sat, 2 Dec 2023 17:52:00 +0800
Subject: [PATCH] Refactor with pyautogui
---
SERVER_SETUP.md | 21 +--
desktop_env/controllers/keyboard.py | 56 --------
desktop_env/controllers/mouse.py | 162 -----------------------
desktop_env/controllers/python.py | 40 +++---
desktop_env/controllers/xdotool.py | 11 --
desktop_env/envs/desktop_env.py | 194 ++++------------------------
desktop_env/server/main.py | 64 +++++++++
desktop_env/windows_server/main.py | 27 ----
gpt_4v_agent_exp.py | 19 ++-
main.py | 63 +++------
mm_agents/gpt_4v_agent.py | 2 +
11 files changed, 144 insertions(+), 515 deletions(-)
delete mode 100644 desktop_env/controllers/keyboard.py
delete mode 100644 desktop_env/controllers/mouse.py
delete mode 100644 desktop_env/controllers/xdotool.py
create mode 100644 desktop_env/server/main.py
delete mode 100644 desktop_env/windows_server/main.py
diff --git a/SERVER_SETUP.md b/SERVER_SETUP.md
index f67c01e..231bd9e 100644
--- a/SERVER_SETUP.md
+++ b/SERVER_SETUP.md
@@ -1,23 +1,6 @@
# Server Setup Guide
-- [Linux](#linux)
-- [Windows](#windows)
-
-## Linux
-
-
-
-1. `sudo apt install openssh-server`
-2. `sudo systemctl enable ssh --now`
-3. `sudo ufw disable` (disable firewall - safe for local network, otherwise `sudo ufw allow ssh`)
-4. `ip a` - find ip address
-5. ssh username@
-6. On host, run `ssh-copy-id @`
-
-
-## Windows
-
-1. Copy and paste the file `windows_server/main.py` to the windows vm
-2. Make sure `mouse` and `keyboard` are installed
+1. Copy and paste the file `server/main.py` to the windows vm
+2. Install the requirements `pip install -r requirements.txt`
3. Run the file `python main.py`
4. `ipconfig /all` and find the ip address
\ No newline at end of file
diff --git a/desktop_env/controllers/keyboard.py b/desktop_env/controllers/keyboard.py
deleted file mode 100644
index 5178491..0000000
--- a/desktop_env/controllers/keyboard.py
+++ /dev/null
@@ -1,56 +0,0 @@
-from abc import ABC, abstractmethod
-from fabric import Connection
-
-from .xdotool import XDoToolController
-from .python import PythonController
-
-class AbstractKeyboardController(ABC):
- @abstractmethod
- def type(self, text: str):
- raise NotImplementedError
-
- @abstractmethod
- def key(self, key: str):
- raise NotImplementedError
-
- @abstractmethod
- def key_down(self, key: str):
- raise NotImplementedError
-
- @abstractmethod
- def key_up(self, key: str):
- raise NotImplementedError
-
-
-class XDoToolKeyboardController(AbstractKeyboardController, XDoToolController):
- def __init__(self, ssh_connection: Connection):
- super().__init__(ssh_connection=ssh_connection)
-
- def type(self, text: str):
- self._execute_xdotool_command(f"type {text}")
-
- def key(self, key: str):
- self._execute_xdotool_command(f"key {key}")
-
- def key_down(self, key: str):
- self._execute_xdotool_command(f"keydown {key}")
-
- def key_up(self, key: str):
- self._execute_xdotool_command(f"keyup {key}")
-
-class PythonKeyboardController(AbstractKeyboardController, PythonController):
- def __init__(self, http_server: str):
- super().__init__(http_server=http_server)
- self.command = "python -c \"import keyboard; {command}\""
-
- def type(self, text: str):
- self._execute_python_command(self.command.format(command=f"keyboard.write('{text}')"))
-
- def key(self, key: str):
- self._execute_python_command(self.command.format(command=f"keyboard.press_and_release('{key}')"))
-
- def key_down(self, key: str):
- self._execute_python_command(self.command.format(command=f"keyboard.press('{key}')"))
-
- def key_up(self, key: str):
- self._execute_python_command(self.command.format(command=f"keyboard.release('{key}')"))
diff --git a/desktop_env/controllers/mouse.py b/desktop_env/controllers/mouse.py
deleted file mode 100644
index 833f4ee..0000000
--- a/desktop_env/controllers/mouse.py
+++ /dev/null
@@ -1,162 +0,0 @@
-from enum import Enum
-
-from abc import ABC, abstractmethod
-from fabric import Connection
-import re
-
-from .xdotool import XDoToolController
-from .python import PythonController
-class MouseClick(Enum):
- LEFT = 1
- MIDDLE = 2
- RIGHT = 3
- WHEEL_UP = 4
- WHEEL_DOWN = 5
-
-class AbstractMouseController(ABC):
- @abstractmethod
- def get_mouse(self):
- raise NotImplementedError
-
- @abstractmethod
- def mouse_move(self, x: int, y: int):
- raise NotImplementedError
-
- @abstractmethod
- def left_down(self):
- raise NotImplementedError
-
- @abstractmethod
- def left_up(self):
- raise NotImplementedError
-
- @abstractmethod
- def left_click(self):
- raise NotImplementedError
-
- @abstractmethod
- def middle_down(self):
- raise NotImplementedError
-
- @abstractmethod
- def middle_up(self):
- raise NotImplementedError
-
- @abstractmethod
- def middle_click(self):
- raise NotImplementedError
-
- @abstractmethod
- def right_down(self):
- raise NotImplementedError
-
- @abstractmethod
- def right_up(self):
- raise NotImplementedError
-
- @abstractmethod
- def right_click(self):
- raise NotImplementedError
-
- @abstractmethod
- def scroll_up(self):
- raise NotImplementedError
-
- @abstractmethod
- def scroll_down(self):
- raise NotImplementedError
-
-
-class XDoToolMouseController(AbstractMouseController, XDoToolController):
- def __init__(self, ssh_connection: Connection):
- super().__init__(ssh_connection=ssh_connection)
-
- def get_mouse(self):
- output = self._execute_xdotool_command(f"")
- parts = output.split(" ")
- x = int(parts[0].split(":")[1])
- y = int(parts[1].split(":")[1])
- return x, y
-
- def mouse_move(self, x: int, y: int):
- self._execute_xdotool_command(f"mousemove {x} {y}")
-
- def left_down(self):
- self._execute_xdotool_command(f"mousedown 1")
-
- def left_up(self):
- self._execute_xdotool_command(f"mouseup 1")
-
- def left_click(self):
- self._execute_xdotool_command(f"click 1")
-
- def middle_down(self):
- self._execute_xdotool_command(f"mousedown 2")
-
- def middle_up(self):
- self._execute_xdotool_command(f"mouseup 2")
-
- def middle_click(self):
- self._execute_xdotool_command(f"click 2")
-
- def right_down(self):
- self._execute_xdotool_command(f"mousedown 3")
-
- def right_up(self):
- self._execute_xdotool_command(f"mouseup 3")
-
- def right_click(self):
- self._execute_xdotool_command(f"click 3")
-
- def scroll_up(self):
- self._execute_xdotool_command(f"click 4")
-
- def scroll_down(self):
- self._execute_xdotool_command(f"click 5")
-
-class PythonMouseController(AbstractMouseController, PythonController):
- def __init__(self, http_server: str):
- super().__init__(http_server=http_server)
- self.command = "py -c \"import mouse; {command}\""
-
- def get_mouse(self):
- response = self._execute_python_command(self.command.format(command=f"print(mouse.get_position())"))
- numbers = re.findall(r'-?\d+', response["output"])
- x, y = map(int, numbers)
- return x, y
-
- def mouse_move(self, x: int, y: int):
- self._execute_python_command(self.command.format(command=f"mouse.move({x}, {y})"))
-
- def left_down(self):
- self._execute_python_command(self.command.format(command="mouse.press(button='left')"))
-
- def left_up(self):
- self._execute_python_command(self.command.format(command="mouse.release(button='left')"))
-
- def left_click(self):
- self._execute_python_command(self.command.format(command="mouse.click(button='left')"))
-
- def middle_down(self):
- self._execute_python_command(self.command.format(command="mouse.press(button='middle')"))
-
- def middle_up(self):
- self._execute_python_command(self.command.format(command="mouse.release(button='middle')"))
-
- def middle_click(self):
- self._execute_python_command(self.command.format(command="mouse.click(button='middle')"))
-
- def right_down(self):
- self._execute_python_command(self.command.format(command="mouse.press(button='right')"))
-
- def right_up(self):
- self._execute_python_command(self.command.format(command="mouse.release(button='right')"))
-
- def right_click(self):
- self._execute_python_command(self.command.format(command="mouse.click(button='right')"))
-
- def scroll_up(self):
- self._execute_python_command(self.command.format(command="mouse.wheel(10)"))
-
- def scroll_down(self):
- self._execute_python_command(self.command.format(command="mouse.wheel(-10)"))
\ No newline at end of file
diff --git a/desktop_env/controllers/python.py b/desktop_env/controllers/python.py
index 1e6a751..2e2c804 100644
--- a/desktop_env/controllers/python.py
+++ b/desktop_env/controllers/python.py
@@ -1,15 +1,31 @@
-import requests
import json
+import requests
+
class PythonController:
- def __init__(self, http_server: str):
+ def __init__(self, http_server: str, pkgs_prefix: str = "py -c \"import pyautogui; {command}\""):
self.http_server = http_server
+ self.pkgs_prefix = pkgs_prefix # fixme: this is a hacky way to execute python commands. fix it and combine it with installation of packages
- def _execute_python_command(self, command: str) -> None:
- payload = json.dumps({
- "command": command
- })
+ def get_screenshot(self):
+ """
+ Gets a screenshot from the server. With the cursor.
+ """
+ response = requests.get(self.http_server + "/screenshot")
+ if response.status_code == 200:
+ return response.content
+ else:
+ print("Failed to get screenshot. Status code:", response.status_code)
+ return None
+
+ def execute_python_command(self, command: str) -> None:
+ """
+ Executes a python command on the server.
+ It can be used to execute the pyautogui commands, or... any other python command. who knows?
+ """
+ command = self.pkgs_prefix.format(command=command)
+ payload = json.dumps({"command": command})
headers = {
'Content-Type': 'application/json'
}
@@ -23,15 +39,3 @@ class PythonController:
return response.json()
except requests.exceptions.RequestException as e:
print("An error occurred while trying to execute the command:", e)
-
-
-# example usage
-if __name__ == '__main__':
- # replace with your actual server URL of the vm
- server_url = "http://192.168.7.129:5000"
- controller = PythonController(server_url)
-
- # example commands
- python_command = "python -c \"import keyboard; keyboard.write('hello world')\""
- python_command = "python -c \"import mouse; mouse.move(100,100);mouse.right_click()\""
- controller._execute_python_command(python_command)
diff --git a/desktop_env/controllers/xdotool.py b/desktop_env/controllers/xdotool.py
deleted file mode 100644
index abb268f..0000000
--- a/desktop_env/controllers/xdotool.py
+++ /dev/null
@@ -1,11 +0,0 @@
-from fabric import Connection
-from typing import List
-
-
-class XDoToolController:
- def __init__(self, ssh_connection: Connection):
- self.ssh_connection = ssh_connection
-
- def _execute_xdotool_command(self, command: List[str]) -> None:
- result = self.ssh_connection.run(f"DISPLAY=:0 xdotool {command}", hide=True)
- return result.stdout.strip()
diff --git a/desktop_env/envs/desktop_env.py b/desktop_env/envs/desktop_env.py
index d032f87..5485fe0 100644
--- a/desktop_env/envs/desktop_env.py
+++ b/desktop_env/envs/desktop_env.py
@@ -1,34 +1,20 @@
+from __future__ import annotations
+
import os
-from enum import Enum
-from typing import Literal, List, Tuple
import subprocess
-from fabric import Connection
import time
+import uuid
+from typing import List
import gymnasium as gym
-from gymnasium import spaces
-import numpy as np
-import uuid
-from PIL import Image
-from desktop_env.controllers.mouse import MouseClick, AbstractMouseController, XDoToolMouseController, \
- PythonMouseController
-from desktop_env.controllers.keyboard import AbstractKeyboardController, XDoToolKeyboardController, \
- PythonKeyboardController
+from desktop_env.controllers.python import PythonController
-class Action(Enum):
- CLICK = 0
- MOUSE_DOWN = 1
- MOUSE_UP = 2
- MOUSE_MOVE = 3
- KEY = 4
- KEY_DOWN = 5
- KEY_UP = 6
- TYPE = 7
-
-
-VM_TYPE = Literal['ubuntu', 'windows']
+def _execute_command(command: List[str]) -> None:
+ result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True)
+ if result.returncode != 0:
+ raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m")
class DesktopEnv(gym.Env):
@@ -37,67 +23,20 @@ class DesktopEnv(gym.Env):
def __init__(
self,
path_to_vm: str,
- username: str,
- password: str = None,
host: str = "192.168.7.128:5000",
snapshot_path: str = "base",
- vm_os: VM_TYPE = "ubuntu"
):
-
- # The path to the vmx file of your vm
+ # Initialize environment variables
self.path_to_vm = path_to_vm
-
- # username and password for your vm
- self.username = username
- self.password = password
-
self.host = host
self.snapshot_path = snapshot_path # todo: handling the logic of snapshot directory
- # Initialize emulator
+ # Initialize emulator and controller
print("Initializing...")
self._start_emulator()
+ self.controller = PythonController(http_server=self.host)
- # set up controllers
- self.mouse_controller, self.keyboard_controller = self._create_controllers(vm_os)
-
- # Get the screen size
- self.screen_width, self.screen_height = self._get_screensize()
-
- # Define the action and observation space
- self.action_space = spaces.Dict({
- "action_type": spaces.Discrete(len(Action)),
- "click_type": spaces.Discrete(len(MouseClick)),
- "x": spaces.Discrete(self.screen_width),
- "y": spaces.Discrete(self.screen_height),
- "key": spaces.MultiDiscrete([128] * 10), # max 10 characters, ASCII
- "text": spaces.MultiDiscrete([128] * 10) # max 10 characters, ASCII
- })
-
- self.observation_space = spaces.Box(low=0, high=255, shape=(self.screen_width, self.screen_height, 3),
- dtype=np.uint8)
-
- # Additional setup
- self.metadata = {'render.modes': ['rgb_array']}
-
-
- def _get_screensize(self):
- screenshot_path = self._get_obs()
- img = Image.open(screenshot_path)
- return img.size
-
- def _create_controllers(self, vm_os: VM_TYPE) -> Tuple[AbstractMouseController, AbstractKeyboardController]:
- if vm_os == "ubuntu":
- ssh_connection = Connection(host=self.host, user=self.username, connect_kwargs={"password": self.password})
- mouse_controller = XDoToolMouseController(ssh_connection)
- keyboard_controller = XDoToolKeyboardController(ssh_connection)
- elif vm_os == "windows":
- mouse_controller = PythonMouseController(http_server=self.host)
- keyboard_controller = PythonKeyboardController(http_server=self.host)
- else:
- raise NotImplementedError(vm_os)
-
- return mouse_controller, keyboard_controller
+ # todo: define the action space and the observation space as gym did
def _start_emulator(self):
while True:
@@ -109,52 +48,35 @@ class DesktopEnv(gym.Env):
break
else:
print("Starting VM...")
- self._execute_command(["vmrun", "-T", "ws", "start", self.path_to_vm])
+ _execute_command(["vmrun", "-T", "ws", "start", self.path_to_vm])
time.sleep(10)
except subprocess.CalledProcessError as e:
print(f"Error executing command: {e.output.decode().strip()}")
- def _execute_command(self, command: List[str]) -> None:
- result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True)
- if result.returncode != 0:
- raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m")
-
def _save_state(self):
- self._execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
+ _execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
def _get_screenshot(self):
random_uuid = str(uuid.uuid4())
os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
image_path = os.path.join("tmp", random_uuid, "screenshot.png")
- if self.password:
- self._execute_command(
- ["vmrun", "-T", "ws", "-gu", self.username, "-gp", self.password, "captureScreen", self.path_to_vm, image_path])
- else:
- self._execute_command(
- ["vmrun", "-T", "ws", "-gu", self.username, "captureScreen", self.path_to_vm, image_path])
+ # Get the screenshot and save to the image_path
+ screenshot = self.controller.get_screenshot()
+ with open(image_path, "wb") as f:
+ f.write(screenshot)
return image_path
def _get_obs(self):
screenshot_image_path = self._get_screenshot()
- self._add_cursor(screenshot_image_path)
return screenshot_image_path
-
- def _add_cursor(self, img_path: str):
- x, y = self.mouse_controller.get_mouse()
- cursor_image = Image.open("./desktop_env/assets/cursor.png")
- cursor_image = cursor_image.resize((int(cursor_image.width / 2), int(cursor_image.height / 2)))
- screenshot = Image.open(img_path)
- screenshot.paste(cursor_image, (x, y), cursor_image)
- screenshot.save(img_path)
-
- def reset(self):
+ def reset(self, seed=None, options=None):
print("Resetting environment...")
print("Reverting to snapshot to {}...".format(self.snapshot_path))
- self._execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
+ _execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
time.sleep(5)
print("Starting emulator...")
@@ -165,75 +87,11 @@ class DesktopEnv(gym.Env):
return observation
def step(self, action):
- if isinstance(action, list):
- for a in action:
- observation, reward, done, info = self.step(a)
- return observation, reward, done, info
-
- # todo: handle the case when the action is not a single action
- try:
- action_type = Action(action['action_type'])
- except KeyError:
- done = True
- return self._get_obs(), 0, done, {}
-
- if action_type == Action.CLICK:
- click = MouseClick(action['click_type'])
- if click == MouseClick.LEFT:
- self.mouse_controller.left_click()
- elif click == MouseClick.MIDDLE:
- self.mouse_controller.middle_click()
- elif click == MouseClick.RIGHT:
- self.mouse_controller.right_click()
- elif click == MouseClick.WHEEL_UP:
- self.mouse_controller.scroll_up()
- elif click == MouseClick.WHEEL_DOWN:
- self.mouse_controller.scroll_down()
- elif action_type == Action.MOUSE_DOWN:
- click = MouseClick(action['click_type'])
- if click == MouseClick.LEFT:
- self.mouse_controller.left_down()
- elif click == MouseClick.MIDDLE:
- self.mouse_controller.middle_down()
- elif click == MouseClick.RIGHT:
- self.mouse_controller.right_down()
- elif click == MouseClick.WHEEL_UP:
- self.mouse_controller.scroll_up()
- elif click == MouseClick.WHEEL_DOWN:
- self.mouse_controller.scroll_down()
- elif action_type == Action.MOUSE_UP:
- click = MouseClick(action['click_type'])
- if click == MouseClick.LEFT:
- self.mouse_controller.left_up()
- elif click == MouseClick.MIDDLE:
- self.mouse_controller.middle_up()
- elif click == MouseClick.RIGHT:
- self.mouse_controller.right_up()
- elif click == MouseClick.WHEEL_UP:
- self.mouse_controller.scroll_up()
- elif click == MouseClick.WHEEL_DOWN:
- self.mouse_controller.scroll_down()
- elif action_type == Action.MOUSE_MOVE:
- self.mouse_controller.mouse_move(x=action['x'], y=action['y'])
- elif action_type == Action.KEY:
- self.keyboard_controller.key(action['key'])
- elif action_type == Action.KEY_DOWN:
- self.keyboard_controller.key_down(action['key'])
- elif action_type == Action.KEY_UP:
- self.keyboard_controller.key_up(action['key'])
- elif action_type == Action.TYPE:
- for key in action['text']:
- if key == "\r" or key == "\n":
- self.keyboard_controller.key("enter")
- else:
- self.keyboard_controller.key(key)
- # sleep for 0.05 seconds with some random noise
- time.sleep(0.05 + np.random.normal(0, 0.01))
-
- # Capture new state
+ # Our action space is the set of all possible python commands insides `pyautogui`
+ self.controller.execute_python_command(action)
observation = self._get_obs()
- reward = 0 # Define reward calculation
- done = False # Define episode termination condition
+ reward = 0 # todo: Define reward calculation for each example
+ done = False # todo: Define episode termination condition for each example
info = {}
return observation, reward, done, info
@@ -244,4 +102,4 @@ class DesktopEnv(gym.Env):
raise ValueError('Unsupported render mode: {}'.format(mode))
def close(self):
- self._execute_command(["vmrun", "stop", self.path_to_vm])
+ _execute_command(["vmrun", "stop", self.path_to_vm])
diff --git a/desktop_env/server/main.py b/desktop_env/server/main.py
new file mode 100644
index 0000000..228a08d
--- /dev/null
+++ b/desktop_env/server/main.py
@@ -0,0 +1,64 @@
+import os
+import platform
+import subprocess
+
+import Xlib.display
+import pyautogui
+from PIL import ImageGrab
+from flask import Flask, request, jsonify, send_file
+
+app = Flask(__name__)
+
+pyautogui.PAUSE = 0
+pyautogui.DARWIN_CATCH_UP_TIME = 0
+
+
+@app.route('/execute', methods=['POST'])
+def execute_command():
+ data = request.json
+ # The 'command' key in the JSON request should contain the command to be executed.
+ command = data.get('command', '')
+
+ # Execute the command without any safety checks.
+ try:
+ result = subprocess.run(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
+ return jsonify({
+ 'status': 'success',
+ 'output': result.stdout,
+ 'error': result.stderr
+ })
+ except Exception as e:
+ return jsonify({
+ 'status': 'error',
+ 'message': str(e)
+ }), 500
+
+
+@app.route('/screenshot', methods=['GET'])
+def capture_screen_with_cursor():
+ file_path = os.path.join("screenshots", "screenshot.png")
+ user_platform = platform.system()
+
+ # Ensure the screenshots directory exists
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
+
+ if user_platform == "Windows":
+ screenshot = pyautogui.screenshot()
+ screenshot.save(file_path)
+ elif user_platform == "Linux":
+ # Use xlib to prevent scrot dependency for Linux
+ screen = Xlib.display.Display().screen()
+ size = screen.width_in_pixels, screen.height_in_pixels
+ screenshot = ImageGrab.grab(bbox=(0, 0, size[0], size[1]))
+ screenshot.save(file_path)
+ elif user_platform == "Darwin": # (Mac OS)
+ # Use the screencapture utility to capture the screen with the cursor
+ subprocess.run(["screencapture", "-C", file_path])
+ else:
+ print(f"The platform you're using ({user_platform}) is not currently supported")
+
+ return send_file(file_path, mimetype='image/png')
+
+
+if __name__ == '__main__':
+ app.run(debug=True, host="0.0.0.0")
diff --git a/desktop_env/windows_server/main.py b/desktop_env/windows_server/main.py
deleted file mode 100644
index 098dca5..0000000
--- a/desktop_env/windows_server/main.py
+++ /dev/null
@@ -1,27 +0,0 @@
-from flask import Flask, request, jsonify
-import subprocess
-
-app = Flask(__name__)
-
-@app.route('/execute', methods=['POST'])
-def execute_command():
- data = request.json
- # The 'command' key in the JSON request should contain the command to be executed.
- command = data.get('command', '')
-
- # Execute the command without any safety checks.
- try:
- result = subprocess.run(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
- return jsonify({
- 'status': 'success',
- 'output': result.stdout,
- 'error': result.stderr
- })
- except Exception as e:
- return jsonify({
- 'status': 'error',
- 'message': str(e)
- }), 500
-
-if __name__ == '__main__':
- app.run(debug=True, host="0.0.0.0")
diff --git a/gpt_4v_agent_exp.py b/gpt_4v_agent_exp.py
index bae0446..12b5dd0 100644
--- a/gpt_4v_agent_exp.py
+++ b/gpt_4v_agent_exp.py
@@ -8,29 +8,26 @@ import uuid
def gpt_4v_agent():
api_key = os.environ.get("OPENAI_API_KEY")
- # meta_info = {
- # "instruction": "Open WSJ website to get latest news",
- # "task_name": "open_wsj",
- # "snapshot_path": "base",
- # }
-
meta_info = {
- "instruction": "Clear the recycle bin",
- "task_name": "clean_recycle_bin",
+ "instruction": "Open WSJ website to get latest news",
+ "task_name": "open_wsj",
"snapshot_path": "base",
}
+ # meta_info = {
+ # "instruction": "Clear the recycle bin",
+ # "task_name": "clean_recycle_bin",
+ # "snapshot_path": "base",
+ # }
+
agent = GPT4v_Agent(api_key=api_key, instruction=meta_info["instruction"])
env = DesktopEnv(
path_to_vm=r"""C:\Users\tianbaox\Documents\Virtual Machines\Win10\Win10.vmx""",
# automitically load the snapshot and start the vm
# path_to_vm="/home/yuri/vmware/Ubuntu 64-bit/Ubuntu 64-bit.vmx",
snapshot_path="base",
- username="tianbaox",
- password="951753",
# host="192.168.7.128",
host="http://192.168.13.128:5000",
- vm_os="windows"
)
# reset the environment to certain snapshot
diff --git a/main.py b/main.py
index 2cbcd18..57c1926 100644
--- a/main.py
+++ b/main.py
@@ -1,55 +1,31 @@
-from pprint import pprint
-from desktop_env.envs.desktop_env import DesktopEnv, Action, MouseClick
-
-def get_human_action():
- """
- Prompts the human player for an action and returns a structured action.
- """
- print("\nAvailable actions:", [action.name for action in Action])
- action_type = None
- while action_type not in [action.value for action in Action]:
- action_type = Action[input("Enter the type of action: ".strip())].value
-
- action = {"action_type": action_type}
-
- if action_type == Action.CLICK.value or action_type == Action.MOUSE_DOWN.value or action_type == Action.MOUSE_UP.value:
- print("\n Available clicks:", [action.name for action in MouseClick])
- click_type = input("Enter click type: ")
- action["click_type"] = MouseClick[click_type].value
-
- if action_type == Action.MOUSE_MOVE.value:
- x = int(input("Enter x-coordinate for mouse move: "))
- y = int(input("Enter y-coordinate for mouse move: "))
- action["x"] = x
- action["y"] = y
-
- if action_type == Action.KEY.value:
- key = input("Enter the key to press: ")
- action["key"] = [ord(c) for c in key]
-
- if action_type == Action.TYPE.value:
- text = input("Enter the text to type: ")
- action["text"] = [ord(c) for c in text]
-
- return action
+from desktop_env.envs.desktop_env import DesktopEnv
def human_agent():
"""
Runs the Gym environment with human input.
"""
- env = DesktopEnv(path_to_vm="/home/yuri/vmware/Windows 10 x64/Windows 10 x64.vmx",
- # path_to_vm="/home/yuri/vmware/Ubuntu 64-bit/Ubuntu 64-bit.vmx",
- username="user",
- password="password",
- # host="192.168.7.128",
- host="http://192.168.7.129:5000",
- vm_os="windows")
- observation = env.reset()
+ env = DesktopEnv(
+ path_to_vm=r"""C:\Users\tianbaox\Documents\Virtual Machines\Win10\Win10.vmx""",
+ # path_to_vm="/home/yuri/vmware/Ubuntu 64-bit/Ubuntu 64-bit.vmx",
+ # host="192.168.7.128",
+ host="http://192.168.13.128:5000",
+ )
+
+ # reset the environment to certain snapshot
+ # observation = env.reset()
done = False
while not done:
- action = get_human_action()
+ # action = get_human_action()
+
+ # action = {
+ # "action_type": 0,
+ # "click_type": 3,
+ # }
+
+ action = "pyautogui.dragTo(100, 200, button='left')"
+
observation, reward, done, info = env.step(action)
print("Observation:", observation)
print("Reward:", reward)
@@ -64,5 +40,6 @@ def human_agent():
env.close()
print("Environment closed.")
+
if __name__ == "__main__":
human_agent()
diff --git a/mm_agents/gpt_4v_agent.py b/mm_agents/gpt_4v_agent.py
index c52b9c9..52b4dcf 100644
--- a/mm_agents/gpt_4v_agent.py
+++ b/mm_agents/gpt_4v_agent.py
@@ -88,6 +88,8 @@ class GPT4v_Agent:
traj_to_show = []
for i in range(len(self.trajectory)):
traj_to_show.append(self.trajectory[i]["content"][0]["text"])
+ if len(self.trajectory[i]["content"]) > 1:
+ traj_to_show.append("screenshot_obs")
print("Trajectory:", traj_to_show)
payload = {
"model": self.model,