Accomplish the exp scripts v1; Add video recording and trajectory recording of desktop agent; Fix minor bugs

This commit is contained in:
Timothyxxx
2024-01-15 13:49:48 +08:00
parent f153a4c253
commit 24169a65d0
6 changed files with 127 additions and 21 deletions

View File

@@ -243,6 +243,32 @@ class PythonController:
else: else:
raise Exception(f"Unknown action type: {action_type}") raise Exception(f"Unknown action type: {action_type}")
# Record video
def start_recording(self):
"""
Starts recording the screen.
"""
response = requests.post(self.http_server + "/start_recording")
if response.status_code == 200:
logger.info("Recording started successfully")
else:
logger.error("Failed to start recording. Status code: %d", response.status_code)
def end_recording(self, dest: str):
"""
Ends recording the screen.
"""
response = requests.post(self.http_server + "/end_recording")
if response.status_code == 200:
logger.info("Recording stopped successfully")
with open(dest, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
else:
logger.error("Failed to stop recording. Status code: %d", response.status_code)
return None
# Additional info # Additional info
def get_vm_platform(self): def get_vm_platform(self):
""" """

View File

@@ -209,10 +209,6 @@ class SetupController:
if not command: if not command:
raise Exception("Empty command to launch.") raise Exception("Empty command to launch.")
if isinstance(command, str) and len(command.split()) > 1:
logger.warning("Command should be a list of strings. Now it is a string. Will split it by space.")
command = command.split()
payload = json.dumps({"command": command}) payload = json.dumps({"command": command})
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}

View File

@@ -86,6 +86,9 @@ def compare_images(image1_path, image2_path):
# score = compare_images('path_to_image1', 'path_to_image2') # score = compare_images('path_to_image1', 'path_to_image2')
# print("Similarity score:", score) # print("Similarity score:", score)
if not image1_path or not image2_path:
return 0
# Open the images and convert to grayscale # Open the images and convert to grayscale
image1 = Image.open(image1_path).convert('L') image1 = Image.open(image1_path).convert('L')
image2 = Image.open(image2_path).convert('L') image2 = Image.open(image2_path).convert('L')
@@ -119,6 +122,9 @@ def compare_audios(audio_path_1, audio_path_2, max_distance=1000):
# print(f'Similarity Score: {similarity}') # print(f'Similarity Score: {similarity}')
# Convert to common format if necessary and load audio # Convert to common format if necessary and load audio
if not audio_path_1 or not audio_path_2:
return 0
y1, sr1 = librosa.load(audio_path_1) y1, sr1 = librosa.load(audio_path_1)
y2, sr2 = librosa.load(audio_path_2) y2, sr2 = librosa.load(audio_path_2)

View File

@@ -78,3 +78,8 @@ Activating the window manager control requires the installation of `wmctrl`:
```bash ```bash
sudo apt install wmctrl sudo apt install wmctrl
``` ```
To enable recording in the virtual machine, you need to install `ffmpeg`:
```bash
sudo apt install ffmpeg
```

View File

@@ -1,6 +1,7 @@
import ctypes import ctypes
import os import os
import platform import platform
import shlex
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from typing import Any, Optional from typing import Any, Optional
@@ -13,7 +14,7 @@ import pyautogui
import requests import requests
from PIL import Image from PIL import Image
from Xlib import display, X from Xlib import display, X
from flask import Flask, request, jsonify, send_file, abort from flask import Flask, request, jsonify, send_file, abort, send_from_directory
from lxml.etree import _Element from lxml.etree import _Element
from pyatspi import Accessible, StateType from pyatspi import Accessible, StateType
from pyatspi import Action as ATAction from pyatspi import Action as ATAction
@@ -29,7 +30,8 @@ pyautogui.PAUSE = 0
pyautogui.DARWIN_CATCH_UP_TIME = 0 pyautogui.DARWIN_CATCH_UP_TIME = 0
logger = app.logger logger = app.logger
recording_process = None # fixme: this is a temporary solution for recording, need to be changed to support multiple-process
recording_path = "/tmp/recording.mp4"
@app.route('/setup/execute', methods=['POST']) @app.route('/setup/execute', methods=['POST'])
@app.route('/execute', methods=['POST']) @app.route('/execute', methods=['POST'])
@@ -39,6 +41,9 @@ def execute_command():
shell = data.get('shell', False) shell = data.get('shell', False)
command = data.get('command', "" if shell else []) command = data.get('command', "" if shell else [])
if isinstance(command, str):
command = shlex.split(command)
# Execute the command without any safety checks. # Execute the command without any safety checks.
try: try:
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell, text=True) result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell, text=True)
@@ -60,6 +65,9 @@ def launch_app():
data = request.json data = request.json
command: List[str] = data.get("command", []) command: List[str] = data.get("command", [])
if isinstance(command, str):
command = shlex.split(command)
try: try:
subprocess.Popen(command) subprocess.Popen(command)
return "{:} launched successfully".format(" ".join(command)) return "{:} launched successfully".format(" ".join(command))
@@ -604,5 +612,42 @@ def activate_window():
return "File opened successfully", 200 return "File opened successfully", 200
@app.route('/start_recording', methods=['POST'])
def start_recording():
global recording_process
if recording_process:
return jsonify({'status': 'error', 'message': 'Recording is already in progress.'}), 400
d = display.Display()
screen_width = d.screen().width_in_pixels
screen_height = d.screen().height_in_pixels
start_command = f"ffmpeg -y -f x11grab -draw_mouse 1 -s {screen_width}x{screen_height} -i :0.0 -c:v libx264 -r 30 {recording_path}"
recording_process = subprocess.Popen(shlex.split(start_command), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
return jsonify({'status': 'success', 'message': 'Started recording.'})
@app.route('/end_recording', methods=['POST'])
def end_recording():
global recording_process
if not recording_process:
return jsonify({'status': 'error', 'message': 'No recording in progress to stop.'}), 400
recording_process.terminate()
recording_process.wait()
return_code = recording_process.returncode
output, error = recording_process.communicate()
recording_process = None
# return recording video file
if os.path.exists(recording_path):
return send_file(recording_path, as_attachment=True)
else:
return abort(404, description="Recording failed")
if __name__ == '__main__': if __name__ == '__main__':
app.run(debug=True, host="0.0.0.0") app.run(debug=True, host="0.0.0.0")

View File

@@ -44,7 +44,7 @@ logger = logging.getLogger("desktopenv.experiment")
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx" PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx"
def run_one_example(example, agent, max_steps=20, example_trajectory_dir="exp_trajectory"): def run_one_example(example, agent, max_steps=2, example_trajectory_dir="exp_trajectory", recording=True):
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json") trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
env = DesktopEnv( env = DesktopEnv(
path_to_vm=PATH_TO_VM, path_to_vm=PATH_TO_VM,
@@ -57,25 +57,53 @@ def run_one_example(example, agent, max_steps=20, example_trajectory_dir="exp_tr
done = False done = False
step_num = 0 step_num = 0
# todo: save the screenshots and actions to a folder if recording:
# send a request to the server to start recording
env.controller.start_recording()
while not done and step_num < max_steps: while not done and step_num < max_steps:
actions = agent.predict(observation) actions = agent.predict(observation)
for action in actions: for action in actions:
step_num += 1
# Capture the timestamp before executing the action
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
observation, reward, done, info = env.step(action) observation, reward, done, info = env.step(action)
observation['instruction'] = example['instruction'] observation['instruction'] = example['instruction']
step_num += 1
logger.info("Step %d", step_num)
logger.info("Action: %s", actions)
observation.pop("accessibility_tree")
logger.info("Observation: %s", observation)
logger.info("Reward: %.2f", reward)
logger.info("Info: %s", info)
logger.info("================================\n") # Logging
logger.info("Step %d: %s", step_num, action)
logger.info("Reward: %.2f", reward)
logger.info("Done: %s", done)
logger.info("Info: %s", info)
if done:
logger.info("The episode is done.") # Save screenshot and trajectory information
break with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
with open(observation['screenshot'], "rb") as __f:
screenshot = __f.read()
_f.write(screenshot)
with open(trajectory_recording_path, "a") as f:
f.write(json.dumps({
"step_num": step_num,
"action_timestamp": action_timestamp,
"action": action,
"reward": reward,
"done": done,
"info": info,
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
}))
f.write("\n")
if done:
logger.info("The episode is done.")
break
if recording:
# send a request to the server to stop recording
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
result = env.evaluate() result = env.evaluate()
logger.info("Result: %.2f", result) logger.info("Result: %.2f", result)
@@ -91,7 +119,7 @@ if __name__ == "__main__":
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r") as f: with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r") as f:
example = json.load(f) example = json.load(f)
example["snapshot"] = "chrome_setup" example["snapshot"] = "exp_setup"
api_key = os.environ.get("OPENAI_API_KEY") api_key = os.environ.get("OPENAI_API_KEY")
agent = GPT4v_Agent(api_key=api_key, action_space=action_space) agent = GPT4v_Agent(api_key=api_key, action_space=action_space)
@@ -101,4 +129,4 @@ if __name__ == "__main__":
example_trajectory_dir = os.path.join(root_trajectory_dir, example_class, example_id) example_trajectory_dir = os.path.join(root_trajectory_dir, example_class, example_id)
os.makedirs(example_trajectory_dir, exist_ok=True) os.makedirs(example_trajectory_dir, exist_ok=True)
run_one_example(example, agent, 20, example_trajectory_dir) run_one_example(example, agent, 2, example_trajectory_dir)