Merge branch 'main' into zdy
18
README.md
@@ -1,7 +1,8 @@
|
|||||||
# DesktopEnv: A Learning Environment for Human-like Computer Task Mastery
|
# DesktopEnv: An Environment towards Human-like Computer Task Mastery
|
||||||
|
|
||||||
## Setup guide
|
## Setup guide
|
||||||
|
|
||||||
|
### For members of the team
|
||||||
1. Download OS image
|
1. Download OS image
|
||||||
1. Download kubuntu from <https://kubuntu.org/getkubuntu/>
|
1. Download kubuntu from <https://kubuntu.org/getkubuntu/>
|
||||||
2. Download ubuntu from <https://ubuntu.com/download/desktop>
|
2. Download ubuntu from <https://ubuntu.com/download/desktop>
|
||||||
@@ -13,7 +14,7 @@
|
|||||||
3. Set up bridge for connecting to VM
|
3. Set up bridge for connecting to VM
|
||||||
1. Option 1: Install [xdotool](https://github.com/jordansissel/xdotool) on VM
|
1. Option 1: Install [xdotool](https://github.com/jordansissel/xdotool) on VM
|
||||||
2. Option 2: Install [mouse](https://github.com/boppreh/mouse/)
|
2. Option 2: Install [mouse](https://github.com/boppreh/mouse/)
|
||||||
4. Set up SSH server on VM | [Guide](./SSH_SERVER_SETUP.md)
|
4. Set up SSH server on VM | [Guide](./SERVER_SETUP.md)
|
||||||
5. Install screenshot tool (in vm)
|
5. Install screenshot tool (in vm)
|
||||||
1. `sudo apt install imagemagick-6.q16hdri`
|
1. `sudo apt install imagemagick-6.q16hdri`
|
||||||
2. `DISPLAY=:0 import -window root screenshot.png`
|
2. `DISPLAY=:0 import -window root screenshot.png`
|
||||||
@@ -22,12 +23,8 @@
|
|||||||
2. `rm -rf ~/screenshot.png`
|
2. `rm -rf ~/screenshot.png`
|
||||||
7. Set up python and install [mouse](https://github.com/boppreh/mouse/) and [keyboard](https://github.com/jordansissel/xdotool)
|
7. Set up python and install [mouse](https://github.com/boppreh/mouse/) and [keyboard](https://github.com/jordansissel/xdotool)
|
||||||
|
|
||||||
## Windows setup guide
|
### For users of the environment
|
||||||
|
todo
|
||||||
1. Copy and paste the file `windows_server/main.py` to the windows vm
|
|
||||||
2. Make sure `mouse` and `keyboard` are installed
|
|
||||||
3. Run the file `pythonw main.py`
|
|
||||||
4. `ipconfig /all` and find the ip address
|
|
||||||
|
|
||||||
## Road map (Proposed)
|
## Road map (Proposed)
|
||||||
|
|
||||||
@@ -36,6 +33,11 @@
|
|||||||
- MacOS is closed source and cannot be legally installed
|
- MacOS is closed source and cannot be legally installed
|
||||||
- Windows is available legally and can be installed
|
- Windows is available legally and can be installed
|
||||||
- [x] Build gym-like python interface for controlling the VM
|
- [x] Build gym-like python interface for controlling the VM
|
||||||
|
- [ ] Make configuration much easier from code perspective
|
||||||
|
- [ ] README
|
||||||
|
- [ ] Make it easier to install the dependencies
|
||||||
|
- [ ] Make it easier to install the VM
|
||||||
|
- [ ] Make it easier to set up the VM
|
||||||
- [ ] Recording of actions (mouse movement, click, keyboard) for human to annotate, and we can replay it
|
- [ ] Recording of actions (mouse movement, click, keyboard) for human to annotate, and we can replay it
|
||||||
- [ ] This part may be conflict with work from [Aran Komatsuzaki](https://twitter.com/arankomatsuzaki) team, a.k.a. [Duck AI](https://duckai.org/)
|
- [ ] This part may be conflict with work from [Aran Komatsuzaki](https://twitter.com/arankomatsuzaki) team, a.k.a. [Duck AI](https://duckai.org/)
|
||||||
- [ ] Build a simple task, e.g. open a browser, open a website, click on a button, and close the browser
|
- [ ] Build a simple task, e.g. open a browser, open a website, click on a button, and close the browser
|
||||||
|
|||||||
6
SERVER_SETUP.md
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
# Server Setup Guide
|
||||||
|
|
||||||
|
1. Copy and paste the file `server/main.py` to the windows vm
|
||||||
|
2. Install the requirements `pip install -r requirements.txt`
|
||||||
|
3. Run the file `python main.py`
|
||||||
|
4. `ipconfig /all` and find the ip address
|
||||||
@@ -1,50 +0,0 @@
|
|||||||
# SSH Server Setup Guide
|
|
||||||
|
|
||||||
- [Linux](#linux)
|
|
||||||
- [Windows](#windows)
|
|
||||||
|
|
||||||
## Linux
|
|
||||||
|
|
||||||
<https://averagelinuxuser.com/ssh-into-virtualbox/>
|
|
||||||
|
|
||||||
1. `sudo apt install openssh-server`
|
|
||||||
2. `sudo systemctl enable ssh --now`
|
|
||||||
3. `sudo ufw disable` (disable firewall - safe for local network, otherwise `sudo ufw allow ssh`)
|
|
||||||
4. `ip a` - find ip address
|
|
||||||
5. ssh username@<ip_address>
|
|
||||||
6. On host, run `ssh-copy-id <username>@<ip_address>`
|
|
||||||
|
|
||||||
## Windows
|
|
||||||
|
|
||||||
To SSH into a Windows machine and control it using the terminal, you need to set up an SSH server on the Windows machine and then connect to it from an SSH client. Microsoft has integrated an OpenSSH server and client in Windows 10 and later, which makes this process more straightforward. Here's how you can do it:
|
|
||||||
|
|
||||||
> Make sure that your windows account has a password
|
|
||||||
|
|
||||||
### Setting Up SSH Server on Windows
|
|
||||||
|
|
||||||
1. **Enable OpenSSH Server:**
|
|
||||||
|
|
||||||
- Open **Settings** → **Apps** → **Optional Features**.
|
|
||||||
- Scan the list to see if OpenSSH Server is installed. If it's not, click on **Add a feature**, then find **OpenSSH Server**, and click **Install**.
|
|
||||||
|
|
||||||
2. **Start the SSH Service:**
|
|
||||||
|
|
||||||
- Open **Services** from the Start menu.
|
|
||||||
- Find the **OpenSSH SSH Server** service, right-click it, and select **Properties**.
|
|
||||||
- Set the startup type to **Automatic** and then start the service.
|
|
||||||
|
|
||||||
3. **Configure the Firewall (if necessary):**
|
|
||||||
|
|
||||||
- In most cases, Windows Firewall will automatically allow SSH connections. However, if you have a third-party firewall or if connections are being blocked, you may need to manually open port 22 (default SSH port).
|
|
||||||
|
|
||||||
4. **Add ssh key to windows**
|
|
||||||
|
|
||||||
- Add the public ssh key to `C:\ProgramData\ssh\administrators_authorized_keys` and `~/.ssh/authorized_keys`
|
|
||||||
|
|
||||||
### Connecting to the Windows Machine from SSH Client
|
|
||||||
|
|
||||||
1. **From a Linux/Mac Client:**
|
|
||||||
|
|
||||||
- Open the terminal.
|
|
||||||
- Use the command `ssh username@windows-ip-address`. Replace `username` with the Windows account username and `windows-ip-address` with the IP address of the Windows machine.
|
|
||||||
- Accept the fingerprint (if it's the first time connecting) and enter the password when prompted.
|
|
||||||
172
annotation/.gitignore
vendored
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
||||||
|
|
||||||
|
# experiments
|
||||||
|
experiments/**/*.png
|
||||||
|
experiments/**/*.csv
|
||||||
|
experiments/**/*.mp4
|
||||||
|
experiments/**/*.jsonl
|
||||||
|
experiments/**/*.json
|
||||||
|
experiments/**/*.md
|
||||||
|
experiments/**/*.txt
|
||||||
|
|
||||||
|
# macos
|
||||||
|
*DS_Store*
|
||||||
21
annotation/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2023 DuckAI
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
50
annotation/OBS_SETUP.md
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
# OBS Setup
|
||||||
|
|
||||||
|
These are instructions on setting up OBS (Open Broadcaster Software) to record screen activity for creating the multimodal computer dataset.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
1. Go to the OBS Project website: [https://obsproject.com/](https://obsproject.com/).
|
||||||
|
2. Choose the appropriate installer for your operating system.
|
||||||
|
3.
|
||||||
|

|
||||||
|
|
||||||
|
3. Run the installer from your downloads folder and grant OBS the necessary permissions for installation.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
4. Keep the default settings and proceed through the installation wizard by clicking "Next" and then "Finish."
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
5. OBS should now be open. If not, search for and open the application.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
## Enabling OBS WebSocket Server
|
||||||
|
|
||||||
|
1. Click on "Tools" in the Navigation Bar within OBS, and then select "WebSocket Server Settings." A pop-up window will appear.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
2. Check the box next to "Enable WebSocket server" and uncheck the box next to "Enable Authentication." Click "Apply," then "Ok." You should return to the main OBS page.
|
||||||
|
Make sure the port is set to 4455.
|
||||||
|

|
||||||
|
|
||||||
|
## Adding Display Capture and Recording
|
||||||
|
|
||||||
|
1. Now, back on the home page of OBS, select "Scene." Under "Sources," click the "+" button and then click "Display Capture." (in MacOS this is MacOS Screen Capture)
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
2. Select "Ok."
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
3. Make sure the "Display" is set to your main display, and you should see your screen on the canvas. Select "Ok." _(in MacOS if your screen is black with a red square in the top left try to disable then re-enable OBS Screen Recording permissions, this has worked before)_
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
4. Now you can close OBS and OBS will opened and controlled automatically when you launch the Computer Tracker App. Also, the Computer Tracker app creates a new OBS profile so you don't have to worry about your previous settings being messed up.
|
||||||
|
|
||||||
|

|
||||||
98
annotation/README.md
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
# DuckTrack
|
||||||
|
|
||||||
|
This is the repository for the DuckAI DuckTrack app which records all keyboard and mouse input as well as the screen for use in a multimodal computer interaction dataset.
|
||||||
|
|
||||||
|
## Installation & Setup
|
||||||
|
|
||||||
|
### Download Application
|
||||||
|
|
||||||
|
<!-- TODO: add prebuilt applications in github releases -->
|
||||||
|
Download the pre-built application for your system [here](https://github.com/TheDuckAI/DuckTrack/releases/).
|
||||||
|
|
||||||
|
Make sure you have OBS downloaded with the following configuration:
|
||||||
|
1. Have a screen capture source recording your whole main screen.
|
||||||
|
2. Enable desktop audio and mute microphone.
|
||||||
|
3. Make sure the default websocket is enabled.
|
||||||
|
|
||||||
|
More detailed instructions for OBS setup and installation located [here](OBS_SETUP.md).
|
||||||
|
|
||||||
|
If you are on MacOS, make sure to enable to the following Privacy & Security permissions before running the app:
|
||||||
|
|
||||||
|
1. Accessibility (for playing back actions)
|
||||||
|
2. Input Monitoring (for reading keyboard inputs)
|
||||||
|
|
||||||
|
Make sure to accept all other security permission dialogues to ensure that the app works properly.
|
||||||
|
|
||||||
|
### Build from source
|
||||||
|
|
||||||
|
Have Python >=3.11.
|
||||||
|
|
||||||
|
Clone this repo and `cd` into it:
|
||||||
|
```bash
|
||||||
|
$ git clone https://github.com/TheDuckAI/DuckTrack
|
||||||
|
$ cd DuckTrack
|
||||||
|
```
|
||||||
|
|
||||||
|
Install the dependencies for this project:
|
||||||
|
```bash
|
||||||
|
$ pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
Build the application:
|
||||||
|
```bash
|
||||||
|
$ python3 build.py
|
||||||
|
```
|
||||||
|
|
||||||
|
The built application should be located in the generated `dist` directory. After this, follow the remaining relevant setup instructions.
|
||||||
|
|
||||||
|
## Running the App
|
||||||
|
|
||||||
|
You can run the app like any other desktop app on your computer. If you decided to not download the app or build it from source, just run `python main.py` and it should work the same. You will be interacting with the app through an app tray icon or a small window.
|
||||||
|
|
||||||
|
### Recording
|
||||||
|
|
||||||
|
From the app tray or GUI, you can start and stop a recording as well as pause and resume a recording. Pausing and resuming is important for when you want to hide sensitive information like credit card of login credentials. You can optionally name your recording and give it a description upon stopping a recording. You can also view your recordings by pressing the "Show Recordings" option.
|
||||||
|
|
||||||
|
### Playback
|
||||||
|
|
||||||
|
You can playback a recording, i.e. simulate the series of events from the recording, by pressing "Play Latest Recording", which plays the latest created recording, or by pressing "Play Custom Recording", which lets you choose a recording to play. You can easily replay the most recently played recording by pressing "Replay Recording".
|
||||||
|
|
||||||
|
To stop the app mid-playback, just press `shift`+`esc` on your keyboard.
|
||||||
|
|
||||||
|
### Misc
|
||||||
|
|
||||||
|
To quit the app, you just press the "Quit" option.
|
||||||
|
|
||||||
|
## Recording Format
|
||||||
|
|
||||||
|
Recordings are stored in `Documents/DuckTrack_Recordings`. Each recording is a directory containing:
|
||||||
|
|
||||||
|
1. `events.jsonl` file - sequence of all computer actions that happened. A sample event may look like this:
|
||||||
|
```json
|
||||||
|
{"time_stamp": 1234567.89, "action": "move", "x": 69.0, "y": 420.0}
|
||||||
|
```
|
||||||
|
1. `metadata.json` - stores metadata about the computer that made the recording
|
||||||
|
2. `README.md` - stores the description for the recording
|
||||||
|
3. MP4 file - the screen recording from OBS of the recording.
|
||||||
|
|
||||||
|
Here is a [sample recording](example) for further reference.
|
||||||
|
|
||||||
|
## Technical Overview
|
||||||
|
|
||||||
|
<!-- maybe put a nice graphical representation of the app here -->
|
||||||
|
|
||||||
|
*TDB*
|
||||||
|
|
||||||
|
## Known Bugs
|
||||||
|
|
||||||
|
- After doing lots of playbacks on macOS, a segfault will occur.
|
||||||
|
- Mouse movement is not captured when the current application is using raw input, i.e. video games.
|
||||||
|
- OBS may not open in the background properly on some Linux machines.
|
||||||
|
|
||||||
|
## Things To Do
|
||||||
|
|
||||||
|
- Add logging
|
||||||
|
- Testing
|
||||||
|
- CI (with builds and testing)
|
||||||
|
- Add way to hide/show window from the app tray (and it saves that as a preference?)
|
||||||
|
- Make saving preferences a thing generally, like with natural scrolling too
|
||||||
BIN
annotation/assets/duck.ico
Normal file
|
After Width: | Height: | Size: 6.5 KiB |
BIN
annotation/assets/duck.png
Normal file
|
After Width: | Height: | Size: 2.4 KiB |
27
annotation/build.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from platform import system
|
||||||
|
from subprocess import CalledProcessError, run
|
||||||
|
|
||||||
|
project_dir = Path(".")
|
||||||
|
assets_dir = project_dir / "assets"
|
||||||
|
main_py = project_dir / "main.py"
|
||||||
|
icon_file = assets_dir / ("duck.ico" if system() == "Windows" else "duck.png")
|
||||||
|
|
||||||
|
for dir_to_remove in ["dist", "build"]:
|
||||||
|
dir_path = project_dir / dir_to_remove
|
||||||
|
if dir_path.exists():
|
||||||
|
shutil.rmtree(dir_path)
|
||||||
|
|
||||||
|
pyinstaller_cmd = [
|
||||||
|
"pyinstaller", "--onefile", "--windowed",
|
||||||
|
f"--add-data={assets_dir}{';' if system() == 'Windows' else ':'}{assets_dir}",
|
||||||
|
f"--name=DuckTrack", f"--icon={icon_file}", str(main_py)
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
run(pyinstaller_cmd, check=True)
|
||||||
|
except CalledProcessError as e:
|
||||||
|
print("An error occurred while running PyInstaller:", e)
|
||||||
|
sys.exit(1)
|
||||||
1
annotation/ducktrack/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .app import MainInterface
|
||||||
251
annotation/ducktrack/app.py
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from platform import system
|
||||||
|
|
||||||
|
from PyQt6.QtCore import QTimer, pyqtSlot
|
||||||
|
from PyQt6.QtGui import QAction, QIcon
|
||||||
|
from PyQt6.QtWidgets import (QApplication, QCheckBox, QDialog, QFileDialog,
|
||||||
|
QFormLayout, QLabel, QLineEdit, QMenu,
|
||||||
|
QMessageBox, QPushButton, QSystemTrayIcon,
|
||||||
|
QTextEdit, QVBoxLayout, QWidget)
|
||||||
|
|
||||||
|
from .obs_client import close_obs, is_obs_running, open_obs
|
||||||
|
from .playback import Player, get_latest_recording
|
||||||
|
from .recorder import Recorder
|
||||||
|
from .util import get_recordings_dir, open_file
|
||||||
|
|
||||||
|
|
||||||
|
class TitleDescriptionDialog(QDialog):
|
||||||
|
def __init__(self, parent=None):
|
||||||
|
super().__init__(parent)
|
||||||
|
|
||||||
|
self.setWindowTitle("Recording Details")
|
||||||
|
|
||||||
|
layout = QVBoxLayout(self)
|
||||||
|
|
||||||
|
self.form_layout = QFormLayout()
|
||||||
|
|
||||||
|
self.title_label = QLabel("Title:")
|
||||||
|
self.title_input = QLineEdit(self)
|
||||||
|
self.form_layout.addRow(self.title_label, self.title_input)
|
||||||
|
|
||||||
|
self.description_label = QLabel("Description:")
|
||||||
|
self.description_input = QTextEdit(self)
|
||||||
|
self.form_layout.addRow(self.description_label, self.description_input)
|
||||||
|
|
||||||
|
layout.addLayout(self.form_layout)
|
||||||
|
|
||||||
|
self.submit_button = QPushButton("Save", self)
|
||||||
|
self.submit_button.clicked.connect(self.accept)
|
||||||
|
layout.addWidget(self.submit_button)
|
||||||
|
|
||||||
|
def get_values(self):
|
||||||
|
return self.title_input.text(), self.description_input.toPlainText()
|
||||||
|
|
||||||
|
class MainInterface(QWidget):
|
||||||
|
def __init__(self, app: QApplication):
|
||||||
|
super().__init__()
|
||||||
|
self.tray = QSystemTrayIcon(QIcon(resource_path("assets/duck.png")))
|
||||||
|
self.tray.show()
|
||||||
|
|
||||||
|
self.app = app
|
||||||
|
|
||||||
|
self.init_tray()
|
||||||
|
self.init_window()
|
||||||
|
|
||||||
|
if not is_obs_running():
|
||||||
|
self.obs_process = open_obs()
|
||||||
|
|
||||||
|
def init_window(self):
|
||||||
|
self.setWindowTitle("DuckTrack")
|
||||||
|
layout = QVBoxLayout(self)
|
||||||
|
|
||||||
|
self.toggle_record_button = QPushButton("Start Recording", self)
|
||||||
|
self.toggle_record_button.clicked.connect(self.toggle_record)
|
||||||
|
layout.addWidget(self.toggle_record_button)
|
||||||
|
|
||||||
|
self.toggle_pause_button = QPushButton("Pause Recording", self)
|
||||||
|
self.toggle_pause_button.clicked.connect(self.toggle_pause)
|
||||||
|
self.toggle_pause_button.setEnabled(False)
|
||||||
|
layout.addWidget(self.toggle_pause_button)
|
||||||
|
|
||||||
|
self.show_recordings_button = QPushButton("Show Recordings", self)
|
||||||
|
self.show_recordings_button.clicked.connect(lambda: open_file(get_recordings_dir()))
|
||||||
|
layout.addWidget(self.show_recordings_button)
|
||||||
|
|
||||||
|
self.play_latest_button = QPushButton("Play Latest Recording", self)
|
||||||
|
self.play_latest_button.clicked.connect(self.play_latest_recording)
|
||||||
|
layout.addWidget(self.play_latest_button)
|
||||||
|
|
||||||
|
self.play_custom_button = QPushButton("Play Custom Recording", self)
|
||||||
|
self.play_custom_button.clicked.connect(self.play_custom_recording)
|
||||||
|
layout.addWidget(self.play_custom_button)
|
||||||
|
|
||||||
|
self.replay_recording_button = QPushButton("Replay Recording", self)
|
||||||
|
self.replay_recording_button.clicked.connect(self.replay_recording)
|
||||||
|
self.replay_recording_button.setEnabled(False)
|
||||||
|
layout.addWidget(self.replay_recording_button)
|
||||||
|
|
||||||
|
self.quit_button = QPushButton("Quit", self)
|
||||||
|
self.quit_button.clicked.connect(self.quit)
|
||||||
|
layout.addWidget(self.quit_button)
|
||||||
|
|
||||||
|
self.natural_scrolling_checkbox = QCheckBox("Natural Scrolling", self, checked=system() == "Darwin")
|
||||||
|
layout.addWidget(self.natural_scrolling_checkbox)
|
||||||
|
|
||||||
|
self.natural_scrolling_checkbox.stateChanged.connect(self.toggle_natural_scrolling)
|
||||||
|
|
||||||
|
self.setLayout(layout)
|
||||||
|
|
||||||
|
def init_tray(self):
|
||||||
|
self.menu = QMenu()
|
||||||
|
self.tray.setContextMenu(self.menu)
|
||||||
|
|
||||||
|
self.toggle_record_action = QAction("Start Recording")
|
||||||
|
self.toggle_record_action.triggered.connect(self.toggle_record)
|
||||||
|
self.menu.addAction(self.toggle_record_action)
|
||||||
|
|
||||||
|
self.toggle_pause_action = QAction("Pause Recording")
|
||||||
|
self.toggle_pause_action.triggered.connect(self.toggle_pause)
|
||||||
|
self.toggle_pause_action.setVisible(False)
|
||||||
|
self.menu.addAction(self.toggle_pause_action)
|
||||||
|
|
||||||
|
self.show_recordings_action = QAction("Show Recordings")
|
||||||
|
self.show_recordings_action.triggered.connect(lambda: open_file(get_recordings_dir()))
|
||||||
|
self.menu.addAction(self.show_recordings_action)
|
||||||
|
|
||||||
|
self.play_latest_action = QAction("Play Latest Recording")
|
||||||
|
self.play_latest_action.triggered.connect(self.play_latest_recording)
|
||||||
|
self.menu.addAction(self.play_latest_action)
|
||||||
|
|
||||||
|
self.play_custom_action = QAction("Play Custom Recording")
|
||||||
|
self.play_custom_action.triggered.connect(self.play_custom_recording)
|
||||||
|
self.menu.addAction(self.play_custom_action)
|
||||||
|
|
||||||
|
self.replay_recording_action = QAction("Replay Recording")
|
||||||
|
self.replay_recording_action.triggered.connect(self.replay_recording)
|
||||||
|
self.menu.addAction(self.replay_recording_action)
|
||||||
|
self.replay_recording_action.setVisible(False)
|
||||||
|
|
||||||
|
self.quit_action = QAction("Quit")
|
||||||
|
self.quit_action.triggered.connect(self.quit)
|
||||||
|
self.menu.addAction(self.quit_action)
|
||||||
|
|
||||||
|
self.menu.addSeparator()
|
||||||
|
|
||||||
|
self.natural_scrolling_option = QAction("Natural Scrolling", checkable=True, checked=system() == "Darwin")
|
||||||
|
self.natural_scrolling_option.triggered.connect(self.toggle_natural_scrolling)
|
||||||
|
self.menu.addAction(self.natural_scrolling_option)
|
||||||
|
|
||||||
|
@pyqtSlot()
|
||||||
|
def replay_recording(self):
|
||||||
|
player = Player()
|
||||||
|
if hasattr(self, "last_played_recording_path"):
|
||||||
|
player.play(self.last_played_recording_path)
|
||||||
|
else:
|
||||||
|
self.display_error_message("No recording has been played yet!")
|
||||||
|
|
||||||
|
@pyqtSlot()
|
||||||
|
def play_latest_recording(self):
|
||||||
|
player = Player()
|
||||||
|
recording_path = get_latest_recording()
|
||||||
|
self.last_played_recording_path = recording_path
|
||||||
|
self.replay_recording_action.setVisible(True)
|
||||||
|
self.replay_recording_button.setEnabled(True)
|
||||||
|
player.play(recording_path)
|
||||||
|
|
||||||
|
@pyqtSlot()
|
||||||
|
def play_custom_recording(self):
|
||||||
|
player = Player()
|
||||||
|
directory = QFileDialog.getExistingDirectory(None, "Select Recording", get_recordings_dir())
|
||||||
|
if directory:
|
||||||
|
self.last_played_recording_path = directory
|
||||||
|
self.replay_recording_button.setEnabled(True)
|
||||||
|
self.replay_recording_action.setVisible(True)
|
||||||
|
player.play(directory)
|
||||||
|
|
||||||
|
@pyqtSlot()
|
||||||
|
def quit(self):
|
||||||
|
if hasattr(self, "recorder_thread"):
|
||||||
|
self.toggle_record()
|
||||||
|
if hasattr(self, "obs_process"):
|
||||||
|
close_obs(self.obs_process)
|
||||||
|
self.app.quit()
|
||||||
|
|
||||||
|
def closeEvent(self, event):
|
||||||
|
self.quit()
|
||||||
|
|
||||||
|
@pyqtSlot()
|
||||||
|
def toggle_natural_scrolling(self):
|
||||||
|
sender = self.sender()
|
||||||
|
|
||||||
|
if sender == self.natural_scrolling_checkbox:
|
||||||
|
state = self.natural_scrolling_checkbox.isChecked()
|
||||||
|
self.natural_scrolling_option.setChecked(state)
|
||||||
|
else:
|
||||||
|
state = self.natural_scrolling_option.isChecked()
|
||||||
|
self.natural_scrolling_checkbox.setChecked(state)
|
||||||
|
|
||||||
|
@pyqtSlot()
|
||||||
|
def toggle_pause(self):
|
||||||
|
if self.recorder_thread._is_paused:
|
||||||
|
self.recorder_thread.resume_recording()
|
||||||
|
self.toggle_pause_action.setText("Pause Recording")
|
||||||
|
self.toggle_pause_button.setText("Pause Recording")
|
||||||
|
else:
|
||||||
|
self.recorder_thread.pause_recording()
|
||||||
|
self.toggle_pause_action.setText("Resume Recording")
|
||||||
|
self.toggle_pause_button.setText("Resume Recording")
|
||||||
|
|
||||||
|
@pyqtSlot()
|
||||||
|
def toggle_record(self):
|
||||||
|
if not hasattr(self, "recorder_thread"):
|
||||||
|
self.recorder_thread = Recorder(natural_scrolling=self.natural_scrolling_checkbox.isChecked())
|
||||||
|
self.recorder_thread.recording_stopped.connect(self.on_recording_stopped)
|
||||||
|
self.recorder_thread.start()
|
||||||
|
self.update_menu(True)
|
||||||
|
else:
|
||||||
|
self.recorder_thread.stop_recording()
|
||||||
|
self.recorder_thread.terminate()
|
||||||
|
|
||||||
|
recording_dir = self.recorder_thread.recording_path
|
||||||
|
|
||||||
|
del self.recorder_thread
|
||||||
|
|
||||||
|
dialog = TitleDescriptionDialog()
|
||||||
|
QTimer.singleShot(0, dialog.raise_)
|
||||||
|
result = dialog.exec()
|
||||||
|
|
||||||
|
if result == QDialog.DialogCode.Accepted:
|
||||||
|
title, description = dialog.get_values()
|
||||||
|
|
||||||
|
if title:
|
||||||
|
renamed_dir = os.path.join(os.path.dirname(recording_dir), title)
|
||||||
|
os.rename(recording_dir, renamed_dir)
|
||||||
|
|
||||||
|
with open(os.path.join(renamed_dir, 'README.md'), 'w') as f:
|
||||||
|
f.write(description)
|
||||||
|
|
||||||
|
self.on_recording_stopped()
|
||||||
|
|
||||||
|
@pyqtSlot()
|
||||||
|
def on_recording_stopped(self):
|
||||||
|
self.update_menu(False)
|
||||||
|
|
||||||
|
def update_menu(self, is_recording: bool):
|
||||||
|
self.toggle_record_button.setText("Stop Recording" if is_recording else "Start Recording")
|
||||||
|
self.toggle_record_action.setText("Stop Recording" if is_recording else "Start Recording")
|
||||||
|
|
||||||
|
self.toggle_pause_button.setEnabled(is_recording)
|
||||||
|
self.toggle_pause_action.setVisible(is_recording)
|
||||||
|
|
||||||
|
def display_error_message(self, message):
|
||||||
|
QMessageBox.critical(None, "Error", message)
|
||||||
|
|
||||||
|
def resource_path(relative_path: str) -> str:
|
||||||
|
if hasattr(sys, '_MEIPASS'):
|
||||||
|
base_path = getattr(sys, "_MEIPASS")
|
||||||
|
else:
|
||||||
|
base_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')
|
||||||
|
|
||||||
|
return os.path.join(base_path, relative_path)
|
||||||
33
annotation/ducktrack/keycomb.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from pynput.keyboard import Listener
|
||||||
|
|
||||||
|
from .util import name_to_key
|
||||||
|
|
||||||
|
|
||||||
|
class KeyCombinationListener:
|
||||||
|
"""
|
||||||
|
Simple and bad key combination listener.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.current_keys = set()
|
||||||
|
self.callbacks = {}
|
||||||
|
self.listener = Listener(on_press=self.on_key_press, on_release=self.on_key_release)
|
||||||
|
|
||||||
|
def add_comb(self, keys, callback):
|
||||||
|
self.callbacks[tuple([name_to_key(key_name) for key_name in sorted(keys)])] = callback
|
||||||
|
|
||||||
|
def on_key_press(self, key):
|
||||||
|
self.current_keys.add(key)
|
||||||
|
for comb, callback in self.callbacks.items():
|
||||||
|
if all(k in self.current_keys for k in comb):
|
||||||
|
return callback()
|
||||||
|
|
||||||
|
def on_key_release(self, key):
|
||||||
|
if key in self.current_keys:
|
||||||
|
self.current_keys.remove(key)
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self.listener.start()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self.listener.stop()
|
||||||
60
annotation/ducktrack/metadata.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from platform import uname
|
||||||
|
|
||||||
|
from screeninfo import get_monitors
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataManager:
|
||||||
|
"""
|
||||||
|
Handles various system metadata collection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, recording_path: str, natural_scrolling: bool):
|
||||||
|
self.recording_path = recording_path
|
||||||
|
|
||||||
|
self.metadata = uname()._asdict()
|
||||||
|
|
||||||
|
self.metadata["id"] = uuid.getnode()
|
||||||
|
|
||||||
|
main_monitor = get_monitors()[0]
|
||||||
|
self.metadata["screen_width"] = main_monitor.width
|
||||||
|
self.metadata["screen_height"] = main_monitor.height
|
||||||
|
|
||||||
|
try:
|
||||||
|
match self.metadata["system"]:
|
||||||
|
case "Windows":
|
||||||
|
import wmi
|
||||||
|
for item in wmi.WMI().Win32_ComputerSystem():
|
||||||
|
self.metadata["model"] = item.Model
|
||||||
|
break
|
||||||
|
case "Darwin":
|
||||||
|
import subprocess
|
||||||
|
model = subprocess.check_output(["sysctl", "-n", "hw.model"]).decode().strip()
|
||||||
|
self.metadata["model"] = model
|
||||||
|
case "Linux":
|
||||||
|
with open("/sys/devices/virtual/dmi/id/product_name", "r") as f:
|
||||||
|
self.metadata["model"] = f.read().strip()
|
||||||
|
except:
|
||||||
|
self.metadata["model"] = "Unknown"
|
||||||
|
|
||||||
|
self.metadata["scroll_direction"] = -1 if natural_scrolling else 1
|
||||||
|
|
||||||
|
def save_metadata(self):
|
||||||
|
metadata_path = os.path.join(self.recording_path, "metadata.json")
|
||||||
|
with open(metadata_path, "w") as f:
|
||||||
|
json.dump(self.metadata, f, indent=4)
|
||||||
|
|
||||||
|
def collect(self):
|
||||||
|
self.metadata["start_time"] = self._get_time_stamp()
|
||||||
|
|
||||||
|
def end_collect(self):
|
||||||
|
self.metadata["stop_time"] = self._get_time_stamp()
|
||||||
|
|
||||||
|
def add_obs_record_state_timings(self, record_state_events: dict[str, float]):
|
||||||
|
self.metadata["obs_record_state_timings"] = record_state_events
|
||||||
|
|
||||||
|
def _get_time_stamp(self):
|
||||||
|
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
200
annotation/ducktrack/obs_client.py
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from platform import system
|
||||||
|
|
||||||
|
import obsws_python as obs
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
|
||||||
|
def is_obs_running() -> bool:
|
||||||
|
try:
|
||||||
|
for process in psutil.process_iter(attrs=["pid", "name"]):
|
||||||
|
if "obs" in process.info["name"].lower():
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except:
|
||||||
|
raise Exception("Could not check if OBS is running already. Please check manually.")
|
||||||
|
|
||||||
|
def close_obs(obs_process: subprocess.Popen):
|
||||||
|
if obs_process:
|
||||||
|
obs_process.terminate()
|
||||||
|
try:
|
||||||
|
obs_process.wait(timeout=5)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
obs_process.kill()
|
||||||
|
|
||||||
|
def find_obs() -> str:
|
||||||
|
common_paths = {
|
||||||
|
"Windows": [
|
||||||
|
"C:\\Program Files\\obs-studio\\bin\\64bit\\obs64.exe",
|
||||||
|
"C:\\Program Files (x86)\\obs-studio\\bin\\32bit\\obs32.exe"
|
||||||
|
],
|
||||||
|
"Darwin": [
|
||||||
|
"/Applications/OBS.app/Contents/MacOS/OBS",
|
||||||
|
"/opt/homebrew/bin/obs"
|
||||||
|
],
|
||||||
|
"Linux": [
|
||||||
|
"/usr/bin/obs",
|
||||||
|
"/usr/local/bin/obs"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
for path in common_paths.get(system(), []):
|
||||||
|
if os.path.exists(path):
|
||||||
|
return path
|
||||||
|
|
||||||
|
try:
|
||||||
|
if system() == "Windows":
|
||||||
|
obs_path = subprocess.check_output("where obs", shell=True).decode().strip()
|
||||||
|
else:
|
||||||
|
obs_path = subprocess.check_output("which obs", shell=True).decode().strip()
|
||||||
|
|
||||||
|
if os.path.exists(obs_path):
|
||||||
|
return obs_path
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return "obs"
|
||||||
|
|
||||||
|
def open_obs() -> subprocess.Popen:
|
||||||
|
try:
|
||||||
|
obs_path = find_obs()
|
||||||
|
if system() == "Windows":
|
||||||
|
# you have to change the working directory first for OBS to find the correct locale on windows
|
||||||
|
os.chdir(os.path.dirname(obs_path))
|
||||||
|
obs_path = os.path.basename(obs_path)
|
||||||
|
return subprocess.Popen([obs_path, "--startreplaybuffer", "--minimize-to-tray"])
|
||||||
|
except:
|
||||||
|
raise Exception("Failed to find OBS, please open OBS manually.")
|
||||||
|
|
||||||
|
class OBSClient:
|
||||||
|
"""
|
||||||
|
Controls the OBS client via the OBS websocket.
|
||||||
|
Sets all the correct settings for recording.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
recording_path: str,
|
||||||
|
metadata: dict,
|
||||||
|
fps=30,
|
||||||
|
output_width=1280,
|
||||||
|
output_height=720,
|
||||||
|
):
|
||||||
|
self.metadata = metadata
|
||||||
|
|
||||||
|
self.req_client = obs.ReqClient()
|
||||||
|
self.event_client = obs.EventClient()
|
||||||
|
|
||||||
|
self.record_state_events = {}
|
||||||
|
|
||||||
|
def on_record_state_changed(data):
|
||||||
|
output_state = data.output_state
|
||||||
|
print("record state changed:", output_state)
|
||||||
|
if output_state not in self.record_state_events:
|
||||||
|
self.record_state_events[output_state] = []
|
||||||
|
self.record_state_events[output_state].append(time.perf_counter())
|
||||||
|
|
||||||
|
self.event_client.callback.register(on_record_state_changed)
|
||||||
|
|
||||||
|
self.old_profile = self.req_client.get_profile_list().current_profile_name
|
||||||
|
|
||||||
|
if "computer_tracker" not in self.req_client.get_profile_list().profiles:
|
||||||
|
self.req_client.create_profile("computer_tracker")
|
||||||
|
else:
|
||||||
|
self.req_client.set_current_profile("computer_tracker")
|
||||||
|
self.req_client.create_profile("temp")
|
||||||
|
self.req_client.remove_profile("temp")
|
||||||
|
self.req_client.set_current_profile("computer_tracker")
|
||||||
|
|
||||||
|
base_width = metadata["screen_width"]
|
||||||
|
base_height = metadata["screen_height"]
|
||||||
|
|
||||||
|
if metadata["system"] == "Darwin":
|
||||||
|
# for retina displays
|
||||||
|
# TODO: check if external displays are messed up by this
|
||||||
|
base_width *= 2
|
||||||
|
base_height *= 2
|
||||||
|
|
||||||
|
scaled_width, scaled_height = _scale_resolution(base_width, base_height, output_width, output_height)
|
||||||
|
|
||||||
|
self.req_client.set_profile_parameter("Video", "BaseCX", str(base_width))
|
||||||
|
self.req_client.set_profile_parameter("Video", "BaseCY", str(base_height))
|
||||||
|
self.req_client.set_profile_parameter("Video", "OutputCX", str(scaled_width))
|
||||||
|
self.req_client.set_profile_parameter("Video", "OutputCY", str(scaled_height))
|
||||||
|
self.req_client.set_profile_parameter("Video", "ScaleType", "lanczos")
|
||||||
|
|
||||||
|
self.req_client.set_profile_parameter("AdvOut", "RescaleRes", f"{base_width}x{base_height}")
|
||||||
|
self.req_client.set_profile_parameter("AdvOut", "RecRescaleRes", f"{base_width}x{base_height}")
|
||||||
|
self.req_client.set_profile_parameter("AdvOut", "FFRescaleRes", f"{base_width}x{base_height}")
|
||||||
|
|
||||||
|
self.req_client.set_profile_parameter("Video", "FPSCommon", str(fps))
|
||||||
|
self.req_client.set_profile_parameter("Video", "FPSInt", str(fps))
|
||||||
|
self.req_client.set_profile_parameter("Video", "FPSNum", str(fps))
|
||||||
|
self.req_client.set_profile_parameter("Video", "FPSDen", "1")
|
||||||
|
|
||||||
|
self.req_client.set_profile_parameter("SimpleOutput", "RecFormat2", "mp4")
|
||||||
|
|
||||||
|
bitrate = int(_get_bitrate_mbps(scaled_width, scaled_height, fps=fps) * 1000 / 50) * 50
|
||||||
|
self.req_client.set_profile_parameter("SimpleOutput", "VBitrate", str(bitrate))
|
||||||
|
|
||||||
|
# do this in order to get pause & resume
|
||||||
|
self.req_client.set_profile_parameter("SimpleOutput", "RecQuality", "Small")
|
||||||
|
|
||||||
|
self.req_client.set_profile_parameter("SimpleOutput", "FilePath", recording_path)
|
||||||
|
|
||||||
|
# TODO: not all OBS configs have this, maybe just instruct the user to mute themselves
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.req_client.set_input_mute("Mic/Aux", muted=True)
|
||||||
|
except obs.error.OBSSDKRequestError :
|
||||||
|
# In case there is no Mic/Aux input, this will throw an error
|
||||||
|
pass
|
||||||
|
|
||||||
|
def start_recording(self):
|
||||||
|
self.req_client.start_record()
|
||||||
|
|
||||||
|
def stop_recording(self):
|
||||||
|
self.req_client.stop_record()
|
||||||
|
self.req_client.set_current_profile(self.old_profile) # restore old profile
|
||||||
|
|
||||||
|
def pause_recording(self):
|
||||||
|
self.req_client.pause_record()
|
||||||
|
|
||||||
|
def resume_recording(self):
|
||||||
|
self.req_client.resume_record()
|
||||||
|
|
||||||
|
def _get_bitrate_mbps(width: int, height: int, fps=30) -> float:
|
||||||
|
"""
|
||||||
|
Gets the YouTube recommended bitrate in Mbps for a given resolution and framerate.
|
||||||
|
Refer to https://support.google.com/youtube/answer/1722171?hl=en#zippy=%2Cbitrate
|
||||||
|
"""
|
||||||
|
resolutions = {
|
||||||
|
(7680, 4320): {30: 120, 60: 180},
|
||||||
|
(3840, 2160): {30: 40, 60: 60.5},
|
||||||
|
(2160, 1440): {30: 16, 60: 24},
|
||||||
|
(1920, 1080): {30: 8, 60: 12},
|
||||||
|
(1280, 720): {30: 5, 60: 7.5},
|
||||||
|
(640, 480): {30: 2.5, 60: 4},
|
||||||
|
(480, 360): {30: 1, 60: 1.5}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (width, height) in resolutions:
|
||||||
|
return resolutions[(width, height)].get(fps)
|
||||||
|
else:
|
||||||
|
# approximate the bitrate using a simple linear model
|
||||||
|
area = width * height
|
||||||
|
multiplier = 3.5982188179592543e-06 if fps == 30 else 5.396175171097084e-06
|
||||||
|
constant = 2.418399836285939 if fps == 30 else 3.742780056500365
|
||||||
|
return multiplier * area + constant
|
||||||
|
|
||||||
|
def _scale_resolution(base_width: int, base_height: int, target_width: int, target_height: int) -> tuple[int, int]:
|
||||||
|
target_area = target_width * target_height
|
||||||
|
aspect_ratio = base_width / base_height
|
||||||
|
|
||||||
|
scaled_height = int((target_area / aspect_ratio) ** 0.5)
|
||||||
|
scaled_width = int(aspect_ratio * scaled_height)
|
||||||
|
|
||||||
|
return scaled_width, scaled_height
|
||||||
188
annotation/ducktrack/playback.py
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pyautogui
|
||||||
|
from pynput.keyboard import Controller as KeyboardController
|
||||||
|
from pynput.keyboard import Key
|
||||||
|
from pynput.mouse import Button
|
||||||
|
from pynput.mouse import Controller as MouseController
|
||||||
|
|
||||||
|
from .keycomb import KeyCombinationListener
|
||||||
|
from .util import (fix_windows_dpi_scaling, get_recordings_dir, name_to_button,
|
||||||
|
name_to_key)
|
||||||
|
|
||||||
|
pyautogui.PAUSE = 0
|
||||||
|
pyautogui.DARWIN_CATCH_UP_TIME = 0
|
||||||
|
|
||||||
|
class Player:
|
||||||
|
"""
|
||||||
|
Plays back recordings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.stop_playback = False
|
||||||
|
self.listener = KeyCombinationListener()
|
||||||
|
|
||||||
|
def stop_comb_pressed():
|
||||||
|
self.stop_playback = True
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.listener.add_comb(("shift", "esc"), stop_comb_pressed)
|
||||||
|
self.listener.start()
|
||||||
|
|
||||||
|
def play(self, recording_path: str):
|
||||||
|
with open(os.path.join(recording_path, "events.jsonl"), "r") as f:
|
||||||
|
events = [json.loads(line) for line in f.readlines()]
|
||||||
|
|
||||||
|
with open(os.path.join(recording_path, "metadata.json"), "r") as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
|
||||||
|
self.playback(events, metadata)
|
||||||
|
|
||||||
|
def playback(self, events: list[dict], metadata: dict):
|
||||||
|
if metadata["system"] == "Windows":
|
||||||
|
fix_windows_dpi_scaling()
|
||||||
|
|
||||||
|
mouse_controller = MouseController()
|
||||||
|
keyboard_controller = KeyboardController()
|
||||||
|
|
||||||
|
if not events:
|
||||||
|
self.listener.stop()
|
||||||
|
return
|
||||||
|
|
||||||
|
presses_to_skip = 0
|
||||||
|
releases_to_skip = 0
|
||||||
|
|
||||||
|
in_click_sequence = False
|
||||||
|
|
||||||
|
for i, event in enumerate(events):
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
if self.stop_playback:
|
||||||
|
return
|
||||||
|
|
||||||
|
def do_mouse_press(button):
|
||||||
|
for j, second_event in enumerate(events[i+1:]):
|
||||||
|
# make sure the time between mouse clicks is less than 500ms
|
||||||
|
if second_event["time_stamp"] - event["time_stamp"] > 0.5:
|
||||||
|
break
|
||||||
|
|
||||||
|
if "x" in second_event and "y" in second_event:
|
||||||
|
# if the mouse moves out of the click radius/rectangle, it is not a click sequence
|
||||||
|
if math.sqrt((second_event["y"] - event["y"]) ** 2 +
|
||||||
|
(second_event["x"] - event["x"]) ** 2) > 4:
|
||||||
|
break
|
||||||
|
|
||||||
|
if second_event["action"] == "click" and second_event["pressed"]:
|
||||||
|
for k, third_event in enumerate(events[i+j+2:]):
|
||||||
|
if third_event["time_stamp"] - second_event["time_stamp"] > 0.5:
|
||||||
|
break
|
||||||
|
|
||||||
|
if "x" in third_event and "y" in third_event:
|
||||||
|
if math.sqrt((third_event["y"] - event["y"]) ** 2 +
|
||||||
|
(third_event["x"] - event["x"]) ** 2) > 5:
|
||||||
|
break
|
||||||
|
|
||||||
|
if third_event["action"] == "click" and third_event["pressed"]:
|
||||||
|
mouse_controller.click(button, 3)
|
||||||
|
return 2, 2
|
||||||
|
|
||||||
|
mouse_controller.click(button, 2)
|
||||||
|
return 1, 1
|
||||||
|
|
||||||
|
mouse_controller.press(button)
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
if event["action"] == "move":
|
||||||
|
mouse_controller.position = (event["x"], event["y"])
|
||||||
|
|
||||||
|
elif event["action"] == "click":
|
||||||
|
button = name_to_button(event["button"])
|
||||||
|
|
||||||
|
if event["pressed"]:
|
||||||
|
if presses_to_skip == 0:
|
||||||
|
presses, releases = do_mouse_press(button)
|
||||||
|
presses_to_skip += presses
|
||||||
|
releases_to_skip += releases
|
||||||
|
|
||||||
|
if presses > 0:
|
||||||
|
in_click_sequence = True
|
||||||
|
else:
|
||||||
|
presses_to_skip -= 1
|
||||||
|
else:
|
||||||
|
if releases_to_skip == 0:
|
||||||
|
mouse_controller.release(button)
|
||||||
|
|
||||||
|
if in_click_sequence:
|
||||||
|
keyboard_controller.press(Key.shift)
|
||||||
|
mouse_controller.click(Button.left)
|
||||||
|
keyboard_controller.release(Key.shift)
|
||||||
|
in_click_sequence = False
|
||||||
|
else:
|
||||||
|
releases_to_skip -= 1
|
||||||
|
|
||||||
|
elif event["action"] == "scroll":
|
||||||
|
if metadata["system"] == "Windows":
|
||||||
|
# for some reason on windows, pynput scroll is correct but pyautogui is not
|
||||||
|
mouse_controller.scroll(metadata["scroll_direction"] * event["dx"], metadata["scroll_direction"] * event["dy"])
|
||||||
|
else:
|
||||||
|
pyautogui.hscroll(clicks=metadata["scroll_direction"] * event["dx"])
|
||||||
|
pyautogui.vscroll(clicks=metadata["scroll_direction"] * event["dy"])
|
||||||
|
|
||||||
|
elif event["action"] in ["press", "release"]:
|
||||||
|
key = name_to_key(event["name"])
|
||||||
|
if event["action"] == "press":
|
||||||
|
keyboard_controller.press(key)
|
||||||
|
else:
|
||||||
|
keyboard_controller.release(key)
|
||||||
|
|
||||||
|
# sleep for the correct amount of time
|
||||||
|
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
execution_time = end_time - start_time
|
||||||
|
|
||||||
|
if i + 1 < len(events):
|
||||||
|
desired_delay = events[i + 1]["time_stamp"] - event["time_stamp"]
|
||||||
|
delay = desired_delay - execution_time
|
||||||
|
if delay < 0:
|
||||||
|
print(f"warning: behind by {-delay * 1000:.3f} ms")
|
||||||
|
elif delay != 0:
|
||||||
|
wait_until = time.perf_counter() + delay
|
||||||
|
while time.perf_counter() < wait_until:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.listener.stop()
|
||||||
|
|
||||||
|
def get_latest_recording() -> str:
|
||||||
|
recordings_dir = get_recordings_dir()
|
||||||
|
if not os.path.exists(recordings_dir):
|
||||||
|
raise Exception("The recordings directory does not exist")
|
||||||
|
|
||||||
|
recordings = [os.path.join(recordings_dir, f) for f in os.listdir(recordings_dir) if os.path.isdir(os.path.join(recordings_dir, f))]
|
||||||
|
|
||||||
|
if len(recordings) == 0:
|
||||||
|
raise Exception("You have no recordings to play back")
|
||||||
|
|
||||||
|
latest_recording = max(recordings, key=os.path.getctime)
|
||||||
|
|
||||||
|
return latest_recording
|
||||||
|
|
||||||
|
def main():
|
||||||
|
player = Player()
|
||||||
|
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
recording_path = sys.argv[1]
|
||||||
|
else:
|
||||||
|
recording_path = get_latest_recording()
|
||||||
|
|
||||||
|
player.play(recording_path)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
n = 3
|
||||||
|
print("press shift+esc to stop the playback")
|
||||||
|
print(f"starting in {n} seconds...")
|
||||||
|
time.sleep(n)
|
||||||
|
main()
|
||||||
145
annotation/ducktrack/recorder.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from platform import system
|
||||||
|
from queue import Queue
|
||||||
|
|
||||||
|
from pynput import keyboard, mouse
|
||||||
|
from pynput.keyboard import KeyCode
|
||||||
|
from PyQt6.QtCore import QThread, pyqtSignal
|
||||||
|
|
||||||
|
from .metadata import MetadataManager
|
||||||
|
from .obs_client import OBSClient
|
||||||
|
from .util import fix_windows_dpi_scaling, get_recordings_dir
|
||||||
|
|
||||||
|
|
||||||
|
class Recorder(QThread):
|
||||||
|
"""
|
||||||
|
Makes recordings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
recording_stopped = pyqtSignal()
|
||||||
|
|
||||||
|
def __init__(self, natural_scrolling: bool):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if system() == "Windows":
|
||||||
|
fix_windows_dpi_scaling()
|
||||||
|
|
||||||
|
self.recording_path = self._get_recording_path()
|
||||||
|
|
||||||
|
self._is_recording = False
|
||||||
|
self._is_paused = False
|
||||||
|
|
||||||
|
self.event_queue = Queue()
|
||||||
|
self.events_file = open(os.path.join(self.recording_path, "events.jsonl"), "a")
|
||||||
|
|
||||||
|
self.metadata_manager = MetadataManager(
|
||||||
|
recording_path=self.recording_path,
|
||||||
|
natural_scrolling=natural_scrolling
|
||||||
|
)
|
||||||
|
self.obs_client = OBSClient(recording_path=self.recording_path,
|
||||||
|
metadata=self.metadata_manager.metadata)
|
||||||
|
|
||||||
|
self.mouse_listener = mouse.Listener(
|
||||||
|
on_move=self.on_move,
|
||||||
|
on_click=self.on_click,
|
||||||
|
on_scroll=self.on_scroll)
|
||||||
|
|
||||||
|
self.keyboard_listener = keyboard.Listener(
|
||||||
|
on_press=self.on_press,
|
||||||
|
on_release=self.on_release)
|
||||||
|
|
||||||
|
def on_move(self, x, y):
|
||||||
|
if not self._is_paused:
|
||||||
|
self.event_queue.put({"time_stamp": time.perf_counter(),
|
||||||
|
"action": "move",
|
||||||
|
"x": x,
|
||||||
|
"y": y}, block=False)
|
||||||
|
|
||||||
|
def on_click(self, x, y, button, pressed):
|
||||||
|
if not self._is_paused:
|
||||||
|
self.event_queue.put({"time_stamp": time.perf_counter(),
|
||||||
|
"action": "click",
|
||||||
|
"x": x,
|
||||||
|
"y": y,
|
||||||
|
"button": button.name,
|
||||||
|
"pressed": pressed}, block=False)
|
||||||
|
|
||||||
|
def on_scroll(self, x, y, dx, dy):
|
||||||
|
if not self._is_paused:
|
||||||
|
self.event_queue.put({"time_stamp": time.perf_counter(),
|
||||||
|
"action": "scroll",
|
||||||
|
"x": x,
|
||||||
|
"y": y,
|
||||||
|
"dx": dx,
|
||||||
|
"dy": dy}, block=False)
|
||||||
|
|
||||||
|
def on_press(self, key):
|
||||||
|
if not self._is_paused:
|
||||||
|
self.event_queue.put({"time_stamp": time.perf_counter(),
|
||||||
|
"action": "press",
|
||||||
|
"name": key.char if type(key) == KeyCode else key.name}, block=False)
|
||||||
|
|
||||||
|
def on_release(self, key):
|
||||||
|
if not self._is_paused:
|
||||||
|
self.event_queue.put({"time_stamp": time.perf_counter(),
|
||||||
|
"action": "release",
|
||||||
|
"name": key.char if type(key) == KeyCode else key.name}, block=False)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
self._is_recording = True
|
||||||
|
|
||||||
|
self.metadata_manager.collect()
|
||||||
|
self.obs_client.start_recording()
|
||||||
|
|
||||||
|
self.mouse_listener.start()
|
||||||
|
self.keyboard_listener.start()
|
||||||
|
|
||||||
|
while self._is_recording:
|
||||||
|
event = self.event_queue.get()
|
||||||
|
self.events_file.write(json.dumps(event) + "\n")
|
||||||
|
|
||||||
|
def stop_recording(self):
|
||||||
|
if self._is_recording:
|
||||||
|
self._is_recording = False
|
||||||
|
|
||||||
|
self.metadata_manager.end_collect()
|
||||||
|
|
||||||
|
self.mouse_listener.stop()
|
||||||
|
self.keyboard_listener.stop()
|
||||||
|
|
||||||
|
self.obs_client.stop_recording()
|
||||||
|
self.metadata_manager.add_obs_record_state_timings(self.obs_client.record_state_events)
|
||||||
|
self.events_file.close()
|
||||||
|
self.metadata_manager.save_metadata()
|
||||||
|
|
||||||
|
self.recording_stopped.emit()
|
||||||
|
|
||||||
|
def pause_recording(self):
|
||||||
|
if not self._is_paused and self._is_recording:
|
||||||
|
self._is_paused = True
|
||||||
|
self.obs_client.pause_recording()
|
||||||
|
self.event_queue.put({"time_stamp": time.perf_counter(),
|
||||||
|
"action": "pause"}, block=False)
|
||||||
|
|
||||||
|
def resume_recording(self):
|
||||||
|
if self._is_paused and self._is_recording:
|
||||||
|
self._is_paused = False
|
||||||
|
self.obs_client.resume_recording()
|
||||||
|
self.event_queue.put({"time_stamp": time.perf_counter(),
|
||||||
|
"action": "resume"}, block=False)
|
||||||
|
|
||||||
|
def _get_recording_path(self) -> str:
|
||||||
|
recordings_dir = get_recordings_dir()
|
||||||
|
|
||||||
|
if not os.path.exists(recordings_dir):
|
||||||
|
os.mkdir(recordings_dir)
|
||||||
|
|
||||||
|
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||||
|
|
||||||
|
recording_path = os.path.join(recordings_dir, f"recording-{current_time}")
|
||||||
|
os.mkdir(recording_path)
|
||||||
|
|
||||||
|
return recording_path
|
||||||
38
annotation/ducktrack/util.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from pynput.keyboard import Key, KeyCode
|
||||||
|
from pynput.mouse import Button
|
||||||
|
|
||||||
|
|
||||||
|
def name_to_key(name: str) -> Key | KeyCode:
|
||||||
|
try:
|
||||||
|
return getattr(Key, name)
|
||||||
|
except AttributeError:
|
||||||
|
return KeyCode.from_char(name)
|
||||||
|
|
||||||
|
def name_to_button(name: str) -> Button:
|
||||||
|
return getattr(Button, name)
|
||||||
|
|
||||||
|
def get_recordings_dir() -> str:
|
||||||
|
documents_folder = Path.home() / 'Documents' / 'DuckTrack_Recordings'
|
||||||
|
return str(documents_folder)
|
||||||
|
|
||||||
|
def fix_windows_dpi_scaling():
|
||||||
|
"""
|
||||||
|
Fixes DPI scaling issues with legacy windows applications
|
||||||
|
Reference: https://pynput.readthedocs.io/en/latest/mouse.html#ensuring-consistent-coordinates-between-listener-and-controller-on-windows
|
||||||
|
"""
|
||||||
|
import ctypes
|
||||||
|
PROCESS_PER_MONITOR_DPI_AWARE = 2
|
||||||
|
ctypes.windll.shcore.SetProcessDpiAwareness(PROCESS_PER_MONITOR_DPI_AWARE)
|
||||||
|
|
||||||
|
def open_file(path):
|
||||||
|
if platform.system() == "Windows":
|
||||||
|
os.startfile(path)
|
||||||
|
elif platform.system() == "Darwin":
|
||||||
|
subprocess.Popen(["open", path])
|
||||||
|
else:
|
||||||
|
subprocess.Popen(["xdg-open", path])
|
||||||
48
annotation/experiments/delays/delay.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
import glob
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import seaborn as sns
|
||||||
|
from scipy.stats import sem, t
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_confidence_interval(data, confidence=0.95):
|
||||||
|
n = len(data)
|
||||||
|
m = np.mean(data)
|
||||||
|
std_err = sem(data)
|
||||||
|
h = std_err * t.ppf((1 + confidence) / 2, n - 1)
|
||||||
|
return m, m-h, m+h
|
||||||
|
|
||||||
|
runs = glob.glob("run*.txt")
|
||||||
|
TOTAL_EVENTS = 22509
|
||||||
|
percent_delays = []
|
||||||
|
all_delays = []
|
||||||
|
|
||||||
|
for run in runs:
|
||||||
|
with open(run, "r") as f:
|
||||||
|
delays = [float(line.split()[3]) for line in f if float(line.split()[3]) > 0] # consider only positive delays
|
||||||
|
percent_delays.append((len(delays) / TOTAL_EVENTS) * 100)
|
||||||
|
all_delays.extend(delays)
|
||||||
|
|
||||||
|
average_percent_delays = np.mean(percent_delays)
|
||||||
|
confidence_interval_percent_delays = calculate_confidence_interval(percent_delays)
|
||||||
|
print(f"Average percentage of delayed events across all runs: {average_percent_delays:.2f}%")
|
||||||
|
print(f"95% Confidence interval: ({confidence_interval_percent_delays[1]:.2f}%, {confidence_interval_percent_delays[2]:.2f}%)")
|
||||||
|
|
||||||
|
if all_delays:
|
||||||
|
mean_delay = np.mean(all_delays)
|
||||||
|
confidence_interval_delays = calculate_confidence_interval(all_delays)
|
||||||
|
print(f"Mean delay time: {mean_delay:.2f}")
|
||||||
|
print(f"95% Confidence interval for delay time: ({confidence_interval_delays[1]:.2f}, {confidence_interval_delays[2]:.2f})")
|
||||||
|
else:
|
||||||
|
print("No delay data available for calculation.")
|
||||||
|
|
||||||
|
sns.histplot(all_delays, bins=30, kde=False)
|
||||||
|
plt.xlabel('Delay Time (ms)')
|
||||||
|
plt.ylabel('Frequency')
|
||||||
|
plt.yscale('log')
|
||||||
|
plt.title('Histogram of Delay Times (macOS)')
|
||||||
|
|
||||||
|
plt.savefig('delays.png', dpi=300)
|
||||||
|
|
||||||
|
plt.show()
|
||||||
110
annotation/experiments/drawing/drawing.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
import glob
|
||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import scipy.stats as stats
|
||||||
|
from skimage.metrics import structural_similarity as ssim
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# use this: https://sketch.io
|
||||||
|
|
||||||
|
def calculate_rmse(imageA, imageB):
|
||||||
|
err = np.sum((imageA - imageB) ** 2)
|
||||||
|
err /= float(imageA.shape[0] * imageA.shape[1])
|
||||||
|
return np.sqrt(err)
|
||||||
|
|
||||||
|
def compare_images(ground_truth_path, sample_paths):
|
||||||
|
results = []
|
||||||
|
gt_image = cv2.imread(ground_truth_path, cv2.IMREAD_GRAYSCALE)
|
||||||
|
|
||||||
|
if gt_image is None:
|
||||||
|
raise ValueError("Ground truth image could not be read. Please check the file path.")
|
||||||
|
|
||||||
|
gt_image = gt_image.astype("float") / 255.0
|
||||||
|
|
||||||
|
for path in tqdm(sample_paths):
|
||||||
|
sample_image = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
|
||||||
|
|
||||||
|
if sample_image is None:
|
||||||
|
print(f"WARNING: Sample image at path {path} could not be read. Skipping this image.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
sample_image = sample_image.astype("float") / 255.0
|
||||||
|
|
||||||
|
rmse_value = calculate_rmse(gt_image, sample_image)
|
||||||
|
ssim_value, _ = ssim(gt_image, sample_image, full=True, data_range=1) # Corrected line
|
||||||
|
|
||||||
|
diff_mask = cv2.absdiff(gt_image, sample_image)
|
||||||
|
|
||||||
|
# plt.imshow(diff_mask * 255, cmap='gray')
|
||||||
|
# plt.title(f'Difference Mask for {os.path.basename(path)}\nRMSE: {rmse_value:.5f} - SSIM: {ssim_value:.5f}')
|
||||||
|
# plt.show()
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
'path': path,
|
||||||
|
'rmse': rmse_value,
|
||||||
|
'ssim': ssim_value,
|
||||||
|
'diff_mask': diff_mask
|
||||||
|
})
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
ground_truth = 'ground_truth.png'
|
||||||
|
sample_images = glob.glob("samples/*.png")
|
||||||
|
|
||||||
|
results = compare_images(ground_truth, sample_images)
|
||||||
|
|
||||||
|
for res in results:
|
||||||
|
print(f"Image: {res['path']} - RMSE: {res['rmse']} - SSIM: {res['ssim']}")
|
||||||
|
|
||||||
|
def calculate_confidence_interval(data, confidence_level=0.95):
|
||||||
|
mean = np.mean(data)
|
||||||
|
sem = stats.sem(data)
|
||||||
|
df = len(data) - 1
|
||||||
|
me = sem * stats.t.ppf((1 + confidence_level) / 2, df)
|
||||||
|
return mean - me, mean + me
|
||||||
|
|
||||||
|
rmse_values = [res['rmse'] for res in results]
|
||||||
|
ssim_values = [res['ssim'] for res in results]
|
||||||
|
|
||||||
|
rmse_mean = np.mean(rmse_values)
|
||||||
|
rmse_median = np.median(rmse_values)
|
||||||
|
rmse_stdev = np.std(rmse_values, ddof=1)
|
||||||
|
|
||||||
|
ssim_mean = np.mean(ssim_values)
|
||||||
|
ssim_median = np.median(ssim_values)
|
||||||
|
ssim_stdev = np.std(ssim_values, ddof=1)
|
||||||
|
|
||||||
|
rmse_ci = calculate_confidence_interval(rmse_values)
|
||||||
|
ssim_ci = calculate_confidence_interval(ssim_values)
|
||||||
|
|
||||||
|
print(f"\nRMSE - Mean: {rmse_mean}, Median: {rmse_median}, Std Dev: {rmse_stdev}, 95% CI: {rmse_ci}")
|
||||||
|
print(f"SSIM - Mean: {ssim_mean}, Median: {ssim_median}, Std Dev: {ssim_stdev}, 95% CI: {ssim_ci}")
|
||||||
|
|
||||||
|
print(f"RMSE: {rmse_mean} ± {rmse_ci[1] - rmse_mean}")
|
||||||
|
print(f"SSIM: {ssim_mean} ± {ssim_ci[1] - ssim_mean}")
|
||||||
|
|
||||||
|
def save_average_diff_map(results, save_path='average_diff_map.png'):
|
||||||
|
if not results:
|
||||||
|
print("No results available to create an average diff map.")
|
||||||
|
return
|
||||||
|
|
||||||
|
avg_diff_map = None
|
||||||
|
|
||||||
|
for res in results:
|
||||||
|
if avg_diff_map is None:
|
||||||
|
avg_diff_map = np.zeros_like(res['diff_mask'])
|
||||||
|
|
||||||
|
avg_diff_map += res['diff_mask']
|
||||||
|
|
||||||
|
avg_diff_map /= len(results)
|
||||||
|
|
||||||
|
avg_diff_map = (avg_diff_map * 255).astype(np.uint8)
|
||||||
|
|
||||||
|
cv2.imwrite(save_path, avg_diff_map)
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
save_average_diff_map(results)
|
||||||
4
annotation/experiments/recaptcha/recaptcha.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
success = 10
|
||||||
|
total = 10
|
||||||
|
|
||||||
|
print(success / total)
|
||||||
48
annotation/experiments/sleep_testing/calc_errors.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
import csv
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def check_sleep(duration, sleep_function):
|
||||||
|
start = time.perf_counter()
|
||||||
|
sleep_function(duration)
|
||||||
|
end = time.perf_counter()
|
||||||
|
elapsed = end - start
|
||||||
|
return abs(elapsed - duration)
|
||||||
|
|
||||||
|
def busy_sleep(duration):
|
||||||
|
end_time = time.perf_counter() + duration
|
||||||
|
while time.perf_counter() < end_time:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def measure_accuracy(sleep_function, durations, iterations=100):
|
||||||
|
average_errors = []
|
||||||
|
for duration in tqdm(durations):
|
||||||
|
errors = [check_sleep(duration, sleep_function) for _ in range(iterations)]
|
||||||
|
average_error = np.mean(errors)
|
||||||
|
average_errors.append(average_error)
|
||||||
|
return average_errors
|
||||||
|
|
||||||
|
durations = np.arange(0.001, 0.101, 0.001) # From 1ms to 100ms in 1ms increments
|
||||||
|
iterations = 100
|
||||||
|
|
||||||
|
sleep_errors = measure_accuracy(time.sleep, durations, iterations)
|
||||||
|
busy_sleep_errors = measure_accuracy(busy_sleep, durations, iterations)
|
||||||
|
|
||||||
|
def save_to_csv(filename, durations, sleep_errors, busy_sleep_errors):
|
||||||
|
with open(filename, 'w', newline='') as csvfile:
|
||||||
|
fieldnames = ['duration', 'sleep_error', 'busy_sleep_error']
|
||||||
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||||
|
|
||||||
|
writer.writeheader()
|
||||||
|
for duration, sleep_error, busy_sleep_error in zip(durations, sleep_errors, busy_sleep_errors):
|
||||||
|
writer.writerow({
|
||||||
|
'duration': duration,
|
||||||
|
'sleep_error': sleep_error,
|
||||||
|
'busy_sleep_error': busy_sleep_error
|
||||||
|
})
|
||||||
|
print("Data saved to", filename)
|
||||||
|
|
||||||
|
save_to_csv('sleep_data.csv', durations * 1000, np.array(sleep_errors) * 1000, np.array(busy_sleep_errors) * 1000)
|
||||||
33
annotation/experiments/sleep_testing/plot_errors.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import csv
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def plot_from_csv(filename, save_plot=False):
|
||||||
|
durations = []
|
||||||
|
sleep_errors = []
|
||||||
|
busy_sleep_errors = []
|
||||||
|
|
||||||
|
with open(filename, 'r') as csvfile:
|
||||||
|
reader = csv.DictReader(csvfile)
|
||||||
|
for row in reader:
|
||||||
|
durations.append(float(row['duration']))
|
||||||
|
sleep_errors.append(float(row['sleep_error']))
|
||||||
|
busy_sleep_errors.append(float(row['busy_sleep_error']))
|
||||||
|
|
||||||
|
plt.figure(figsize=(10, 5))
|
||||||
|
plt.plot(durations, sleep_errors, label='time.sleep()', marker='o')
|
||||||
|
plt.plot(durations, busy_sleep_errors, label='busy_sleep()', marker='x')
|
||||||
|
plt.xlabel('Desired Delay (ms)')
|
||||||
|
plt.ylabel('Average Error (ms)')
|
||||||
|
plt.title('Sleep Accuracy: time.sleep() vs Busy-Wait Loop (macOS)')
|
||||||
|
plt.legend()
|
||||||
|
plt.grid(True)
|
||||||
|
|
||||||
|
if save_plot:
|
||||||
|
plt.savefig('sleep_accuracy_plot.png', dpi=300)
|
||||||
|
print("Plot saved as sleep_accuracy_plot.png")
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
plot_from_csv('sleep_data.csv', save_plot=True)
|
||||||
110
annotation/experiments/stopwatch/stopwatch.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
import glob
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import scipy.stats as stats
|
||||||
|
import seaborn as sns
|
||||||
|
|
||||||
|
# use this: https://www.estopwatch.net/
|
||||||
|
|
||||||
|
def read_file(file_path):
|
||||||
|
df = pd.read_csv(file_path)
|
||||||
|
df['Elapsed time'] = pd.to_datetime(df['Elapsed time'], errors='coerce')
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_new_error(run_df, groundtruth_df):
|
||||||
|
cumulative_errors = run_df['Elapsed time'] - groundtruth_df['Elapsed time']
|
||||||
|
cumulative_errors_in_seconds = cumulative_errors.dt.total_seconds()
|
||||||
|
|
||||||
|
new_errors_in_seconds = cumulative_errors_in_seconds.diff().fillna(cumulative_errors_in_seconds[0])
|
||||||
|
new_error_points = new_errors_in_seconds[new_errors_in_seconds != 0].index.tolist()
|
||||||
|
|
||||||
|
return new_errors_in_seconds[new_error_points]
|
||||||
|
|
||||||
|
def calculate_statistics(errors):
|
||||||
|
if len(errors) == 0:
|
||||||
|
return {
|
||||||
|
'mean_error': 0,
|
||||||
|
'median_error': 0,
|
||||||
|
'stddev_error': 0,
|
||||||
|
'rmse_error': 0,
|
||||||
|
'confidence_interval': (0, 0),
|
||||||
|
'error_frequency': 0
|
||||||
|
}
|
||||||
|
|
||||||
|
mean_error = np.mean(errors)
|
||||||
|
median_error = np.median(errors)
|
||||||
|
stddev_error = np.std(errors)
|
||||||
|
rmse_error = np.sqrt(np.mean(np.square(errors)))
|
||||||
|
|
||||||
|
ci_low, ci_high = stats.t.interval(
|
||||||
|
confidence=0.95,
|
||||||
|
df=len(errors) - 1,
|
||||||
|
loc=mean_error,
|
||||||
|
scale=stats.sem(errors) if len(errors) > 1 else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'mean_error': mean_error,
|
||||||
|
'median_error': median_error,
|
||||||
|
'stddev_error': stddev_error,
|
||||||
|
'rmse_error': rmse_error,
|
||||||
|
'confidence_interval': (ci_low, ci_high),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
groundtruth_file = 'groundtruth.csv'
|
||||||
|
run_files = glob.glob('runs/*.csv')
|
||||||
|
|
||||||
|
groundtruth_df = read_file(groundtruth_file)
|
||||||
|
run_dfs = {f'run{i+1}': read_file(file) for i, file in enumerate(run_files)}
|
||||||
|
|
||||||
|
total_errors = []
|
||||||
|
total_points = 0
|
||||||
|
all_errors = []
|
||||||
|
|
||||||
|
for run, df in run_dfs.items():
|
||||||
|
errors = analyze_new_error(df, groundtruth_df)
|
||||||
|
total_errors.extend(errors)
|
||||||
|
all_errors.extend(errors)
|
||||||
|
total_points += len(df)
|
||||||
|
|
||||||
|
results = calculate_statistics(errors)
|
||||||
|
error_frequency = len(errors) / len(df)
|
||||||
|
|
||||||
|
print(f"Results for {run}:")
|
||||||
|
print(f"Mean New Error: {results['mean_error']:.5f} seconds")
|
||||||
|
print(f"Median New Error: {results['median_error']:.5f} seconds")
|
||||||
|
print(f"Standard Deviation of New Error: {results['stddev_error']:.5f} seconds")
|
||||||
|
print(f"RMSE of New Error: {results['rmse_error']:.5f} seconds")
|
||||||
|
print(f"95% Confidence Interval of New Error: ({results['confidence_interval'][0]:.5f}, {results['confidence_interval'][1]:.5f}) seconds")
|
||||||
|
print(f"New Error Frequency: {error_frequency*100:.5f} %")
|
||||||
|
print('-----------------------------------------')
|
||||||
|
|
||||||
|
total_results = calculate_statistics(total_errors)
|
||||||
|
total_error_frequency = len(total_errors) / total_points
|
||||||
|
|
||||||
|
print("Total Statistics:")
|
||||||
|
print(f"Mean New Error: {total_results['mean_error']:.5f} seconds")
|
||||||
|
print(f"Median New Error: {total_results['median_error']:.5f} seconds")
|
||||||
|
print(f"Standard Deviation of New Error: {total_results['stddev_error']:.5f} seconds")
|
||||||
|
print(f"RMSE of New Error: {total_results['rmse_error']:.5f} seconds")
|
||||||
|
print(f"95% Confidence Interval of New Error: ({total_results['confidence_interval'][0]:.5f}, {total_results['confidence_interval'][1]:.5f}) seconds")
|
||||||
|
print(f"New Error Frequency: {total_error_frequency*100:.5f} %")
|
||||||
|
|
||||||
|
# do plus minus
|
||||||
|
print(f"New Error: {total_results['mean_error']:.5f} ± {total_results['confidence_interval'][1] - total_results['mean_error']:.5f} seconds")
|
||||||
|
|
||||||
|
plt.figure(figsize=(10, 5))
|
||||||
|
sns.histplot(all_errors, bins=12, kde=False)
|
||||||
|
plt.title('Distribution of Newly Introduced Errors (macOS)')
|
||||||
|
plt.xlabel('Error Duration (seconds)')
|
||||||
|
plt.ylabel('Frequency')
|
||||||
|
plt.savefig('error_dist', dpi=300)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
39
annotation/main.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from PyQt6.QtWidgets import QApplication
|
||||||
|
|
||||||
|
from ducktrack import MainInterface
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
app = QApplication(sys.argv)
|
||||||
|
app.setQuitOnLastWindowClosed(False)
|
||||||
|
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
||||||
|
interface = MainInterface(app)
|
||||||
|
interface.show()
|
||||||
|
|
||||||
|
# TODO: come up with a better error solution to this
|
||||||
|
|
||||||
|
original_excepthook = sys.excepthook
|
||||||
|
def handle_exception(exc_type, exc_value, exc_traceback):
|
||||||
|
print("Exception type:", exc_type)
|
||||||
|
print("Exception value:", exc_value)
|
||||||
|
|
||||||
|
trace_details = traceback.format_exception(exc_type, exc_value, exc_traceback)
|
||||||
|
trace_string = "".join(trace_details)
|
||||||
|
|
||||||
|
print("Exception traceback:", trace_string)
|
||||||
|
|
||||||
|
message = f"An error occurred!\n\n{exc_value}\n\n{trace_string}"
|
||||||
|
interface.display_error_message(message)
|
||||||
|
|
||||||
|
original_excepthook(exc_type, exc_value, exc_traceback)
|
||||||
|
|
||||||
|
sys.excepthook = handle_exception
|
||||||
|
|
||||||
|
sys.exit(app.exec())
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
BIN
annotation/readme_images/Screenshot 2023-06-17 220155.png
Normal file
|
After Width: | Height: | Size: 1.8 MiB |
BIN
annotation/readme_images/Screenshot 2023-06-17 221407.png
Normal file
|
After Width: | Height: | Size: 156 KiB |
BIN
annotation/readme_images/Screenshot 2023-06-17 221553.png
Normal file
|
After Width: | Height: | Size: 176 KiB |
BIN
annotation/readme_images/Screenshot 2023-06-17 222626.png
Normal file
|
After Width: | Height: | Size: 21 KiB |
BIN
annotation/readme_images/Screenshot 2023-06-24 103752.png
Normal file
|
After Width: | Height: | Size: 170 KiB |
BIN
annotation/readme_images/Screenshot 2023-06-24 104203.png
Normal file
|
After Width: | Height: | Size: 156 KiB |
BIN
annotation/readme_images/Screenshot 2023-06-24 110033.png
Normal file
|
After Width: | Height: | Size: 127 KiB |
BIN
annotation/readme_images/Screenshot 2023-06-24 110113.png
Normal file
|
After Width: | Height: | Size: 94 KiB |
BIN
annotation/readme_images/Screenshot 2023-06-24 110823.png
Normal file
|
After Width: | Height: | Size: 192 KiB |
BIN
annotation/readme_images/Screenshot 2023-06-24 111017.png
Normal file
|
After Width: | Height: | Size: 26 KiB |
BIN
annotation/readme_images/Screenshot 2023-06-24 111110.png
Normal file
|
After Width: | Height: | Size: 185 KiB |
BIN
annotation/readme_images/Screenshot 2023-06-24 111422.png
Normal file
|
After Width: | Height: | Size: 166 KiB |
BIN
annotation/readme_images/Screenshot 2023-06-24 111634.png
Normal file
|
After Width: | Height: | Size: 56 KiB |
BIN
annotation/readme_images/Screenshot 2023-06-24 111654.png
Normal file
|
After Width: | Height: | Size: 68 KiB |
BIN
annotation/readme_images/Screenshot 2023-06-24 111809.png
Normal file
|
After Width: | Height: | Size: 57 KiB |
BIN
annotation/readme_images/Screenshot 2023-06-24 111841.png
Normal file
|
After Width: | Height: | Size: 64 KiB |
BIN
annotation/readme_images/Screenshot 2023-06-24 112001.png
Normal file
|
After Width: | Height: | Size: 442 KiB |
BIN
annotation/readme_images/Screenshot 2023-06-24 113548.png
Normal file
|
After Width: | Height: | Size: 88 KiB |
BIN
annotation/readme_images/Screenshot 2023-06-24 115916.png
Normal file
|
After Width: | Height: | Size: 836 KiB |
BIN
annotation/readme_images/Screenshot 2023-06-24 120133.png
Normal file
|
After Width: | Height: | Size: 674 KiB |
BIN
annotation/readme_images/Screenshot 2023-06-24 120347.png
Normal file
|
After Width: | Height: | Size: 248 KiB |
BIN
annotation/readme_images/Screenshot 2023-06-24 121017.png
Normal file
|
After Width: | Height: | Size: 109 KiB |
BIN
annotation/readme_images/Screenshot 2023-06-24 121222.png
Normal file
|
After Width: | Height: | Size: 144 KiB |
BIN
annotation/readme_images/Screenshot 2023-06-24 122006.png
Normal file
|
After Width: | Height: | Size: 190 KiB |
BIN
annotation/readme_images/Screenshot 2023-06-24 162423.png
Normal file
|
After Width: | Height: | Size: 13 KiB |
9
annotation/requirements.txt
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
git+https://github.com/moses-palmer/pynput.git@refs/pull/541/head # to make sure that it works on Apple Silicon
|
||||||
|
pyautogui
|
||||||
|
obsws-python
|
||||||
|
PyQt6
|
||||||
|
Pillow
|
||||||
|
screeninfo
|
||||||
|
wmi
|
||||||
|
psutil
|
||||||
|
pyinstaller
|
||||||
0
annotation/tests/__init__.py
Normal file
BIN
desktop_env/assets/cursor.png
Normal file
|
After Width: | Height: | Size: 4.7 KiB |
@@ -1,35 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from fabric import Connection
|
|
||||||
|
|
||||||
from .xdotool import XDoToolController
|
|
||||||
from .python import PythonController
|
|
||||||
|
|
||||||
class AbstractKeyboardController(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def type(self, text: str):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def key(self, key: str):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
class XDoToolKeyboardController(AbstractKeyboardController, XDoToolController):
|
|
||||||
def __init__(self, ssh_connection: Connection):
|
|
||||||
super().__init__(ssh_connection=ssh_connection)
|
|
||||||
|
|
||||||
def type(self, text: str):
|
|
||||||
self._execute_xdotool_command(f"type {text}")
|
|
||||||
|
|
||||||
def key(self, key: str):
|
|
||||||
self._execute_xdotool_command(f"key {key}")
|
|
||||||
|
|
||||||
class PythonKeyboardController(AbstractKeyboardController, PythonController):
|
|
||||||
def __init__(self, http_server: str):
|
|
||||||
super().__init__(http_server=http_server)
|
|
||||||
self.command = "python -c \"import keyboard; {command}\""
|
|
||||||
|
|
||||||
def type(self, text: str):
|
|
||||||
self._execute_python_command(self.command.format(command=f"keyboard.write('{text}')"))
|
|
||||||
|
|
||||||
def key(self, key: str):
|
|
||||||
self._execute_python_command(self.command.format(command=f"keyboard.press_and_release('{key}')"))
|
|
||||||
@@ -1,144 +0,0 @@
|
|||||||
from enum import Enum
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from fabric import Connection
|
|
||||||
|
|
||||||
from .xdotool import XDoToolController
|
|
||||||
from .python import PythonController
|
|
||||||
class MouseClick(Enum):
|
|
||||||
LEFT = 1
|
|
||||||
MIDDLE = 2
|
|
||||||
RIGHT = 3
|
|
||||||
WHEEL_UP = 4
|
|
||||||
WHEEL_DOWN = 5
|
|
||||||
|
|
||||||
class AbstractMouseController(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def mouse_move(self, x: int, y: int):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def left_down(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def left_up(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def left_click(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def middle_down(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def middle_up(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def middle_click(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def right_down(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def right_up(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def right_click(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def scroll_up(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def scroll_down(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class XDoToolMouseController(AbstractMouseController, XDoToolController):
|
|
||||||
def __init__(self, ssh_connection: Connection):
|
|
||||||
super().__init__(ssh_connection=ssh_connection)
|
|
||||||
|
|
||||||
def mouse_move(self, x: int, y: int):
|
|
||||||
self._execute_xdotool_command(f"mousemove {x} {y}")
|
|
||||||
|
|
||||||
def left_down(self):
|
|
||||||
self._execute_xdotool_command(f"mousedown 1")
|
|
||||||
|
|
||||||
def left_up(self):
|
|
||||||
self._execute_xdotool_command(f"mouseup 1")
|
|
||||||
|
|
||||||
def left_click(self):
|
|
||||||
self._execute_xdotool_command(f"click 1")
|
|
||||||
|
|
||||||
def middle_down(self):
|
|
||||||
self._execute_xdotool_command(f"mousedown 2")
|
|
||||||
|
|
||||||
def middle_up(self):
|
|
||||||
self._execute_xdotool_command(f"mouseup 2")
|
|
||||||
|
|
||||||
def middle_click(self):
|
|
||||||
self._execute_xdotool_command(f"click 2")
|
|
||||||
|
|
||||||
def right_down(self):
|
|
||||||
self._execute_xdotool_command(f"mousedown 3")
|
|
||||||
|
|
||||||
def right_up(self):
|
|
||||||
self._execute_xdotool_command(f"mouseup 3")
|
|
||||||
|
|
||||||
def right_click(self):
|
|
||||||
self._execute_xdotool_command(f"click 3")
|
|
||||||
|
|
||||||
def scroll_up(self):
|
|
||||||
self._execute_xdotool_command(f"click 4")
|
|
||||||
|
|
||||||
def scroll_down(self):
|
|
||||||
self._execute_xdotool_command(f"click 5")
|
|
||||||
|
|
||||||
class PythonMouseController(AbstractMouseController, PythonController):
|
|
||||||
def __init__(self, http_server: str):
|
|
||||||
super().__init__(http_server=http_server)
|
|
||||||
self.command = "python -c \"import mouse; {command}\""
|
|
||||||
|
|
||||||
def mouse_move(self, x: int, y: int):
|
|
||||||
self._execute_python_command(self.command.format(command=f"mouse.move({x}, {y})"))
|
|
||||||
|
|
||||||
def left_down(self):
|
|
||||||
self._execute_python_command(self.command.format(command="mouse.press(button='left')"))
|
|
||||||
|
|
||||||
def left_up(self):
|
|
||||||
self._execute_python_command(self.command.format(command="mouse.release(button='left')"))
|
|
||||||
|
|
||||||
def left_click(self):
|
|
||||||
self._execute_python_command(self.command.format(command="mouse.click(button='left')"))
|
|
||||||
|
|
||||||
def middle_down(self):
|
|
||||||
self._execute_python_command(self.command.format(command="mouse.press(button='middle')"))
|
|
||||||
|
|
||||||
def middle_up(self):
|
|
||||||
self._execute_python_command(self.command.format(command="mouse.release(button='middle')"))
|
|
||||||
|
|
||||||
def middle_click(self):
|
|
||||||
self._execute_python_command(self.command.format(command="mouse.click(button='middle')"))
|
|
||||||
|
|
||||||
def right_down(self):
|
|
||||||
self._execute_python_command(self.command.format(command="mouse.press(button='right')"))
|
|
||||||
|
|
||||||
def right_up(self):
|
|
||||||
self._execute_python_command(self.command.format(command="mouse.release(button='right')"))
|
|
||||||
|
|
||||||
def right_click(self):
|
|
||||||
self._execute_python_command(self.command.format(command="mouse.click(button='right')"))
|
|
||||||
|
|
||||||
def scroll_up(self):
|
|
||||||
self._execute_python_command(self.command.format(command="mouse.wheel(10)"))
|
|
||||||
|
|
||||||
def scroll_down(self):
|
|
||||||
self._execute_python_command(self.command.format(command="mouse.wheel(-10)"))
|
|
||||||
@@ -1,34 +1,208 @@
|
|||||||
import requests
|
|
||||||
import json
|
import json
|
||||||
|
from typing import Any, Dict
|
||||||
|
import requests
|
||||||
|
from desktop_env.envs.actions import KEYBOARD_KEYS
|
||||||
|
|
||||||
|
|
||||||
class PythonController:
|
class PythonController:
|
||||||
def __init__(self, http_server: str):
|
def __init__(self, http_server: str, pkgs_prefix: str = "python -c \"import pyautogui; {command}\""):
|
||||||
self.http_server = http_server
|
self.http_server = http_server
|
||||||
|
self.pkgs_prefix = pkgs_prefix # fixme: this is a hacky way to execute python commands. fix it and combine it with installation of packages
|
||||||
def _execute_python_command(self, command: str) -> None:
|
|
||||||
payload = json.dumps({
|
def get_screenshot(self):
|
||||||
"command": command
|
"""
|
||||||
})
|
Gets a screenshot from the server. With the cursor.
|
||||||
|
"""
|
||||||
|
response = requests.get(self.http_server + "/screenshot")
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response.content
|
||||||
|
else:
|
||||||
|
print("Failed to get screenshot. Status code:", response.status_code)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_file(self, file_path: str):
|
||||||
|
"""
|
||||||
|
Gets a file from the server.
|
||||||
|
"""
|
||||||
|
response = requests.post(self.http_server + "/file", data={"file_path": file_path})
|
||||||
|
if response.status_code == 200:
|
||||||
|
print("File downloaded successfully")
|
||||||
|
return response.content
|
||||||
|
else:
|
||||||
|
print("Failed to get file. Status code:", response.status_code)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def execute_python_command(self, command: str) -> None:
|
||||||
|
"""
|
||||||
|
Executes a python command on the server.
|
||||||
|
It can be used to execute the pyautogui commands, or... any other python command. who knows?
|
||||||
|
"""
|
||||||
|
command = self.pkgs_prefix.format(command=command)
|
||||||
|
payload = json.dumps({"command": command})
|
||||||
headers = {
|
headers = {
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.post(self.http_server + "/execute", headers=headers, data=payload)
|
response = requests.post(self.http_server + "/execute", headers=headers, data=payload)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
print("Command executed successfully:", response.text)
|
print("Command executed successfully:", response.text)
|
||||||
else:
|
else:
|
||||||
print("Failed to execute command. Status code:", response.status_code)
|
print("Failed to execute command. Status code:", response.status_code)
|
||||||
|
return response.json()
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
print("An error occurred while trying to execute the command:", e)
|
print("An error occurred while trying to execute the command:", e)
|
||||||
|
|
||||||
# example usage
|
def execute_action(self, action: Dict[str, Any]):
|
||||||
if __name__ == '__main__':
|
"""
|
||||||
# replace with your actual server URL of the vm
|
Executes an action on the server computer.
|
||||||
server_url = "http://192.168.7.129:5000"
|
"""
|
||||||
controller = PythonController(server_url)
|
|
||||||
|
|
||||||
# example commands
|
action_type = action["action_type"]
|
||||||
python_command = "python -c \"import keyboard; keyboard.write('hello world')\""
|
parameters = action["parameters"] if "parameters" in action else {}
|
||||||
python_command = "python -c \"import mouse; mouse.move(100,100);mouse.right_click()\""
|
|
||||||
controller._execute_python_command(python_command)
|
if action_type == "MOVE_TO":
|
||||||
|
if parameters == {} or None:
|
||||||
|
self.execute_python_command(f"pyautogui.moveTo()")
|
||||||
|
elif "x" in parameters and "y" in parameters:
|
||||||
|
x = parameters["x"]
|
||||||
|
y = parameters["y"]
|
||||||
|
self.execute_python_command(f"pyautogui.moveTo({x}, {y})")
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unknown parameters: {parameters}")
|
||||||
|
|
||||||
|
elif action_type == "CLICK":
|
||||||
|
if parameters == {} or None:
|
||||||
|
self.execute_python_command(f"pyautogui.click()")
|
||||||
|
elif "button" in parameters and "x" in parameters and "y" in parameters:
|
||||||
|
button = parameters["button"]
|
||||||
|
x = parameters["x"]
|
||||||
|
y = parameters["y"]
|
||||||
|
if "num_clicks" in parameters:
|
||||||
|
num_clicks = parameters["num_clicks"]
|
||||||
|
self.execute_python_command(f"pyautogui.click(button='{button}', x={x}, y={y}, clicks={num_clicks})")
|
||||||
|
else:
|
||||||
|
self.execute_python_command(f"pyautogui.click(button='{button}', x={x}, y={y})")
|
||||||
|
elif "button" in parameters and "x" not in parameters and "y" not in parameters:
|
||||||
|
button = parameters["button"]
|
||||||
|
if "num_clicks" in parameters:
|
||||||
|
num_clicks = parameters["num_clicks"]
|
||||||
|
self.execute_python_command(f"pyautogui.click(button='{button}', clicks={num_clicks})")
|
||||||
|
else:
|
||||||
|
self.execute_python_command(f"pyautogui.click(button='{button}')")
|
||||||
|
elif "button" not in parameters and "x" in parameters and "y" in parameters:
|
||||||
|
x = parameters["x"]
|
||||||
|
y = parameters["y"]
|
||||||
|
if "num_clicks" in parameters:
|
||||||
|
num_clicks = parameters["num_clicks"]
|
||||||
|
self.execute_python_command(f"pyautogui.click(x={x}, y={y}, clicks={num_clicks})")
|
||||||
|
else:
|
||||||
|
self.execute_python_command(f"pyautogui.click(x={x}, y={y})")
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unknown parameters: {parameters}")
|
||||||
|
|
||||||
|
elif action_type == "MOUSE_DOWN":
|
||||||
|
if parameters == {} or None:
|
||||||
|
self.execute_python_command(f"pyautogui.mouseDown()")
|
||||||
|
elif "button" in parameters:
|
||||||
|
button = parameters["button"]
|
||||||
|
self.execute_python_command(f"pyautogui.mouseDown(button='{button}')")
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unknown parameters: {parameters}")
|
||||||
|
|
||||||
|
elif action_type == "MOUSE_UP":
|
||||||
|
if parameters == {} or None:
|
||||||
|
self.execute_python_command(f"pyautogui.mouseUp()")
|
||||||
|
elif "button" in parameters:
|
||||||
|
button = parameters["button"]
|
||||||
|
self.execute_python_command(f"pyautogui.mouseUp(button='{button}')")
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unknown parameters: {parameters}")
|
||||||
|
|
||||||
|
elif action_type == "RIGHT_CLICK":
|
||||||
|
if parameters == {} or None:
|
||||||
|
self.execute_python_command(f"pyautogui.rightClick()")
|
||||||
|
elif "x" in parameters and "y" in parameters:
|
||||||
|
x = parameters["x"]
|
||||||
|
y = parameters["y"]
|
||||||
|
self.execute_python_command(f"pyautogui.rightClick(x={x}, y={y})")
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unknown parameters: {parameters}")
|
||||||
|
|
||||||
|
elif action_type == "DOUBLE_CLICK":
|
||||||
|
if parameters == {} or None:
|
||||||
|
self.execute_python_command(f"pyautogui.doubleClick()")
|
||||||
|
elif "x" in parameters and "y" in parameters:
|
||||||
|
x = parameters["x"]
|
||||||
|
y = parameters["y"]
|
||||||
|
self.execute_python_command(f"pyautogui.doubleClick(x={x}, y={y})")
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unknown parameters: {parameters}")
|
||||||
|
|
||||||
|
elif action_type == "DRAG_TO":
|
||||||
|
if "x" in parameters and "y" in parameters:
|
||||||
|
x = parameters["x"]
|
||||||
|
y = parameters["y"]
|
||||||
|
self.execute_python_command(f"pyautogui.dragTo({x}, {y}, duration=1.0, button='left', mouseDownUp=True)")
|
||||||
|
|
||||||
|
elif action_type == "SCROLL":
|
||||||
|
# todo: check if it is related to the operating system, as https://github.com/TheDuckAI/DuckTrack/blob/main/ducktrack/playback.py pointed out
|
||||||
|
if "dx" in parameters and "dy" in parameters:
|
||||||
|
dx = parameters["dx"]
|
||||||
|
dy = parameters["dy"]
|
||||||
|
self.execute_python_command(f"pyautogui.hscroll({dx})")
|
||||||
|
self.execute_python_command(f"pyautogui.vscroll({dy})")
|
||||||
|
elif "dx" in parameters and "dy" not in parameters:
|
||||||
|
dx = parameters["dx"]
|
||||||
|
self.execute_python_command(f"pyautogui.hscroll({dx})")
|
||||||
|
elif "dx" not in parameters and "dy" in parameters:
|
||||||
|
dy = parameters["dy"]
|
||||||
|
self.execute_python_command(f"pyautogui.vscroll({dy})")
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unknown parameters: {parameters}")
|
||||||
|
|
||||||
|
elif action_type == "TYPING":
|
||||||
|
if "text" not in parameters:
|
||||||
|
raise Exception(f"Unknown parameters: {parameters}")
|
||||||
|
text = parameters["text"]
|
||||||
|
self.execute_python_command(f"pyautogui.typewrite('{text}')")
|
||||||
|
|
||||||
|
elif action_type == "PRESS":
|
||||||
|
if "key" not in parameters:
|
||||||
|
raise Exception(f"Unknown parameters: {parameters}")
|
||||||
|
key = parameters["key"]
|
||||||
|
if key.lower() not in KEYBOARD_KEYS:
|
||||||
|
raise Exception(f"Key must be one of {KEYBOARD_KEYS}")
|
||||||
|
self.execute_python_command(f"pyautogui.press('{key}')")
|
||||||
|
|
||||||
|
elif action_type == "KEY_DOWN":
|
||||||
|
if "key" not in parameters:
|
||||||
|
raise Exception(f"Unknown parameters: {parameters}")
|
||||||
|
key = parameters["key"]
|
||||||
|
if key.lower() not in KEYBOARD_KEYS:
|
||||||
|
raise Exception(f"Key must be one of {KEYBOARD_KEYS}")
|
||||||
|
self.execute_python_command(f"pyautogui.keyDown('{key}')")
|
||||||
|
|
||||||
|
elif action_type == "KEY_UP":
|
||||||
|
if "key" not in parameters:
|
||||||
|
raise Exception(f"Unknown parameters: {parameters}")
|
||||||
|
key = parameters["key"]
|
||||||
|
if key.lower() not in KEYBOARD_KEYS:
|
||||||
|
raise Exception(f"Key must be one of {KEYBOARD_KEYS}")
|
||||||
|
self.execute_python_command(f"pyautogui.keyUp('{key}')")
|
||||||
|
|
||||||
|
elif action_type == "HOTKEY":
|
||||||
|
if "keys" not in parameters:
|
||||||
|
raise Exception(f"Unknown parameters: {parameters}")
|
||||||
|
keys = parameters["keys"]
|
||||||
|
if not isinstance(keys, list):
|
||||||
|
raise Exception(f"Keys must be a list of keys")
|
||||||
|
for key in keys:
|
||||||
|
if key.lower() not in KEYBOARD_KEYS:
|
||||||
|
raise Exception(f"Key must be one of {KEYBOARD_KEYS}")
|
||||||
|
|
||||||
|
keys_para_rep = "', '".join(keys)
|
||||||
|
self.execute_python_command(f"pyautogui.hotkey('{keys_para_rep}')")
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unknown action type: {action_type}")
|
||||||
|
|||||||
96
desktop_env/controllers/setup.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class SetupController:
|
||||||
|
def __init__(self, http_server: str):
|
||||||
|
self.http_server = http_server + "/setup"
|
||||||
|
|
||||||
|
def setup(self, config):
|
||||||
|
"""
|
||||||
|
Setup Config:
|
||||||
|
{
|
||||||
|
download: list[tuple[string]], # a list of tuples of url of file to download and the save path
|
||||||
|
...
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
self._download_setup(config)
|
||||||
|
self._change_wallpaper(config)
|
||||||
|
# self._tidy_desktop(config) todo: implement this
|
||||||
|
self._open_setup(config)
|
||||||
|
# can add other setup steps
|
||||||
|
|
||||||
|
def _download_setup(self, config):
|
||||||
|
if not config:
|
||||||
|
return
|
||||||
|
if not 'download' in config:
|
||||||
|
return
|
||||||
|
for url, path in config['download']:
|
||||||
|
if not url or not path:
|
||||||
|
raise Exception(f"Setup Download - Invalid URL ({url}) or path ({path}).")
|
||||||
|
|
||||||
|
payload = json.dumps({"url": url, "path": path})
|
||||||
|
headers = {
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
}
|
||||||
|
|
||||||
|
# send request to server to download file
|
||||||
|
try:
|
||||||
|
response = requests.post(self.http_server + "/download_file", headers=headers, data=payload)
|
||||||
|
if response.status_code == 200:
|
||||||
|
print("Command executed successfully:", response.text)
|
||||||
|
else:
|
||||||
|
print("Failed to download file. Status code:", response.text)
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print("An error occurred while trying to send the request:", e)
|
||||||
|
|
||||||
|
def _change_wallpaper(self, config):
|
||||||
|
if not config:
|
||||||
|
return
|
||||||
|
if not 'wallpaper' in config:
|
||||||
|
return
|
||||||
|
path = config['wallpaper']
|
||||||
|
if not path:
|
||||||
|
raise Exception(f"Setup Wallpaper - Invalid path ({path}).")
|
||||||
|
|
||||||
|
payload = json.dumps({"path": path})
|
||||||
|
headers = {
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
}
|
||||||
|
|
||||||
|
# send request to server to change wallpaper
|
||||||
|
try:
|
||||||
|
response = requests.post(self.http_server + "/change_wallpaper", headers=headers, data=payload)
|
||||||
|
if response.status_code == 200:
|
||||||
|
print("Command executed successfully:", response.text)
|
||||||
|
else:
|
||||||
|
print("Failed to change wallpaper. Status code:", response.text)
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print("An error occurred while trying to send the request:", e)
|
||||||
|
|
||||||
|
def _tidy_desktop(self, config):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _open_setup(self, config):
|
||||||
|
if not config:
|
||||||
|
return
|
||||||
|
if not 'open' in config:
|
||||||
|
return
|
||||||
|
for path in config['open']:
|
||||||
|
if not path:
|
||||||
|
raise Exception(f"Setup Open - Invalid path ({path}).")
|
||||||
|
|
||||||
|
payload = json.dumps({"path": path})
|
||||||
|
headers = {
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
}
|
||||||
|
|
||||||
|
# send request to server to open file
|
||||||
|
try:
|
||||||
|
response = requests.post(self.http_server + "/open_file", headers=headers, data=payload)
|
||||||
|
if response.status_code == 200:
|
||||||
|
print("Command executed successfully:", response.text)
|
||||||
|
else:
|
||||||
|
print("Failed to open file. Status code:", response.text)
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print("An error occurred while trying to send the request:", e)
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
from fabric import Connection
|
|
||||||
|
|
||||||
class XDoToolController:
|
|
||||||
def __init__(self, ssh_connection: Connection):
|
|
||||||
self.ssh_connection = ssh_connection
|
|
||||||
|
|
||||||
def _execute_xdotool_command(self, command: list[str]) -> None:
|
|
||||||
result = self.ssh_connection.run(f"DISPLAY=:0 xdotool {command}", hide=True)
|
|
||||||
return result.stdout.strip()
|
|
||||||
190
desktop_env/envs/actions.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
X_MAX = 1920 # TODO: get the screen resolution
|
||||||
|
Y_MAX = 1080
|
||||||
|
|
||||||
|
KEYBOARD_KEYS = ['\t', '\n', '\r', ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', 'accept', 'add', 'alt', 'altleft', 'altright', 'apps', 'backspace', 'browserback', 'browserfavorites', 'browserforward', 'browserhome', 'browserrefresh', 'browsersearch', 'browserstop', 'capslock', 'clear', 'convert', 'ctrl', 'ctrlleft', 'ctrlright', 'decimal', 'del', 'delete', 'divide', 'down', 'end', 'enter', 'esc', 'escape', 'execute', 'f1', 'f10', 'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f2', 'f20', 'f21', 'f22', 'f23', 'f24', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9', 'final', 'fn', 'hanguel', 'hangul', 'hanja', 'help', 'home', 'insert', 'junja', 'kana', 'kanji', 'launchapp1', 'launchapp2', 'launchmail', 'launchmediaselect', 'left', 'modechange', 'multiply', 'nexttrack', 'nonconvert', 'num0', 'num1', 'num2', 'num3', 'num4', 'num5', 'num6', 'num7', 'num8', 'num9', 'numlock', 'pagedown', 'pageup', 'pause', 'pgdn', 'pgup', 'playpause', 'prevtrack', 'print', 'printscreen', 'prntscrn', 'prtsc', 'prtscr', 'return', 'right', 'scrolllock', 'select', 'separator', 'shift', 'shiftleft', 'shiftright', 'sleep', 'stop', 'subtract', 'tab', 'up', 'volumedown', 'volumemute', 'volumeup', 'win', 'winleft', 'winright', 'yen', 'command', 'option', 'optionleft', 'optionright']
|
||||||
|
|
||||||
|
ACTION_SPACE = [
|
||||||
|
{
|
||||||
|
"action_type": "MOVE_TO",
|
||||||
|
"note": "move the cursor to the specified position",
|
||||||
|
"parameters": {
|
||||||
|
"x": {
|
||||||
|
"type": float,
|
||||||
|
"range": [0, X_MAX],
|
||||||
|
"optional": False,
|
||||||
|
},
|
||||||
|
"y": {
|
||||||
|
"type": float,
|
||||||
|
"range": [0, Y_MAX],
|
||||||
|
"optional": False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"action_type": "CLICK",
|
||||||
|
"note": "click the left button if the button not specified, otherwise click the specified button; click at the current position if x and y are not specified, otherwise click at the specified position",
|
||||||
|
"parameters": {
|
||||||
|
"button": {
|
||||||
|
"type": str,
|
||||||
|
"range": ["left", "right", "middle"],
|
||||||
|
"optional": True,
|
||||||
|
},
|
||||||
|
"x": {
|
||||||
|
"type": float,
|
||||||
|
"range": [0, X_MAX],
|
||||||
|
"optional": True,
|
||||||
|
},
|
||||||
|
"y": {
|
||||||
|
"type": float,
|
||||||
|
"range": [0, Y_MAX],
|
||||||
|
"optional": True,
|
||||||
|
},
|
||||||
|
"num_clicks": {
|
||||||
|
"type": int,
|
||||||
|
"range": [1, 2, 3],
|
||||||
|
"optional": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"action_type": "MOUSE_DOWN",
|
||||||
|
"note": "press the left button if the button not specified, otherwise press the specified button",
|
||||||
|
"parameters": {
|
||||||
|
"button": {
|
||||||
|
"type": str,
|
||||||
|
"range": ["left", "right", "middle"],
|
||||||
|
"optional": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"action_type": "MOUSE_UP",
|
||||||
|
"note": "release the left button if the button not specified, otherwise release the specified button",
|
||||||
|
"parameters": {
|
||||||
|
"button": {
|
||||||
|
"type": str,
|
||||||
|
"range": ["left", "right", "middle"],
|
||||||
|
"optional": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"action_type": "RIGHT_CLICK",
|
||||||
|
"note": "right click at the current position if x and y are not specified, otherwise right click at the specified position",
|
||||||
|
"parameters": {
|
||||||
|
"x": {
|
||||||
|
"type": float,
|
||||||
|
"range": [0, X_MAX],
|
||||||
|
"optional": True,
|
||||||
|
},
|
||||||
|
"y": {
|
||||||
|
"type": float,
|
||||||
|
"range": [0, Y_MAX],
|
||||||
|
"optional": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"action_type": "DOUBLE_CLICK",
|
||||||
|
"note": "double click at the current position if x and y are not specified, otherwise double click at the specified position",
|
||||||
|
"parameters": {
|
||||||
|
"x": {
|
||||||
|
"type": float,
|
||||||
|
"range": [0, X_MAX],
|
||||||
|
"optional": True,
|
||||||
|
},
|
||||||
|
"y": {
|
||||||
|
"type": float,
|
||||||
|
"range": [0, Y_MAX],
|
||||||
|
"optional": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"action_type": "DRAG_TO",
|
||||||
|
"note": "drag the cursor to the specified position with the left button pressed",
|
||||||
|
"parameters": {
|
||||||
|
"x": {
|
||||||
|
"type": float,
|
||||||
|
"range": [0, X_MAX],
|
||||||
|
"optional": False,
|
||||||
|
},
|
||||||
|
"y": {
|
||||||
|
"type": float,
|
||||||
|
"range": [0, Y_MAX],
|
||||||
|
"optional": False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"action_type": "SCROLL",
|
||||||
|
"note": "scroll the mouse wheel up or down",
|
||||||
|
"parameters": {
|
||||||
|
"dx": {
|
||||||
|
"type": int,
|
||||||
|
"range": None,
|
||||||
|
"optional": False,
|
||||||
|
},
|
||||||
|
"dy": {
|
||||||
|
"type": int,
|
||||||
|
"range": None,
|
||||||
|
"optional": False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"action_type": "TYPING",
|
||||||
|
"note": "type the specified text",
|
||||||
|
"parameters": {
|
||||||
|
"text": {
|
||||||
|
"type": str,
|
||||||
|
"range": None,
|
||||||
|
"optional": False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"action_type": "PRESS",
|
||||||
|
"note": "press the specified key and release it",
|
||||||
|
"parameters": {
|
||||||
|
"key": {
|
||||||
|
"type": str,
|
||||||
|
"range": KEYBOARD_KEYS,
|
||||||
|
"optional": False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"action_type": "KEY_DOWN",
|
||||||
|
"note": "press the specified key",
|
||||||
|
"parameters": {
|
||||||
|
"key": {
|
||||||
|
"type": str,
|
||||||
|
"range": KEYBOARD_KEYS,
|
||||||
|
"optional": False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"action_type": "KEY_UP",
|
||||||
|
"note": "release the specified key",
|
||||||
|
"parameters": {
|
||||||
|
"key": {
|
||||||
|
"type": str,
|
||||||
|
"range": KEYBOARD_KEYS,
|
||||||
|
"optional": False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"action_type": "HOTKEY",
|
||||||
|
"note": "press the specified key combination",
|
||||||
|
"parameters": {
|
||||||
|
"keys": {
|
||||||
|
"type": list,
|
||||||
|
"range": [KEYBOARD_KEYS],
|
||||||
|
"optional": False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
@@ -1,203 +1,186 @@
|
|||||||
from enum import Enum
|
from __future__ import annotations
|
||||||
from typing import Literal
|
|
||||||
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
from fabric import Connection
|
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
|
import platform
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium import spaces
|
import requests
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from desktop_env.controllers.mouse import MouseClick, AbstractMouseController, XDoToolMouseController, PythonMouseController
|
from desktop_env.controllers.python import PythonController
|
||||||
from desktop_env.controllers.keyboard import AbstractKeyboardController, XDoToolKeyboardController, PythonKeyboardController
|
from desktop_env.controllers.setup import SetupController
|
||||||
|
from desktop_env.evaluators import eval_funcs
|
||||||
|
|
||||||
class Action(Enum):
|
|
||||||
CLICK = 0
|
|
||||||
MOUSE_DOWN = 1
|
|
||||||
MOUSE_UP = 2
|
|
||||||
MOUSE_MOVE = 3
|
|
||||||
KEY = 4
|
|
||||||
TYPE = 5
|
|
||||||
|
|
||||||
VM_TYPE = Literal['ubuntu', 'windows']
|
def _execute_command(command: List[str]) -> None:
|
||||||
|
if command[:4] == ["vmrun", "-T", "ws", "start"]:
|
||||||
|
p = subprocess.Popen(command)
|
||||||
|
p.wait()
|
||||||
|
else:
|
||||||
|
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m")
|
||||||
|
return result.stdout
|
||||||
|
|
||||||
|
|
||||||
class DesktopEnv(gym.Env):
|
class DesktopEnv(gym.Env):
|
||||||
"""DesktopEnv with OpenAI Gym interface."""
|
"""DesktopEnv with OpenAI Gym interface."""
|
||||||
|
|
||||||
def __init__(self, path_to_vm: str, username: str, password: str,
|
def __init__(
|
||||||
host: str, snapshot_path: str = "snapshot", vm_os: VM_TYPE = "ubuntu"):
|
self,
|
||||||
|
path_to_vm: str,
|
||||||
|
snapshot_path: str = "base",
|
||||||
|
instruction: str = None,
|
||||||
|
config: dict = None,
|
||||||
|
evaluator: dict = None,
|
||||||
|
action_space: str = "computer_13",
|
||||||
|
):
|
||||||
|
# Initialize environment variables
|
||||||
self.path_to_vm = path_to_vm
|
self.path_to_vm = path_to_vm
|
||||||
self.username = username
|
self.snapshot_path = snapshot_path # todo: handling the logic of snapshot directory
|
||||||
self.password = password
|
|
||||||
self.host = host
|
|
||||||
self.snapshot_path = snapshot_path
|
|
||||||
|
|
||||||
self.screen_width = 800
|
|
||||||
self.screen_height = 800
|
|
||||||
# Define the action and observation space
|
|
||||||
self.action_space = spaces.Dict({
|
|
||||||
"action_type": spaces.Discrete(len(Action)),
|
|
||||||
"click_type": spaces.Discrete(len(MouseClick)),
|
|
||||||
"x": spaces.Discrete(self.screen_width),
|
|
||||||
"y": spaces.Discrete(self.screen_height),
|
|
||||||
"key": spaces.MultiDiscrete([128] * 10), # max 10 characters, ASCII
|
|
||||||
"text": spaces.MultiDiscrete([128] * 10) # max 10 characters, ASCII
|
|
||||||
})
|
|
||||||
|
|
||||||
self.observation_space = spaces.Box(low=0, high=255, shape=(self.screen_width, self.screen_height, 3), dtype=np.uint8)
|
# Initialize emulator and controller
|
||||||
|
print("Initializing...")
|
||||||
# Additional setup
|
|
||||||
self.metadata = {'render.modes': ['rgb_array']}
|
|
||||||
self._start_emulator()
|
self._start_emulator()
|
||||||
self._wait_for_emulator_load()
|
self.host = f"http://{self._get_vm_ip()}:5000"
|
||||||
|
self.controller = PythonController(http_server=self.host)
|
||||||
|
self.setup_controller = SetupController(http_server=self.host)
|
||||||
|
self.instruction = instruction
|
||||||
|
self.config = config
|
||||||
|
self.evaluator = evaluator
|
||||||
|
|
||||||
# set up controllers
|
# mode: human or machine
|
||||||
self.mouse_controller, self.keyboard_controller = self._create_controllers(vm_os)
|
assert action_space in ["computer_13", "pyautogui"]
|
||||||
|
self.action_space = action_space
|
||||||
def _create_controllers(self, vm_os: VM_TYPE) -> tuple[AbstractMouseController, AbstractKeyboardController]:
|
# todo: define the action space and the observation space as gym did, or extend theirs
|
||||||
if vm_os == "ubuntu":
|
|
||||||
ssh_connection = Connection(host=self.host, user=self.username, connect_kwargs={"password": self.password})
|
|
||||||
mouse_controller = XDoToolMouseController(ssh_connection)
|
|
||||||
keyboard_controller = XDoToolKeyboardController(ssh_connection)
|
|
||||||
elif vm_os == "windows":
|
|
||||||
mouse_controller = PythonMouseController(http_server=self.host)
|
|
||||||
keyboard_controller = PythonKeyboardController(http_server=self.host)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(vm_os)
|
|
||||||
|
|
||||||
return mouse_controller, keyboard_controller
|
|
||||||
|
|
||||||
def _start_emulator(self):
|
def _start_emulator(self):
|
||||||
self._execute_command(["vmrun", "start", self.path_to_vm])
|
|
||||||
|
|
||||||
def _wait_for_emulator_load(self):
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
output = subprocess.check_output("vmrun -T ws list", shell=True, stderr=subprocess.STDOUT)
|
output = subprocess.check_output("vmrun -T ws list", shell=True, stderr=subprocess.STDOUT)
|
||||||
output = output.decode()
|
output = output.decode()
|
||||||
if self.path_to_vm.lstrip("~/") in output:
|
if self.path_to_vm.lstrip("~/") in output:
|
||||||
print("VM is running.")
|
print("VM is running.")
|
||||||
return
|
break
|
||||||
else:
|
else:
|
||||||
print("Waiting for VM to start...")
|
print("Starting VM...")
|
||||||
time.sleep(5)
|
_execute_command(["vmrun", "-T", "ws", "start", self.path_to_vm])
|
||||||
|
time.sleep(3)
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
print(f"Error executing command: {e.output.decode().strip()}")
|
print(f"Error executing command: {e.output.decode().strip()}")
|
||||||
return
|
|
||||||
|
|
||||||
def _execute_command(self, command: list[str]) -> None:
|
def _get_vm_ip(self):
|
||||||
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
max_retries = 10
|
||||||
stdout, stderr = process.communicate()
|
print("Getting IP Address...")
|
||||||
if process.returncode != 0:
|
for _ in range(max_retries):
|
||||||
print(f"Error executing command: {command}")
|
try:
|
||||||
print(stderr.decode())
|
output = _execute_command(["vmrun", "-T", "ws", "getGuestIPAddress", self.path_to_vm]).strip()
|
||||||
return None
|
print(f"IP address: {output}")
|
||||||
else:
|
return output
|
||||||
return stdout.decode()
|
except:
|
||||||
|
time.sleep(5)
|
||||||
def _execute_xdotool_command(self, command: list[str]) -> None:
|
print("Retrying...")
|
||||||
result = self.ssh_connection.run(f"DISPLAY=:0 xdotool {command}", hide=True)
|
raise Exception("Failed to get VM IP address!")
|
||||||
return result.stdout.strip()
|
|
||||||
|
|
||||||
def _save_state(self):
|
def _save_state(self):
|
||||||
self._execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
|
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
|
||||||
|
|
||||||
def _click(self, click: MouseClick):
|
|
||||||
self._execute_xdotool_command(f"click {click.value}")
|
|
||||||
|
|
||||||
def _mousedown(self, click: MouseClick):
|
|
||||||
self._execute_xdotool_command(f"mousedown {click.value}")
|
|
||||||
|
|
||||||
def _mouseup(self, click: MouseClick):
|
|
||||||
self._execute_xdotool_command(f"mouseup {click.value}")
|
|
||||||
|
|
||||||
def _mouse_move(self, x: int, y: int):
|
|
||||||
self._execute_xdotool_command(f"mousemove {x} {y}")
|
|
||||||
|
|
||||||
def _key(self, key: str):
|
|
||||||
self._execute_xdotool_command(f"key {key}")
|
|
||||||
|
|
||||||
def _type(self, text: str):
|
|
||||||
self._execute_xdotool_command(f"type {text}")
|
|
||||||
|
|
||||||
def _get_screenshot(self):
|
def _get_screenshot(self):
|
||||||
image_path = "./screenshot.png"
|
random_uuid = str(uuid.uuid4())
|
||||||
self._execute_command(["vmrun", "-T", "ws", "-gu", self.username, "-gp", self.password, "captureScreen", self.path_to_vm, image_path])
|
os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
|
||||||
|
image_path = os.path.join("tmp", random_uuid, "screenshot.png")
|
||||||
|
|
||||||
|
# Get the screenshot and save to the image_path
|
||||||
|
screenshot = self.controller.get_screenshot()
|
||||||
|
with open(image_path, "wb") as f:
|
||||||
|
f.write(screenshot)
|
||||||
|
|
||||||
return image_path
|
return image_path
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
print("OBS 1")
|
|
||||||
screenshot_image_path = self._get_screenshot()
|
screenshot_image_path = self._get_screenshot()
|
||||||
print("OBS 2")
|
return screenshot_image_path
|
||||||
with Image.open(screenshot_image_path) as img:
|
|
||||||
return np.array(img)
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed=None, options=None):
|
||||||
input("Reset #1 PE")
|
print("Resetting environment...")
|
||||||
#self._execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
|
|
||||||
input("Revert to snapshot #2 PE")
|
print("Reverting to snapshot to {}...".format(self.snapshot_path))
|
||||||
|
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
print("Starting emulator...")
|
||||||
self._start_emulator()
|
self._start_emulator()
|
||||||
input("Started emulator #3 PE")
|
print("Emulator started.")
|
||||||
self._wait_for_emulator_load()
|
|
||||||
observation = self._get_obs()
|
|
||||||
|
|
||||||
|
print("Setting up environment...")
|
||||||
|
self.setup_controller.setup(self.config)
|
||||||
|
|
||||||
|
time.sleep(5)
|
||||||
|
print("Environment setup complete.")
|
||||||
|
|
||||||
|
observation = self._get_obs()
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action, pause=0.5):
|
||||||
action_type = Action(action['action_type'])
|
# fixme: add reminding logic here, decide if the action is valid for the current action_space
|
||||||
if action_type == Action.CLICK:
|
if self.action_space == "computer_13":
|
||||||
click = MouseClick(action['click_type'])
|
# the set of all possible actions defined in the action representation
|
||||||
if click == MouseClick.LEFT:
|
self.controller.execute_action(action)
|
||||||
self.mouse_controller.left_click()
|
elif self.action_space == "pyautogui":
|
||||||
elif click == MouseClick.MIDDLE:
|
# the set of all possible python commands insides `pyautogui`
|
||||||
self.mouse_controller.middle_click()
|
self.controller.execute_python_command(action)
|
||||||
elif click == MouseClick.RIGHT:
|
|
||||||
self.mouse_controller.right_click()
|
|
||||||
elif click == MouseClick.WHEEL_UP:
|
|
||||||
self.mouse_controller.scroll_up()
|
|
||||||
elif click == MouseClick.WHEEL_DOWN:
|
|
||||||
self.mouse_controller.scroll_down()
|
|
||||||
elif action_type == Action.MOUSE_DOWN:
|
|
||||||
click = MouseClick(action['click_type'])
|
|
||||||
if click == MouseClick.LEFT:
|
|
||||||
self.mouse_controller.left_down()
|
|
||||||
elif click == MouseClick.MIDDLE:
|
|
||||||
self.mouse_controller.middle_down()
|
|
||||||
elif click == MouseClick.RIGHT:
|
|
||||||
self.mouse_controller.right_down()
|
|
||||||
elif click == MouseClick.WHEEL_UP:
|
|
||||||
self.mouse_controller.scroll_up()
|
|
||||||
elif click == MouseClick.WHEEL_DOWN:
|
|
||||||
self.mouse_controller.scroll_down()
|
|
||||||
elif action_type == Action.MOUSE_UP:
|
|
||||||
click = MouseClick(action['click_type'])
|
|
||||||
if click == MouseClick.LEFT:
|
|
||||||
self.mouse_controller.left_up()
|
|
||||||
elif click == MouseClick.MIDDLE:
|
|
||||||
self.mouse_controller.middle_up()
|
|
||||||
elif click == MouseClick.RIGHT:
|
|
||||||
self.mouse_controller.right_up()
|
|
||||||
elif click == MouseClick.WHEEL_UP:
|
|
||||||
self.mouse_controller.scroll_up()
|
|
||||||
elif click == MouseClick.WHEEL_DOWN:
|
|
||||||
self.mouse_controller.scroll_down()
|
|
||||||
elif action_type == Action.MOUSE_MOVE:
|
|
||||||
self.mouse_controller.mouse_move(x = action['x'], y = action['y'])
|
|
||||||
elif action_type == Action.KEY:
|
|
||||||
key_sequence = ''.join(map(chr, action['key'])) # Convert integer array to string
|
|
||||||
self.keyboard_controller.key(key_sequence)
|
|
||||||
elif action_type == Action.TYPE:
|
|
||||||
text = ''.join(map(chr, action['text'])) # Convert integer array to string
|
|
||||||
self.keyboard_controller.type(text)
|
|
||||||
|
|
||||||
# Capture new state
|
# todo: maybe for the better here we need to add a logic to wait until the rendering is done
|
||||||
observation = self._get_obs()
|
time.sleep(pause)
|
||||||
reward = 0 # Define reward calculation
|
observation = {
|
||||||
done = False # Define episode termination condition
|
"screenshot": self._get_obs(),
|
||||||
|
"instruction": self.instruction
|
||||||
|
}
|
||||||
|
reward = 0 # todo: Define reward calculation for each example
|
||||||
|
done = False # todo: Define episode termination condition for each example
|
||||||
info = {}
|
info = {}
|
||||||
return observation, reward, done, info
|
return observation, reward, done, info
|
||||||
|
|
||||||
|
def evaluate(self):
|
||||||
|
"""
|
||||||
|
Evaluate whether the task is successfully completed.
|
||||||
|
"""
|
||||||
|
def copy_file_to_local(_file_info):
|
||||||
|
random_uuid = str(uuid.uuid4())
|
||||||
|
os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
|
||||||
|
_path = os.path.join("tmp", random_uuid, "tmp.xlsx")
|
||||||
|
if _file_info["type"] == "cloud_file":
|
||||||
|
url = _file_info["path"]
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
with open(_path, 'wb') as f:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
if chunk:
|
||||||
|
f.write(chunk)
|
||||||
|
elif _file_info["type"] == "vm_file":
|
||||||
|
# fixme: stream this part maybe as well
|
||||||
|
file = self.controller.get_file(_file_info["path"])
|
||||||
|
with open(_path, "wb") as f:
|
||||||
|
f.write(file)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
return _path
|
||||||
|
|
||||||
|
# todo: make this more flexible by refactoring
|
||||||
|
eval_func = eval_funcs[self.evaluator["func"]]
|
||||||
|
eval_func_vars = {}
|
||||||
|
|
||||||
|
for var_name, file_info in self.evaluator["paths"].items():
|
||||||
|
path = copy_file_to_local(file_info)
|
||||||
|
eval_func_vars[var_name] = path
|
||||||
|
|
||||||
|
return eval_func(**eval_func_vars)
|
||||||
|
|
||||||
def render(self, mode='rgb_array'):
|
def render(self, mode='rgb_array'):
|
||||||
if mode == 'rgb_array':
|
if mode == 'rgb_array':
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
@@ -205,4 +188,4 @@ class DesktopEnv(gym.Env):
|
|||||||
raise ValueError('Unsupported render mode: {}'.format(mode))
|
raise ValueError('Unsupported render mode: {}'.format(mode))
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self._execute_command(["vmrun", "stop", self.path_to_vm])
|
_execute_command(["vmrun", "stop", self.path_to_vm])
|
||||||
|
|||||||
5
desktop_env/evaluators/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from .table import compare_table
|
||||||
|
|
||||||
|
eval_funcs = {
|
||||||
|
"compare_table(expected, actual)": compare_table
|
||||||
|
}
|
||||||
0
desktop_env/evaluators/replay.py
Normal file
14
desktop_env/evaluators/table.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
def compare_table(expected, actual):
|
||||||
|
import pandas as pd
|
||||||
|
df1 = pd.read_excel(expected)
|
||||||
|
df2 = pd.read_excel(actual)
|
||||||
|
|
||||||
|
# Compare the DataFrames
|
||||||
|
return 1 if df1.equals(df2) else 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
path1 = ""
|
||||||
|
path2 = ""
|
||||||
|
|
||||||
|
print(compare_table(path1, path2))
|
||||||
190
desktop_env/server/main.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import platform
|
||||||
|
import subprocess
|
||||||
|
import requests
|
||||||
|
|
||||||
|
import Xlib.display
|
||||||
|
import pyautogui
|
||||||
|
from PIL import ImageGrab, Image
|
||||||
|
from flask import Flask, request, jsonify, send_file
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
pyautogui.PAUSE = 0
|
||||||
|
pyautogui.DARWIN_CATCH_UP_TIME = 0
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/execute', methods=['POST'])
|
||||||
|
def execute_command():
|
||||||
|
data = request.json
|
||||||
|
# The 'command' key in the JSON request should contain the command to be executed.
|
||||||
|
command = data.get('command', '')
|
||||||
|
|
||||||
|
# Execute the command without any safety checks.
|
||||||
|
try:
|
||||||
|
result = subprocess.run(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||||
|
return jsonify({
|
||||||
|
'status': 'success',
|
||||||
|
'output': result.stdout,
|
||||||
|
'error': result.stderr
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({
|
||||||
|
'status': 'error',
|
||||||
|
'message': str(e)
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/screenshot', methods=['GET'])
|
||||||
|
def capture_screen_with_cursor():
|
||||||
|
# fixme: when running on virtual machines, the cursor is not captured, don't know why
|
||||||
|
|
||||||
|
file_path = os.path.join("screenshots", "screenshot.png")
|
||||||
|
user_platform = platform.system()
|
||||||
|
|
||||||
|
# Ensure the screenshots directory exists
|
||||||
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
# fixme: This is a temporary fix for the cursor not being captured on Windows and Linux
|
||||||
|
if user_platform == "Windows" or user_platform == "Linux":
|
||||||
|
def _download_image(url, path):
|
||||||
|
response = requests.get(url)
|
||||||
|
with open(path, 'wb') as file:
|
||||||
|
file.write(response.content)
|
||||||
|
|
||||||
|
cursor_path = os.path.join("screenshots", "cursor.png")
|
||||||
|
if not os.path.exists(cursor_path):
|
||||||
|
cursor_url = "https://vip.helloimg.com/images/2023/12/02/oQPzmt.png"
|
||||||
|
_download_image(cursor_url, cursor_path)
|
||||||
|
screenshot = pyautogui.screenshot()
|
||||||
|
cursor_x, cursor_y = pyautogui.position()
|
||||||
|
cursor = Image.open(cursor_path)
|
||||||
|
# make the cursor smaller
|
||||||
|
cursor = cursor.resize((int(cursor.width / 1.5), int(cursor.height / 1.5)))
|
||||||
|
screenshot.paste(cursor, (cursor_x, cursor_y), cursor)
|
||||||
|
screenshot.save(file_path)
|
||||||
|
# elif user_platform == "Linux":
|
||||||
|
# # Use xlib to prevent scrot dependency for Linux
|
||||||
|
# screen = Xlib.display.Display().screen()
|
||||||
|
# size = screen.width_in_pixels, screen.height_in_pixels
|
||||||
|
# screenshot = ImageGrab.grab(bbox=(0, 0, size[0], size[1]))
|
||||||
|
# screenshot.save(file_path)
|
||||||
|
elif user_platform == "Darwin": # (Mac OS)
|
||||||
|
# Use the screencapture utility to capture the screen with the cursor
|
||||||
|
subprocess.run(["screencapture", "-C", file_path])
|
||||||
|
else:
|
||||||
|
print(f"The platform you're using ({user_platform}) is not currently supported")
|
||||||
|
|
||||||
|
return send_file(file_path, mimetype='image/png')
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/file', methods=['POST'])
|
||||||
|
def get_file():
|
||||||
|
# Retrieve filename from the POST request
|
||||||
|
if 'file_path' in request.form:
|
||||||
|
file_path = request.form['file_path']
|
||||||
|
else:
|
||||||
|
return jsonify({"error": "file_path is required"}), 400
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if the file exists and send it to the user
|
||||||
|
return send_file(file_path, as_attachment=True)
|
||||||
|
except FileNotFoundError:
|
||||||
|
# If the file is not found, return a 404 error
|
||||||
|
return jsonify({"error": "File not found"}), 404
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/platform', methods=['GET'])
|
||||||
|
def get_platform():
|
||||||
|
return platform.system()
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/cursor_position', methods=['GET'])
|
||||||
|
def get_cursor_position():
|
||||||
|
return pyautogui.position().x, pyautogui.position().y
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/setup/change_wallpaper", methods=['POST'])
|
||||||
|
def change_wallpaper():
|
||||||
|
data = request.json
|
||||||
|
path = data.get('path', None)
|
||||||
|
|
||||||
|
if not path:
|
||||||
|
return "Path not supplied!", 400
|
||||||
|
|
||||||
|
path = Path(path)
|
||||||
|
|
||||||
|
if not path.exists():
|
||||||
|
return f"File not found: {path}", 404
|
||||||
|
|
||||||
|
try:
|
||||||
|
user_platform = platform.system()
|
||||||
|
if user_platform == "Windows":
|
||||||
|
import ctypes
|
||||||
|
ctypes.windll.user32.SystemParametersInfoW(20, 0, str(path), 3)
|
||||||
|
elif user_platform == "Linux":
|
||||||
|
import subprocess
|
||||||
|
subprocess.run(["gsettings", "set", "org.gnome.desktop.background", "picture-uri", f"file://{path}"])
|
||||||
|
elif user_platform == "Darwin": # (Mac OS)
|
||||||
|
import subprocess
|
||||||
|
subprocess.run(
|
||||||
|
["osascript", "-e", f'tell application "Finder" to set desktop picture to POSIX file "{path}"'])
|
||||||
|
return "Wallpaper changed successfully"
|
||||||
|
except Exception as e:
|
||||||
|
return f"Failed to change wallpaper. Error: {e}", 500
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/setup/download_file", methods=['POST'])
|
||||||
|
def download_file():
|
||||||
|
data = request.json
|
||||||
|
url = data.get('url', None)
|
||||||
|
path = data.get('path', None)
|
||||||
|
|
||||||
|
if not url or not path:
|
||||||
|
return "Path or URL not supplied!", 400
|
||||||
|
|
||||||
|
path = Path(path)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
max_retries = 3
|
||||||
|
for i in range(max_retries):
|
||||||
|
try:
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
with open(path, 'wb') as f:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
if chunk:
|
||||||
|
f.write(chunk)
|
||||||
|
return "File downloaded successfully"
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
print(f"Failed to download {url}. Retrying... ({max_retries - i - 1} attempts left)")
|
||||||
|
|
||||||
|
return f"Failed to download {url}. No retries left. Error: {e}", 500
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/setup/open_file", methods=['POST'])
|
||||||
|
def open_file():
|
||||||
|
data = request.json
|
||||||
|
path = data.get('path', None)
|
||||||
|
|
||||||
|
if not path:
|
||||||
|
return "Path not supplied!", 400
|
||||||
|
|
||||||
|
path = Path(path)
|
||||||
|
|
||||||
|
if not path.exists():
|
||||||
|
return f"File not found: {path}", 404
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.startfile(path)
|
||||||
|
return "File opened successfully"
|
||||||
|
except Exception as e:
|
||||||
|
return f"Failed to open {path}. Error: {e}", 500
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run(debug=True, host="0.0.0.0")
|
||||||
6
desktop_env/server/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
python3-xlib==0.15
|
||||||
|
PyAutoGUI==0.9.54
|
||||||
|
Pillow==10.1.0
|
||||||
|
git+https://github.com/moses-palmer/pynput.git@refs/pull/541/head # to make sure that it works on Apple Silicon
|
||||||
|
requests
|
||||||
|
flask
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
from flask import Flask, request, jsonify
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
app = Flask(__name__)
|
|
||||||
|
|
||||||
@app.route('/execute', methods=['POST'])
|
|
||||||
def execute_command():
|
|
||||||
data = request.json
|
|
||||||
# The 'command' key in the JSON request should contain the command to be executed.
|
|
||||||
command = data.get('command', '')
|
|
||||||
|
|
||||||
# Execute the command without any safety checks.
|
|
||||||
try:
|
|
||||||
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
||||||
stdout, stderr = process.communicate()
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'status': 'success',
|
|
||||||
'output': stdout.decode(),
|
|
||||||
'error': stderr.decode()
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
return jsonify({
|
|
||||||
'status': 'error',
|
|
||||||
'message': str(e)
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
app.run(debug=True, host="0.0.0.0")
|
|
||||||
24
evaluation_examples/README.md
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
# Evaluation examples
|
||||||
|
|
||||||
|
Here we put the data examples to benchmark the ability of agents when interacting with GUI.
|
||||||
|
The examples are stored in `./examples` where each data item formatted as:
|
||||||
|
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"id": "uid", # unique id
|
||||||
|
"snapshot": "snapshot_id", # the snapshot id of the environment, with some data already there and apps already opened, or just desktop
|
||||||
|
"instruction": "natural_language_instruction", # the natural language instruction of the task, what we want the agent to do
|
||||||
|
"source": "website_url", # where we know this example, some forum, or some website, or some paper
|
||||||
|
"config": {xxx}, # the scripts to setup the donwload and open files actions, as the initial state of a task
|
||||||
|
"trajectory": "trajectory_directory", # the trajectory directory, which contains the action sequence file, the screenshots and the recording video
|
||||||
|
"related_apps": ["app1", "app2", ...], # the related apps, which are opened during the task
|
||||||
|
"evaluator": "evaluation_dir", # the directory of the evaluator, which contains the evaluation script for this example
|
||||||
|
…
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The `./trajectories` file contains the annotated trajectories for each data item in `./examples` for finishing the task.
|
||||||
|
|
||||||
|
For now, it is under construction, and only tested on Windows 10. Please:
|
||||||
|
- Modify the path accordingly to run the evaluation;
|
||||||
|
- Remind us if some parts are overfit to our environment.
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"id": "0bf05a7d-b28b-44d2-955a-50b41e24012a",
|
||||||
|
"snapshot": "libreoffice_calc",
|
||||||
|
"instruction": "I would like to pad all the numbers in the 'Old ID' column with zeros in front, to fill them up to seven digits in the 'New 7 Digit ID' column.",
|
||||||
|
"source": "https://www.youtube.com/shorts/FPAQaDTS8VY",
|
||||||
|
"config": {
|
||||||
|
"download": [
|
||||||
|
[
|
||||||
|
"",
|
||||||
|
"C:\\Users\\tianbaox\\Desktop\\Customers_New_7digit_Id.xlsx"
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"open": [
|
||||||
|
"C:\\Users\\tianbaox\\Desktop\\Customers_New_7digit_Id.xlsx"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"trajectory": "trajectories/0bf05a7d-b28b-44d2-955a-50b41e24012a",
|
||||||
|
"related_apps": [
|
||||||
|
"libreoffice calc"
|
||||||
|
],
|
||||||
|
"evaluator": "evaluation_dir"
|
||||||
|
}
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"id": "2bd59342-0664-4ccb-ba87-79379096cc08",
|
||||||
|
"snapshot": "libreoffice_calc",
|
||||||
|
"instruction": "Make sparkline chart line by line",
|
||||||
|
"source": "https://www.youtube.com/shorts/L3Z-F1QTQFY",
|
||||||
|
"config": {
|
||||||
|
"download": [
|
||||||
|
[
|
||||||
|
"",
|
||||||
|
"C:\\Users\\tianbaox\\Desktop\\OrderId_Month_Chart.xlsx"
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"open": [
|
||||||
|
"C:\\Users\\tianbaox\\Desktop\\OrderId_Month_Chart.xlsx"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"trajectory": "trajectories/2bd59342-0664-4ccb-ba87-79379096cc08",
|
||||||
|
"related_apps": [
|
||||||
|
"libreoffice calc"
|
||||||
|
],
|
||||||
|
"evaluator": "evaluation_dir"
|
||||||
|
}
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
{
|
||||||
|
"id": "37608790-6147-45d0-9f20-1137bb35703d",
|
||||||
|
"snapshot": "libreoffice_calc",
|
||||||
|
"instruction": "Help me fill the columns of First Name, Last Name and Rank",
|
||||||
|
"source": "https://www.youtube.com/shorts/uzPo_CPCHH8",
|
||||||
|
"config": {
|
||||||
|
"download": [
|
||||||
|
[
|
||||||
|
"https://drive.usercontent.google.com/download?id=1wDqap5cBfxnlqTNrZG61k_wDWTujl6AU&export=download&authuser=0&confirm=t&uuid=fd183b89-76b7-4dc5-880e-1045ed769562&at=APZUnTWp9RMafMg0xohhBWazN3YD:1701785710674",
|
||||||
|
"C:\\Users\\tianbaox\\Desktop\\Employee_Roles_and_Ranks.xlsx"
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"open": [
|
||||||
|
"C:\\Users\\tianbaox\\Desktop\\Employee_Roles_and_Ranks.xlsx"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"trajectory": "trajectories/37608790-6147-45d0-9f20-1137bb35703d",
|
||||||
|
"related_apps": [
|
||||||
|
"libreoffice calc"
|
||||||
|
],
|
||||||
|
"evaluator": {
|
||||||
|
"func": "compare_table(expected, actual)",
|
||||||
|
"paths": {
|
||||||
|
"expected": {
|
||||||
|
"type": "cloud_file",
|
||||||
|
"path": "https://drive.usercontent.google.com/download?id=1dxpiUqP_CVvQp5tddxlwO3Cp1BqJ-ZDE&export=download&authuser=0&confirm=t&uuid=ccd204c7-07ce-4fdf-a5d4-a7e4f37b9ce6&at=APZUnTVBs7TgrVrDXpkiU8S7WbQo:1702360836747"
|
||||||
|
},
|
||||||
|
"actual": {
|
||||||
|
"type": "vm_file",
|
||||||
|
"path": "C:\\Users\\tianbaox\\Desktop\\Employee_Roles_and_Ranks.xlsx"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"id": "7a4e4bc8-922c-4c84-865c-25ba34136be1",
|
||||||
|
"snapshot": "libreoffice_calc",
|
||||||
|
"instruction": "Reorder the columns to be \"Data\", \"First Name\", \"Last Name\", \"Order ID\", \"Sales\"",
|
||||||
|
"source": "https://www.youtube.com/shorts/bvUhr1AHs44",
|
||||||
|
"config": {
|
||||||
|
"download": [
|
||||||
|
[
|
||||||
|
"",
|
||||||
|
"C:\\Users\\tianbaox\\Desktop\\Name_Order_Id_move_column.xlsx"
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"open": [
|
||||||
|
"C:\\Users\\tianbaox\\Desktop\\Name_Order_Id_move_column.xlsx"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"trajectory": "trajectories/7a4e4bc8-922c-4c84-865c-25ba34136be1",
|
||||||
|
"related_apps": [
|
||||||
|
"libreoffice calc"
|
||||||
|
],
|
||||||
|
"evaluator": "evaluation_dir"
|
||||||
|
}
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"id": "7b802dad-6e0f-4204-9815-d4e3f57627d8",
|
||||||
|
"snapshot": "libreoffice_calc",
|
||||||
|
"instruction": "I would like to sort this table based on cell color, placing all the rows marked with pink at the beginning, while keeping their order among themselves unchanged.",
|
||||||
|
"source": "https://www.youtube.com/shorts/Of-lzeP1usE",
|
||||||
|
"config": {
|
||||||
|
"download": [
|
||||||
|
[
|
||||||
|
"",
|
||||||
|
"C:\\Users\\tianbaox\\Desktop\\Customer_Sort_by_cell_color.xlsx"
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"open": [
|
||||||
|
"C:\\Users\\tianbaox\\Desktop\\Customer_Sort_by_cell_color.xlsx"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"trajectory": "trajectories/7b802dad-6e0f-4204-9815-d4e3f57627d8",
|
||||||
|
"related_apps": [
|
||||||
|
"libreoffice calc"
|
||||||
|
],
|
||||||
|
"evaluator": "evaluation_dir"
|
||||||
|
}
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"id": "7efeb4b1-3d19-4762-b163-63328d66303b",
|
||||||
|
"snapshot": "libreoffice_calc",
|
||||||
|
"instruction": "Fill in the Serieal Numbers in \"Serial #\" column",
|
||||||
|
"source": "https://www.youtube.com/shorts/4jzXfZNhfmk",
|
||||||
|
"config": {
|
||||||
|
"download": [
|
||||||
|
[
|
||||||
|
"",
|
||||||
|
"C:\\Users\\tianbaox\\Desktop\\Order_Sales_Serial#.xlsx"
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"open": [
|
||||||
|
"C:\\Users\\tianbaox\\Desktop\\Order_Sales_Serial#.xlsx"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"trajectory": "trajectories/",
|
||||||
|
"related_apps": [
|
||||||
|
"libreoffice calc"
|
||||||
|
],
|
||||||
|
"evaluator": "evaluation_dir"
|
||||||
|
}
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"id": "a9f325aa-8c05-4e4f-8341-9e4358565f4f",
|
||||||
|
"snapshot": "libreoffice_calc",
|
||||||
|
"instruction": "Clean the messy movie titles and put them in the cleaned column",
|
||||||
|
"source": "https://www.youtube.com/shorts/A0gmEBRKXWs",
|
||||||
|
"config": {
|
||||||
|
"download": [
|
||||||
|
[
|
||||||
|
"",
|
||||||
|
"C:\\Users\\tianbaox\\Desktop\\"
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"open": [
|
||||||
|
"C:\\Users\\tianbaox\\Desktop\\"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"trajectory": "trajectories/a9f325aa-8c05-4e4f-8341-9e4358565f4f",
|
||||||
|
"related_apps": [
|
||||||
|
"libreoffice calc"
|
||||||
|
],
|
||||||
|
"evaluator": "evaluation_dir"
|
||||||
|
}
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
{
|
||||||
|
"id": "d681960f-7bc3-4286-9913-a8812ba3261a",
|
||||||
|
"snapshot": "libreoffice_calc",
|
||||||
|
"instruction": "According to the green table shown above, calculate and give each student a grade",
|
||||||
|
"source": "https://www.youtube.com/shorts/d7U1S_IsTVM",
|
||||||
|
"config": {
|
||||||
|
"download": [
|
||||||
|
[
|
||||||
|
"https://drive.usercontent.google.com/download?id=1wodZjx1KjThUsrtF6ZJaCTy1fQX4E9vA&export=download&authuser=0&confirm=t&uuid=d07ca312-1abc-40f2-81cd-d06e27119854&at=APZUnTWwjnxsHQYapSvpLR8NmlfV:1701785087048",
|
||||||
|
"C:\\Users\\tianbaox\\Desktop\\Student_Grades_and_Remarks.xlsx"
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"open": [
|
||||||
|
"C:\\Users\\tianbaox\\Desktop\\Student_Grades_and_Remarks.xlsx"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"trajectory": "trajectories/d681960f-7bc3-4286-9913-a8812ba3261a",
|
||||||
|
"related_apps": [
|
||||||
|
"libreoffice calc"
|
||||||
|
],
|
||||||
|
"evaluator": {
|
||||||
|
"func": "compare_table(expected, actual)",
|
||||||
|
"paths": {
|
||||||
|
"expected": {
|
||||||
|
"type": "cloud_file",
|
||||||
|
"path": "https://drive.usercontent.google.com/download?id=1kfEHJH1n0yCsQp443IIFvdD9uWv0DWMr&export=download&authuser=0&confirm=t&uuid=d9907f65-8d39-4ecc-8747-b4ed7e6011f5&at=APZUnTXpPAnlh5sD6q-R8oQtqL6g:1702362952170"
|
||||||
|
},
|
||||||
|
"actual": {
|
||||||
|
"type": "vm_file",
|
||||||
|
"path": "C:\\Users\\tianbaox\\Desktop\\Student_Grades_and_Remarks.xlsx"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"id": "eb03d19a-b88d-4de4-8a64-ca0ac66f426b",
|
||||||
|
"snapshot": "libreoffice_calc",
|
||||||
|
"instruction": "Traverse the table and paste it below",
|
||||||
|
"source": "https://www.youtube.com/shorts/t9JLUaT55UQ",
|
||||||
|
"config": {
|
||||||
|
"download": [
|
||||||
|
[
|
||||||
|
"",
|
||||||
|
"C:\\Users\\tianbaox\\Desktop\\"
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"open": [
|
||||||
|
"C:\\Users\\tianbaox\\Desktop\\"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"trajectory": "trajectories/eb03d19a-b88d-4de4-8a64-ca0ac66f426b",
|
||||||
|
"related_apps": [
|
||||||
|
"libreoffice calc"
|
||||||
|
],
|
||||||
|
"evaluator": "evaluation_dir"
|
||||||
|
}
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"id": "ecb0df7a-4e8d-4a03-b162-053391d3afaf",
|
||||||
|
"snapshot": "libreoffice_calc",
|
||||||
|
"instruction": "Enable each cell in the column\"Pass/Fail/Held\" is a drop down list",
|
||||||
|
"source": "https://www.youtube.com/shorts/tXOovKn0H68",
|
||||||
|
"config": {
|
||||||
|
"download": [
|
||||||
|
[
|
||||||
|
"",
|
||||||
|
"C:\\Users\\tianbaox\\Desktop\\"
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"open": [
|
||||||
|
"C:\\Users\\tianbaox\\Desktop\\"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"trajectory": "trajectories/ecb0df7a-4e8d-4a03-b162-053391d3afaf",
|
||||||
|
"related_apps": [
|
||||||
|
"libreoffice calc"
|
||||||
|
],
|
||||||
|
"evaluator": "evaluation_dir"
|
||||||
|
}
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
{
|
||||||
|
"id": "f9584479-3d0d-4c79-affa-9ad7afdd8850",
|
||||||
|
"snapshot": "libreoffice_calc",
|
||||||
|
"instruction": "Fill the missing row and column which show the total value",
|
||||||
|
"source": "https://youtube.com/shorts/feldd-Pn48c?si=9xJiem2uAHm6Jshb",
|
||||||
|
"config": {
|
||||||
|
"download": [
|
||||||
|
[
|
||||||
|
"https://drive.usercontent.google.com/download?id=1rwhniaClEkF8XFzdfaNUA6GmAiy4syMZ&export=download&authuser=0&confirm=t&uuid=6fdd5b04-85f4-45e1-ad74-368f8f2a82ab&at=APZUnTUP-JxPxLfNls6jXWghblQ5:1701766091851",
|
||||||
|
"C:\\Users\\tianbaox\\Desktop\\Quarterly_Product_Sales_by_Zone.xlsx"
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"open": [
|
||||||
|
"C:\\Users\\tianbaox\\Desktop\\Quarterly_Product_Sales_by_Zone.xlsx"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"trajectory": "trajectories/f9584479-3d0d-4c79-affa-9ad7afdd8850",
|
||||||
|
"related_apps": [
|
||||||
|
"libreoffice calc"
|
||||||
|
],
|
||||||
|
"evaluator": {
|
||||||
|
"func": "compare_table(expected, actual)",
|
||||||
|
"paths": {
|
||||||
|
"expected": {
|
||||||
|
"type": "cloud_file",
|
||||||
|
"path": "https://drive.usercontent.google.com/download?id=17f1wZuJPvUEc5at_Fy3c18VFdOk0x7xz&export=download&authuser=0&confirm=t&uuid=6d2edffd-0ce0-426e-9820-8af25b4667f3&at=APZUnTVh7JS85dwZBaV2hytWQgDK:1702361510956"
|
||||||
|
},
|
||||||
|
"actual": {
|
||||||
|
"type": "vm_file",
|
||||||
|
"path": "C:\\Users\\tianbaox\\Desktop\\Quarterly_Product_Sales_by_Zone.xlsx"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
13
evaluation_examples/examples/template.json
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"id": "",
|
||||||
|
"snapshot": "libreoffice_calc",
|
||||||
|
"instruction": "",
|
||||||
|
"source": "",
|
||||||
|
"config": {
|
||||||
|
},
|
||||||
|
"trajectory": "trajectories/",
|
||||||
|
"related_apps": [
|
||||||
|
"libreoffice calc"
|
||||||
|
],
|
||||||
|
"evaluator": "evaluation_dir"
|
||||||
|
}
|
||||||
81
main.py
@@ -1,59 +1,50 @@
|
|||||||
#from pprint import pprint
|
import json
|
||||||
from desktop_env.envs.desktop_env import DesktopEnv, Action, MouseClick
|
from desktop_env.envs.desktop_env import DesktopEnv
|
||||||
|
|
||||||
def get_human_action():
|
|
||||||
"""
|
|
||||||
Prompts the human player for an action and returns a structured action.
|
|
||||||
"""
|
|
||||||
print("\nAvailable actions:", [action.name for action in Action])
|
|
||||||
action_type = None
|
|
||||||
while action_type not in [action.value for action in Action]:
|
|
||||||
action_type = Action[input("Enter the type of action: ".strip())].value
|
|
||||||
|
|
||||||
action = {"action_type": action_type}
|
|
||||||
|
|
||||||
if action_type == Action.CLICK.value or action_type == Action.MOUSE_DOWN.value or action_type == Action.MOUSE_UP.value:
|
|
||||||
print("\n Available clicks:", [action.name for action in MouseClick])
|
|
||||||
click_type = input("Enter click type: ")
|
|
||||||
action["click_type"] = MouseClick[click_type].value
|
|
||||||
|
|
||||||
if action_type == Action.MOUSE_MOVE.value:
|
|
||||||
x = int(input("Enter x-coordinate for mouse move: "))
|
|
||||||
y = int(input("Enter y-coordinate for mouse move: "))
|
|
||||||
action["x"] = x
|
|
||||||
action["y"] = y
|
|
||||||
|
|
||||||
if action_type == Action.KEY.value:
|
|
||||||
key = input("Enter the key to press: ")
|
|
||||||
action["key"] = [ord(c) for c in key]
|
|
||||||
|
|
||||||
if action_type == Action.TYPE.value:
|
|
||||||
text = input("Enter the text to type: ")
|
|
||||||
action["text"] = [ord(c) for c in text]
|
|
||||||
|
|
||||||
return action
|
|
||||||
|
|
||||||
|
|
||||||
def human_agent():
|
def human_agent():
|
||||||
"""
|
"""
|
||||||
Runs the Gym environment with human input.
|
Runs the Gym environment with human input.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
with open("evaluation_examples/examples/37608790-6147-45d0-9f20-1137bb35703d.json", "r") as f:
|
||||||
|
example = json.load(f)
|
||||||
|
|
||||||
#env = DesktopEnv( path_to_vm="/home/yuri/vmware/Windows 10 x64/Windows 10 x64.vmx"
|
#env = DesktopEnv( path_to_vm="/home/yuri/vmware/Windows 10 x64/Windows 10 x64.vmx"
|
||||||
# path_to_vm="/home/yuri/vmware/Ubuntu 64-bit/Ubuntu 64-bit.vmx",
|
# path_to_vm="/home/yuri/vmware/Ubuntu 64-bit/Ubuntu 64-bit.vmx",
|
||||||
env = DesktopEnv( path_to_vm="/home/david/vmware/KUbuntu 64-bit/KUbuntu 64-bit.vmx"
|
env = DesktopEnv( path_to_vm="/home/david/vmware/KUbuntu 64-bit/KUbuntu 64-bit.vmx"
|
||||||
, username="david"
|
, action_space="computer_13"
|
||||||
, password="123456"
|
, snapshot_path="base_setup"
|
||||||
, host="192.168.174.129"
|
, instruction=example["instruction"]
|
||||||
#host="http://192.168.7.129:5000",
|
#, config=example["config"]
|
||||||
#vm_os="windows")
|
#, evaluator=example["evaluator"]
|
||||||
, vm_os="ubuntu"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# reset the environment to certain snapshot
|
||||||
observation = env.reset()
|
observation = env.reset()
|
||||||
done = False
|
done = False
|
||||||
|
|
||||||
while not done:
|
trajectory = [
|
||||||
action = get_human_action()
|
{
|
||||||
observation, reward, done, info = env.step(action)
|
"action_type": "MOVE_TO",
|
||||||
|
"parameters": {
|
||||||
|
"x": 754,
|
||||||
|
"y": 1057
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"action_type": "CLICK", "parameters": {"button": "right", "num_clicks": 1}}
|
||||||
|
]
|
||||||
|
|
||||||
|
for i in range(len(trajectory)):
|
||||||
|
# action = get_human_action()
|
||||||
|
|
||||||
|
# action = {
|
||||||
|
# "action_type": 0,
|
||||||
|
# "click_type": 3,
|
||||||
|
# }
|
||||||
|
print(trajectory[i])
|
||||||
|
|
||||||
|
observation, reward, done, info = env.step(trajectory[i], pause=5)
|
||||||
print("Observation:", observation)
|
print("Observation:", observation)
|
||||||
print("Reward:", reward)
|
print("Reward:", reward)
|
||||||
print("Info:", info)
|
print("Info:", info)
|
||||||
@@ -64,8 +55,12 @@ def human_agent():
|
|||||||
print("The episode is done.")
|
print("The episode is done.")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
result = env.evaluate()
|
||||||
|
print("Result:", result)
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
print("Environment closed.")
|
print("Environment closed.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
human_agent()
|
human_agent()
|
||||||
|
|||||||
0
mm_agents/__init__.py
Normal file
BIN
mm_agents/chrome_start.png
Normal file
|
After Width: | Height: | Size: 16 MiB |
20
mm_agents/fuyu_test.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from transformers import FuyuProcessor, FuyuForCausalLM
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
image = Image.open("stackoverflow.png").convert("RGB")
|
||||||
|
|
||||||
|
# load model and processor
|
||||||
|
model_id = "adept/fuyu-8b"
|
||||||
|
processor = FuyuProcessor.from_pretrained(model_id)
|
||||||
|
model = FuyuForCausalLM.from_pretrained(model_id, device_map="cuda:0")
|
||||||
|
|
||||||
|
# prepare inputs for the model
|
||||||
|
text_prompt = "Description:\n"
|
||||||
|
|
||||||
|
inputs = processor(text=text_prompt, images=image, return_tensors="pt").to("cuda:0")
|
||||||
|
|
||||||
|
# autoregressively generate text
|
||||||
|
generation_output = model.generate(**inputs, max_new_tokens=100)
|
||||||
|
generation_text = processor.batch_decode(generation_output[:, -100:], skip_special_tokens=True)
|
||||||
|
|
||||||
|
print(generation_text)
|
||||||
166
mm_agents/gpt_4v_agent.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
# fixme: Need to be rewrite on new action space
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import base64
|
||||||
|
from desktop_env.envs.desktop_env import Action, MouseClick
|
||||||
|
import json
|
||||||
|
import requests
|
||||||
|
from mm_agents.gpt_4v_prompt import SYS_PROMPT
|
||||||
|
|
||||||
|
|
||||||
|
# Function to encode the image
|
||||||
|
def encode_image(image_path):
|
||||||
|
with open(image_path, "rb") as image_file:
|
||||||
|
return base64.b64encode(image_file.read()).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
def parse_actions_from_string(input_string):
|
||||||
|
# Search for a JSON string within the input string
|
||||||
|
actions = []
|
||||||
|
matches = re.findall(r'```json\s+(.*?)\s+```', input_string, re.DOTALL)
|
||||||
|
if matches:
|
||||||
|
# Assuming there's only one match, parse the JSON string into a dictionary
|
||||||
|
try:
|
||||||
|
for match in matches:
|
||||||
|
action_dict = json.loads(match)
|
||||||
|
actions.append(action_dict)
|
||||||
|
return actions
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
return f"Failed to parse JSON: {e}"
|
||||||
|
else:
|
||||||
|
matches = re.findall(r'```\s+(.*?)\s+```', input_string, re.DOTALL)
|
||||||
|
if matches:
|
||||||
|
# Assuming there's only one match, parse the JSON string into a dictionary
|
||||||
|
try:
|
||||||
|
for match in matches:
|
||||||
|
action_dict = json.loads(match)
|
||||||
|
actions.append(action_dict)
|
||||||
|
return actions
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
return f"Failed to parse JSON: {e}"
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
action_dict = json.loads(input_string)
|
||||||
|
return [action_dict]
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError("Invalid response format: " + input_string)
|
||||||
|
|
||||||
|
|
||||||
|
class GPT4v_Agent:
|
||||||
|
def __init__(self, api_key, instruction, model="gpt-4-vision-preview", max_tokens=300):
|
||||||
|
self.instruction = instruction
|
||||||
|
self.model = model
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
|
self.headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {api_key}"
|
||||||
|
}
|
||||||
|
|
||||||
|
self.trajectory = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": SYS_PROMPT
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
def predict(self, obs):
|
||||||
|
base64_image = encode_image(obs)
|
||||||
|
self.trajectory.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's the next step for instruction '{}'?".format(self.instruction)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{base64_image}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
traj_to_show = []
|
||||||
|
for i in range(len(self.trajectory)):
|
||||||
|
traj_to_show.append(self.trajectory[i]["content"][0]["text"])
|
||||||
|
if len(self.trajectory[i]["content"]) > 1:
|
||||||
|
traj_to_show.append("screenshot_obs")
|
||||||
|
print("Trajectory:", traj_to_show)
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": self.trajectory,
|
||||||
|
"max_tokens": self.max_tokens
|
||||||
|
}
|
||||||
|
response = requests.post("https://api.openai.com/v1/chat/completions", headers=self.headers, json=payload)
|
||||||
|
|
||||||
|
try:
|
||||||
|
actions = self.parse_actions(response.json()['choices'][0]['message']['content'])
|
||||||
|
except:
|
||||||
|
print("Failed to parse action from response:", response.json()['choices'][0]['message']['content'])
|
||||||
|
actions = None
|
||||||
|
|
||||||
|
return actions
|
||||||
|
|
||||||
|
def parse_actions(self, response: str):
|
||||||
|
# response example
|
||||||
|
"""
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"action_type": "CLICK",
|
||||||
|
"click_type": "RIGHT"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
# parse from the response
|
||||||
|
actions = parse_actions_from_string(response)
|
||||||
|
|
||||||
|
# add action into the trajectory
|
||||||
|
self.trajectory.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": response
|
||||||
|
},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
# parse action
|
||||||
|
parsed_actions = []
|
||||||
|
for action in actions:
|
||||||
|
parsed_action = {}
|
||||||
|
action_type = Action[action['action_type']].value
|
||||||
|
parsed_action["action_type"] = action_type
|
||||||
|
|
||||||
|
if action_type == Action.CLICK.value or action_type == Action.MOUSE_DOWN.value or action_type == Action.MOUSE_UP.value:
|
||||||
|
parsed_action["click_type"] = MouseClick[action['click_type']].value
|
||||||
|
|
||||||
|
if action_type == Action.MOUSE_MOVE.value:
|
||||||
|
parsed_action["x"] = action["x"]
|
||||||
|
parsed_action["y"] = action["y"]
|
||||||
|
|
||||||
|
if action_type == Action.KEY.value:
|
||||||
|
parsed_action["key"] = action["key"] # handle the condition of single key and multiple keys
|
||||||
|
|
||||||
|
if action_type == Action.TYPE.value:
|
||||||
|
parsed_action["text"] = action["text"]
|
||||||
|
|
||||||
|
parsed_actions.append(parsed_action)
|
||||||
|
|
||||||
|
return parsed_actions
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# OpenAI API Key
|
||||||
|
api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
agent = GPT4v_Agent(api_key=api_key, instruction="Open Google Sheet")
|
||||||
|
print(agent.predict(obs="stackoverflow.png"))
|
||||||
52
mm_agents/gpt_4v_prompt.txt
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
You will act as an agent which follow my instruction and perform desktop computer tasks as instructed. You must have good knowledge of computer and good internet connection.
|
||||||
|
For each step, you will get an observation of an image, which is the screenshot of the computer screen. And you will predict the action of the computer based on the image.
|
||||||
|
Here is the description of the action space:
|
||||||
|
|
||||||
|
Firstly you need to predict the class of your action, select from one below:
|
||||||
|
- **MOUSE_MOVE**: move the mouse to a specific position
|
||||||
|
- **CLICK**: click on the screen
|
||||||
|
- **MOUSE_DOWN**: press the mouse button
|
||||||
|
- **MOUSE_UP**: release the mouse button
|
||||||
|
- **KEY**: press a key on the keyboard
|
||||||
|
- **KEY_DOWN**: press a key on the keyboard
|
||||||
|
- **KEY_UP**: release a key on the keyboard
|
||||||
|
- **TYPE**: type a string on the keyboard
|
||||||
|
|
||||||
|
Then you need to predict the parameters of your action:
|
||||||
|
- For MOUSE_MOVE, you need to predict the x and y coordinate of the mouse cursor
|
||||||
|
for example, format as:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"action_type": "MOUSE_MOVE",
|
||||||
|
"x": 1319.11,
|
||||||
|
"y": 65.06
|
||||||
|
}
|
||||||
|
```
|
||||||
|
- For [CLICK, MOUSE_DOWN, MOUSE_UP], you need to specify the click_type as well, select from [LEFT, MIDDLE, RIGHT, WHEEL_UP, WHEEL_DOWN], which means you click the left button, middle button, right button, wheel up or wheel down of your mouse:
|
||||||
|
for example, format as:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"action_type": "CLICK",
|
||||||
|
"click_type": "LEFT"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
- For [KEY, KEY_DOWN, KEY_UP, TYPE], you need to choose a(multiple) key(s) from the keyboard, select from [A-Z, 0-9, F1-F12, ESC, TAB, ENTER, SPACE, BACKSPACE, SHIFT, CTRL, ALT, UP, DOWN, LEFT, RIGHT, CAPSLOCK, NUMLOCK, SCROLLLOCK, INSERT, DELETE, HOME, END, PAGEUP, PAGEDOWN]:
|
||||||
|
for example, format as:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"action_type": "TYPE",
|
||||||
|
"text": [
|
||||||
|
"w",
|
||||||
|
"i",
|
||||||
|
"k",
|
||||||
|
"i",
|
||||||
|
"p",
|
||||||
|
"e",
|
||||||
|
"d",
|
||||||
|
"i",
|
||||||
|
"a"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
For every setup, you should only return the action_type and the parameters of your action as a dict, without any other things.
|
||||||
54
mm_agents/gpt_4v_prompt_action.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
SYS_PROMPT = """
|
||||||
|
You will act as an agent which follow my instruction and perform desktop computer tasks as instructed. You must have good knowledge of computer and good internet connection.
|
||||||
|
For each step, you will get an observation of an image, which is the screenshot of the computer screen. And you will predict the action of the computer based on the image.
|
||||||
|
Here is the description of the action space:
|
||||||
|
|
||||||
|
Firstly you need to predict the class of your action, select from one below:
|
||||||
|
- **MOUSE_MOVE**: move the mouse to a specific position
|
||||||
|
- **CLICK**: click on the screen
|
||||||
|
- **MOUSE_DOWN**: press the mouse button
|
||||||
|
- **MOUSE_UP**: release the mouse button
|
||||||
|
- **KEY**: press a key on the keyboard
|
||||||
|
- **KEY_DOWN**: press a key on the keyboard
|
||||||
|
- **KEY_UP**: release a key on the keyboard
|
||||||
|
- **TYPE**: type a string on the keyboard
|
||||||
|
|
||||||
|
Then you need to predict the parameters of your action:
|
||||||
|
- For MOUSE_MOVE, you need to predict the x and y coordinate of the mouse cursor, the left top corner of the screen is (0, 0), the right bottom corner of the screen is (1920, 1080)
|
||||||
|
for example, format as:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"action_type": "MOUSE_MOVE",
|
||||||
|
"x": 1319.11,
|
||||||
|
"y": 65.06
|
||||||
|
}
|
||||||
|
```
|
||||||
|
- For [CLICK, MOUSE_DOWN, MOUSE_UP], you need to specify the click_type as well, select from [LEFT, MIDDLE, RIGHT, WHEEL_UP, WHEEL_DOWN], which means you click the left button, middle button, right button, wheel up or wheel down of your mouse:
|
||||||
|
for example, format as:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"action_type": "CLICK",
|
||||||
|
"click_type": "LEFT"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
- For [KEY, KEY_DOWN, KEY_UP], you need to choose a(multiple) key(s) from the keyboard
|
||||||
|
for example, format as:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"action_type": "KEY",
|
||||||
|
"key": "ctrl+c"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
- For TYPE, you need to specify the text you want to type
|
||||||
|
for example, format as:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"action_type": "TYPE",
|
||||||
|
"text": "hello world"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
For every step, you should only return the action_type and the parameters of your action as a dict, without any other things. You MUST wrap the dict with backticks (\`).
|
||||||
|
You can predict multiple actions at one step, but you should only return one action for each step.
|
||||||
|
You MUST choose and ONLY CHOOSE from the action space above, otherwise your action will be considered as invalid and you will get a penalty.
|
||||||
|
"""
|
||||||
8
mm_agents/gpt_4v_prompt_code.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
SYS_PROMPT = """
|
||||||
|
You will act as an agent which follow my instruction and perform desktop computer tasks as instructed. You must have good knowledge of computer and good internet connection.
|
||||||
|
For each step, you will get an observation of an image, which is the screenshot of the computer screen. And you will predict the action of the computer based on the image.
|
||||||
|
|
||||||
|
You are required to use `pyautogui` to perform the action.
|
||||||
|
Return one line or multiple lines of python code to perform the action each time, be time efficient.
|
||||||
|
Return `None` if you cannot perform the action.
|
||||||
|
"""
|
||||||
124
mm_agents/sam_test.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
import requests
|
||||||
|
from transformers import SamModel, SamProcessor
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import os
|
||||||
|
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
|
||||||
|
|
||||||
|
def show_mask(mask, ax, random_color=False):
|
||||||
|
if random_color:
|
||||||
|
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
||||||
|
else:
|
||||||
|
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
|
||||||
|
h, w = mask.shape[-2:]
|
||||||
|
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
||||||
|
ax.imshow(mask_image)
|
||||||
|
|
||||||
|
|
||||||
|
def show_box(box, ax):
|
||||||
|
x0, y0 = box[0], box[1]
|
||||||
|
w, h = box[2] - box[0], box[3] - box[1]
|
||||||
|
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
|
||||||
|
|
||||||
|
|
||||||
|
def show_boxes_on_image(raw_image, boxes):
|
||||||
|
plt.figure(figsize=(10, 10))
|
||||||
|
plt.imshow(raw_image)
|
||||||
|
for box in boxes:
|
||||||
|
show_box(box, plt.gca())
|
||||||
|
plt.axis('on')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def show_points_on_image(raw_image, input_points, input_labels=None):
|
||||||
|
plt.figure(figsize=(10, 10))
|
||||||
|
plt.imshow(raw_image)
|
||||||
|
input_points = np.array(input_points)
|
||||||
|
if input_labels is None:
|
||||||
|
labels = np.ones_like(input_points[:, 0])
|
||||||
|
else:
|
||||||
|
labels = np.array(input_labels)
|
||||||
|
show_points(input_points, labels, plt.gca())
|
||||||
|
plt.axis('on')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None):
|
||||||
|
plt.figure(figsize=(10, 10))
|
||||||
|
plt.imshow(raw_image)
|
||||||
|
input_points = np.array(input_points)
|
||||||
|
if input_labels is None:
|
||||||
|
labels = np.ones_like(input_points[:, 0])
|
||||||
|
else:
|
||||||
|
labels = np.array(input_labels)
|
||||||
|
show_points(input_points, labels, plt.gca())
|
||||||
|
for box in boxes:
|
||||||
|
show_box(box, plt.gca())
|
||||||
|
plt.axis('on')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None):
|
||||||
|
plt.figure(figsize=(10, 10))
|
||||||
|
plt.imshow(raw_image)
|
||||||
|
input_points = np.array(input_points)
|
||||||
|
if input_labels is None:
|
||||||
|
labels = np.ones_like(input_points[:, 0])
|
||||||
|
else:
|
||||||
|
labels = np.array(input_labels)
|
||||||
|
show_points(input_points, labels, plt.gca())
|
||||||
|
for box in boxes:
|
||||||
|
show_box(box, plt.gca())
|
||||||
|
plt.axis('on')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def show_points(coords, labels, ax, marker_size=375):
|
||||||
|
pos_points = coords[labels == 1]
|
||||||
|
neg_points = coords[labels == 0]
|
||||||
|
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',
|
||||||
|
linewidth=1.25)
|
||||||
|
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',
|
||||||
|
linewidth=1.25)
|
||||||
|
|
||||||
|
|
||||||
|
def show_masks_on_image(raw_image, masks, scores):
|
||||||
|
if len(masks.shape) == 4:
|
||||||
|
masks = masks.squeeze()
|
||||||
|
if scores.shape[0] == 1:
|
||||||
|
scores = scores.squeeze()
|
||||||
|
|
||||||
|
nb_predictions = scores.shape[-1]
|
||||||
|
fig, axes = plt.subplots(1, nb_predictions, figsize=(15, 15))
|
||||||
|
|
||||||
|
for i, (mask, score) in enumerate(zip(masks, scores)):
|
||||||
|
mask = mask.cpu().detach()
|
||||||
|
axes[i].imshow(np.array(raw_image))
|
||||||
|
show_mask(mask, axes[i])
|
||||||
|
axes[i].title.set_text(f"Mask {i + 1}, Score: {score.item():.3f}")
|
||||||
|
axes[i].axis("off")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
|
||||||
|
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||||
|
|
||||||
|
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
||||||
|
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
||||||
|
|
||||||
|
plt.imshow(raw_image)
|
||||||
|
|
||||||
|
inputs = processor(raw_image, return_tensors="pt").to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
masks = processor.image_processor.post_process_masks(
|
||||||
|
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
scores = outputs.iou_scores
|
||||||
|
show_masks_on_image(raw_image, masks[0], scores)
|
||||||
BIN
mm_agents/stackoverflow.png
Normal file
|
After Width: | Height: | Size: 1.0 MiB |
@@ -1,5 +1,16 @@
|
|||||||
numpy
|
numpy~=1.24.3
|
||||||
Pillow
|
Pillow~=10.1.0
|
||||||
fabric
|
fabric
|
||||||
gymnasium
|
gymnasium~=0.28.1
|
||||||
requests
|
requests~=2.31.0
|
||||||
|
transformers~=4.35.2
|
||||||
|
torch~=2.1.1+cu118
|
||||||
|
accelerate
|
||||||
|
opencv-python~=4.8.1.78
|
||||||
|
matplotlib~=3.7.4
|
||||||
|
pynput~=1.7.6
|
||||||
|
pyautogui~=0.9.54
|
||||||
|
psutil~=5.9.6
|
||||||
|
tqdm~=4.65.0
|
||||||
|
pandas~=2.0.3
|
||||||
|
flask~=3.0.0
|
||||||
3
resouce_collection/README.md
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# Resource Collection
|
||||||
|
|
||||||
|
Manually gain some insights, then scale with careful code.
|
||||||
52
resouce_collection/youtube/LbreOffice_Impress/get_impress.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from googleapiclient.discovery import build
|
||||||
|
import socket
|
||||||
|
socket.setdefaulttimeout(500)
|
||||||
|
|
||||||
|
|
||||||
|
def search_youtube(api_key, query, max_results=50):
|
||||||
|
youtube = build('youtube', 'v3', developerKey=api_key)
|
||||||
|
|
||||||
|
search_response = youtube.search().list(
|
||||||
|
q=query,
|
||||||
|
part="id,snippet",
|
||||||
|
maxResults=max_results,
|
||||||
|
type="video"
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
videos = []
|
||||||
|
|
||||||
|
for search_result in search_response.get("items", []):
|
||||||
|
if search_result["id"]["kind"] == "youtube#video":
|
||||||
|
video_id = search_result["id"]["videoId"]
|
||||||
|
video_metadata = get_video_metadata(api_key, video_id)
|
||||||
|
videos.append(video_metadata)
|
||||||
|
|
||||||
|
return videos
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_metadata(api_key, video_id):
|
||||||
|
youtube = build('youtube', 'v3', developerKey=api_key)
|
||||||
|
|
||||||
|
request = youtube.videos().list(
|
||||||
|
part="snippet,contentDetails,statistics",
|
||||||
|
id=video_id
|
||||||
|
)
|
||||||
|
response = request.execute()
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
api_key = 'AIzaSyDI_BBExs-HypVZFxgnR5tj5S6-uKyU4vk' # Replace with your actual API key
|
||||||
|
|
||||||
|
# Search for videos related to "VLC player"
|
||||||
|
vlc_related_videos = search_youtube(api_key, "LibreOffice Impress Tutorial", max_results=10)
|
||||||
|
|
||||||
|
# create data folder if not exist
|
||||||
|
if not os.path.exists("data"):
|
||||||
|
os.makedirs("data")
|
||||||
|
|
||||||
|
for video in vlc_related_videos:
|
||||||
|
# store the video metadata into a json file
|
||||||
|
with open(f"data/{video['etag']}.json", "w") as f:
|
||||||
|
json.dump(video, f, indent=4)
|
||||||
52
resouce_collection/youtube/LibreOffice_Calc/get_calc.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from googleapiclient.discovery import build
|
||||||
|
import socket
|
||||||
|
socket.setdefaulttimeout(500)
|
||||||
|
|
||||||
|
|
||||||
|
def search_youtube(api_key, query, max_results=50):
|
||||||
|
youtube = build('youtube', 'v3', developerKey=api_key)
|
||||||
|
|
||||||
|
search_response = youtube.search().list(
|
||||||
|
q=query,
|
||||||
|
part="id,snippet",
|
||||||
|
maxResults=max_results,
|
||||||
|
type="video"
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
videos = []
|
||||||
|
|
||||||
|
for search_result in search_response.get("items", []):
|
||||||
|
if search_result["id"]["kind"] == "youtube#video":
|
||||||
|
video_id = search_result["id"]["videoId"]
|
||||||
|
video_metadata = get_video_metadata(api_key, video_id)
|
||||||
|
videos.append(video_metadata)
|
||||||
|
|
||||||
|
return videos
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_metadata(api_key, video_id):
|
||||||
|
youtube = build('youtube', 'v3', developerKey=api_key)
|
||||||
|
|
||||||
|
request = youtube.videos().list(
|
||||||
|
part="snippet,contentDetails,statistics",
|
||||||
|
id=video_id
|
||||||
|
)
|
||||||
|
response = request.execute()
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
api_key = 'AIzaSyDI_BBExs-HypVZFxgnR5tj5S6-uKyU4vk' # Replace with your actual API key
|
||||||
|
|
||||||
|
# Search for videos related to "VLC player"
|
||||||
|
vlc_related_videos = search_youtube(api_key, "LibreOffice Calc Tutorial", max_results=10)
|
||||||
|
|
||||||
|
# create data folder if not exist
|
||||||
|
if not os.path.exists("data"):
|
||||||
|
os.makedirs("data")
|
||||||
|
|
||||||
|
for video in vlc_related_videos:
|
||||||
|
# store the video metadata into a json file
|
||||||
|
with open(f"data/{video['etag']}.json", "w") as f:
|
||||||
|
json.dump(video, f, indent=4)
|
||||||
52
resouce_collection/youtube/Thunderbird/get_thunderbird.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from googleapiclient.discovery import build
|
||||||
|
import socket
|
||||||
|
socket.setdefaulttimeout(500)
|
||||||
|
|
||||||
|
|
||||||
|
def search_youtube(api_key, query, max_results=50):
|
||||||
|
youtube = build('youtube', 'v3', developerKey=api_key)
|
||||||
|
|
||||||
|
search_response = youtube.search().list(
|
||||||
|
q=query,
|
||||||
|
part="id,snippet",
|
||||||
|
maxResults=max_results,
|
||||||
|
type="video"
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
videos = []
|
||||||
|
|
||||||
|
for search_result in search_response.get("items", []):
|
||||||
|
if search_result["id"]["kind"] == "youtube#video":
|
||||||
|
video_id = search_result["id"]["videoId"]
|
||||||
|
video_metadata = get_video_metadata(api_key, video_id)
|
||||||
|
videos.append(video_metadata)
|
||||||
|
|
||||||
|
return videos
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_metadata(api_key, video_id):
|
||||||
|
youtube = build('youtube', 'v3', developerKey=api_key)
|
||||||
|
|
||||||
|
request = youtube.videos().list(
|
||||||
|
part="snippet,contentDetails,statistics",
|
||||||
|
id=video_id
|
||||||
|
)
|
||||||
|
response = request.execute()
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
api_key = 'AIzaSyDI_BBExs-HypVZFxgnR5tj5S6-uKyU4vk' # Replace with your actual API key
|
||||||
|
|
||||||
|
# Search for videos related to "VLC player"
|
||||||
|
vlc_related_videos = search_youtube(api_key, "Thunderbird Tutorial", max_results=10)
|
||||||
|
|
||||||
|
# create data folder if not exist
|
||||||
|
if not os.path.exists("data"):
|
||||||
|
os.makedirs("data")
|
||||||
|
|
||||||
|
for video in vlc_related_videos:
|
||||||
|
# store the video metadata into a json file
|
||||||
|
with open(f"data/{video['etag']}.json", "w") as f:
|
||||||
|
json.dump(video, f, indent=4)
|
||||||
52
resouce_collection/youtube/Ubuntu/get_ubuntu.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from googleapiclient.discovery import build
|
||||||
|
import socket
|
||||||
|
socket.setdefaulttimeout(500)
|
||||||
|
|
||||||
|
|
||||||
|
def search_youtube(api_key, query, max_results=50):
|
||||||
|
youtube = build('youtube', 'v3', developerKey=api_key)
|
||||||
|
|
||||||
|
search_response = youtube.search().list(
|
||||||
|
q=query,
|
||||||
|
part="id,snippet",
|
||||||
|
maxResults=max_results,
|
||||||
|
type="video"
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
videos = []
|
||||||
|
|
||||||
|
for search_result in search_response.get("items", []):
|
||||||
|
if search_result["id"]["kind"] == "youtube#video":
|
||||||
|
video_id = search_result["id"]["videoId"]
|
||||||
|
video_metadata = get_video_metadata(api_key, video_id)
|
||||||
|
videos.append(video_metadata)
|
||||||
|
|
||||||
|
return videos
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_metadata(api_key, video_id):
|
||||||
|
youtube = build('youtube', 'v3', developerKey=api_key)
|
||||||
|
|
||||||
|
request = youtube.videos().list(
|
||||||
|
part="snippet,contentDetails,statistics",
|
||||||
|
id=video_id
|
||||||
|
)
|
||||||
|
response = request.execute()
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
api_key = 'AIzaSyDI_BBExs-HypVZFxgnR5tj5S6-uKyU4vk' # Replace with your actual API key
|
||||||
|
|
||||||
|
# Search for videos related to "VLC player"
|
||||||
|
vlc_related_videos = search_youtube(api_key, "Ubuntu Desktop Tutorial", max_results=10)
|
||||||
|
|
||||||
|
# create data folder if not exist
|
||||||
|
if not os.path.exists("data"):
|
||||||
|
os.makedirs("data")
|
||||||
|
|
||||||
|
for video in vlc_related_videos:
|
||||||
|
# store the video metadata into a json file
|
||||||
|
with open(f"data/{video['etag']}.json", "w") as f:
|
||||||
|
json.dump(video, f, indent=4)
|
||||||
0
resouce_collection/youtube/__init__.py
Normal file
0
resouce_collection/youtube/vlc_player/__init__.py
Normal file
65
resouce_collection/youtube/vlc_player/vlc_player.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from googleapiclient.discovery import build
|
||||||
|
|
||||||
|
|
||||||
|
def search_youtube(api_key, query, max_results=50, language="en"):
|
||||||
|
youtube = build('youtube', 'v3', developerKey=api_key)
|
||||||
|
|
||||||
|
videos = []
|
||||||
|
next_page_token = None
|
||||||
|
total_results = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
search_response = youtube.search().list(
|
||||||
|
q=query,
|
||||||
|
part="id,snippet",
|
||||||
|
maxResults=max_results,
|
||||||
|
pageToken=next_page_token,
|
||||||
|
type="video",
|
||||||
|
relevanceLanguage=language
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
video_ids = [item['id']['videoId'] for item in search_response.get("items", []) if
|
||||||
|
item['id']['kind'] == 'youtube#video']
|
||||||
|
|
||||||
|
# Fetch metadata for each video
|
||||||
|
videos.extend([get_video_metadata(api_key, video_id) for video_id in video_ids])
|
||||||
|
|
||||||
|
total_results += len(video_ids)
|
||||||
|
next_page_token = search_response.get('nextPageToken')
|
||||||
|
|
||||||
|
if not next_page_token or total_results >= max_results:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Sort videos by view count
|
||||||
|
sorted_videos = sorted(videos, key=lambda x: int(x['items'][0]['statistics']['viewCount']), reverse=True)
|
||||||
|
|
||||||
|
return sorted_videos
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_metadata(api_key, video_id):
|
||||||
|
youtube = build('youtube', 'v3', developerKey=api_key)
|
||||||
|
|
||||||
|
request = youtube.videos().list(
|
||||||
|
part="snippet,contentDetails,statistics",
|
||||||
|
id=video_id
|
||||||
|
)
|
||||||
|
response = request.execute()
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
api_key = 'API_KEY' # Replace with your actual API key
|
||||||
|
|
||||||
|
# Search for videos related to "VLC player"
|
||||||
|
vlc_related_videos = search_youtube(api_key, "VLC player", max_results=10)
|
||||||
|
|
||||||
|
# create data folder if not exist
|
||||||
|
if not os.path.exists("data"):
|
||||||
|
os.makedirs("data")
|
||||||
|
|
||||||
|
for video in vlc_related_videos:
|
||||||
|
# store the video metadata into a json file
|
||||||
|
with open(f"data/{video['etag']}.json", "w") as f:
|
||||||
|
json.dump(video, f, indent=4)
|
||||||
BIN
screenshot.png
|
Before Width: | Height: | Size: 356 KiB After Width: | Height: | Size: 826 KiB |