CoACT initialize (#292)
This commit is contained in:
@@ -0,0 +1,14 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .authentication import GoogleCredentialsLocalProvider, GoogleCredentialsProvider
|
||||
from .drive import GoogleDriveToolkit
|
||||
from .toolkit_protocol import GoogleToolkitProtocol
|
||||
|
||||
__all__ = [
|
||||
"GoogleCredentialsLocalProvider",
|
||||
"GoogleCredentialsProvider",
|
||||
"GoogleDriveToolkit",
|
||||
"GoogleToolkitProtocol",
|
||||
]
|
||||
@@ -0,0 +1,11 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .credentials_local_provider import GoogleCredentialsLocalProvider
|
||||
from .credentials_provider import GoogleCredentialsProvider
|
||||
|
||||
__all__ = [
|
||||
"GoogleCredentialsLocalProvider",
|
||||
"GoogleCredentialsProvider",
|
||||
]
|
||||
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
from .....doc_utils import export_module
|
||||
from .....import_utils import optional_import_block
|
||||
from .credentials_provider import GoogleCredentialsProvider
|
||||
|
||||
with optional_import_block():
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
|
||||
__all__ = ["GoogleCredenentialsHostedProvider"]
|
||||
|
||||
|
||||
@export_module("autogen.tools.experimental.google.authentication")
|
||||
class GoogleCredenentialsHostedProvider(GoogleCredentialsProvider):
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int = 8080,
|
||||
*,
|
||||
kwargs: dict[str, str],
|
||||
) -> None:
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._kwargs = kwargs
|
||||
|
||||
raise NotImplementedError("This class is not implemented yet.")
|
||||
|
||||
@property
|
||||
def host(self) -> str:
|
||||
"""The host from which to get the credentials."""
|
||||
return self._host
|
||||
|
||||
@property
|
||||
def port(self) -> int:
|
||||
"""The port from which to get the credentials."""
|
||||
return self._port
|
||||
|
||||
def get_credentials(self) -> "Credentials": # type: ignore[no-any-unimported]
|
||||
raise NotImplementedError("This class is not implemented yet.")
|
||||
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from .....doc_utils import export_module
|
||||
from .....import_utils import optional_import_block, require_optional_import
|
||||
from .credentials_provider import GoogleCredentialsProvider
|
||||
|
||||
with optional_import_block():
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow
|
||||
|
||||
|
||||
__all__ = ["GoogleCredentialsLocalProvider"]
|
||||
|
||||
|
||||
@export_module("autogen.tools.experimental.google.authentication")
|
||||
class GoogleCredentialsLocalProvider(GoogleCredentialsProvider):
|
||||
def __init__(
|
||||
self,
|
||||
client_secret_file: str,
|
||||
scopes: list[str], # e.g. ['https://www.googleapis.com/auth/drive/readonly']
|
||||
token_file: Optional[str] = None,
|
||||
port: int = 8080,
|
||||
) -> None:
|
||||
"""A Google credentials provider that gets the credentials locally.
|
||||
|
||||
Args:
|
||||
client_secret_file (str): The path to the client secret file.
|
||||
scopes (list[str]): The scopes to request.
|
||||
token_file (str): Optional path to the token file. If not provided, the token will not be saved.
|
||||
port (int): The port from which to get the credentials.
|
||||
"""
|
||||
self.client_secret_file = client_secret_file
|
||||
self.scopes = scopes
|
||||
self.token_file = token_file
|
||||
self._port = port
|
||||
|
||||
@property
|
||||
def host(self) -> str:
|
||||
"""Localhost is the default host."""
|
||||
return "localhost"
|
||||
|
||||
@property
|
||||
def port(self) -> int:
|
||||
"""The port from which to get the credentials."""
|
||||
return self._port
|
||||
|
||||
@require_optional_import(
|
||||
[
|
||||
"google_auth_httplib2",
|
||||
"google_auth_oauthlib",
|
||||
],
|
||||
"google-api",
|
||||
)
|
||||
def _refresh_or_get_new_credentials(self, creds: Optional["Credentials"]) -> "Credentials": # type: ignore[no-any-unimported]
|
||||
if creds and creds.expired and creds.refresh_token:
|
||||
creds.refresh(Request()) # type: ignore[no-untyped-call]
|
||||
else:
|
||||
flow = InstalledAppFlow.from_client_secrets_file(self.client_secret_file, self.scopes)
|
||||
creds = flow.run_local_server(host=self.host, port=self.port)
|
||||
return creds # type: ignore[return-value]
|
||||
|
||||
@require_optional_import(
|
||||
[
|
||||
"google_auth_httplib2",
|
||||
"google_auth_oauthlib",
|
||||
],
|
||||
"google-api",
|
||||
)
|
||||
def get_credentials(self) -> "Credentials": # type: ignore[no-any-unimported]
|
||||
"""Get the Google credentials."""
|
||||
creds = None
|
||||
if self.token_file and os.path.exists(self.token_file):
|
||||
creds = Credentials.from_authorized_user_file(self.token_file) # type: ignore[no-untyped-call]
|
||||
|
||||
# If there are no (valid) credentials available, let the user log in.
|
||||
if not creds or not creds.valid:
|
||||
creds = self._refresh_or_get_new_credentials(creds)
|
||||
|
||||
if self.token_file:
|
||||
# Save the credentials for the next run
|
||||
with open(self.token_file, "w") as token:
|
||||
token.write(creds.to_json()) # type: ignore[no-untyped-call]
|
||||
|
||||
return creds # type: ignore[no-any-return]
|
||||
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
from typing import Optional, Protocol, runtime_checkable
|
||||
|
||||
from .....doc_utils import export_module
|
||||
from .....import_utils import optional_import_block
|
||||
|
||||
with optional_import_block():
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
|
||||
__all__ = ["GoogleCredentialsProvider"]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@export_module("autogen.tools.experimental.google.authentication")
|
||||
class GoogleCredentialsProvider(Protocol):
|
||||
"""A protocol for Google credentials provider."""
|
||||
|
||||
def get_credentials(self) -> Optional["Credentials"]: # type: ignore[no-any-unimported]
|
||||
"""Get the Google credentials."""
|
||||
...
|
||||
|
||||
@property
|
||||
def host(self) -> str:
|
||||
"""The host from which to get the credentials."""
|
||||
...
|
||||
|
||||
@property
|
||||
def port(self) -> int:
|
||||
"""The port from which to get the credentials."""
|
||||
...
|
||||
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .toolkit import GoogleDriveToolkit
|
||||
|
||||
__all__ = [
|
||||
"GoogleDriveToolkit",
|
||||
]
|
||||
@@ -0,0 +1,124 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import io
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from .....import_utils import optional_import_block, require_optional_import
|
||||
from ..model import GoogleFileInfo
|
||||
|
||||
with optional_import_block():
|
||||
from googleapiclient.http import MediaIoBaseDownload
|
||||
|
||||
|
||||
__all__ = [
|
||||
"download_file",
|
||||
"list_files_and_folders",
|
||||
]
|
||||
|
||||
|
||||
@require_optional_import(
|
||||
[
|
||||
"googleapiclient",
|
||||
],
|
||||
"google-api",
|
||||
)
|
||||
def list_files_and_folders(service: Any, page_size: int, folder_id: Optional[str]) -> list[GoogleFileInfo]:
|
||||
kwargs = {
|
||||
"pageSize": page_size,
|
||||
"fields": "nextPageToken, files(id, name, mimeType)",
|
||||
}
|
||||
if folder_id:
|
||||
kwargs["q"] = f"'{folder_id}' in parents and trashed=false" # Search for files in the folder
|
||||
response = service.files().list(**kwargs).execute()
|
||||
result = response.get("files", [])
|
||||
if not isinstance(result, list):
|
||||
raise ValueError(f"Expected a list of files, but got {result}")
|
||||
result = [GoogleFileInfo(**file_info) for file_info in result]
|
||||
return result
|
||||
|
||||
|
||||
def _get_file_extension(mime_type: str) -> Optional[str]:
|
||||
"""Returns the correct file extension for a given MIME type."""
|
||||
mime_extensions = {
|
||||
"application/vnd.google-apps.document": "docx", # Google Docs → Word
|
||||
"application/vnd.google-apps.spreadsheet": "csv", # Google Sheets → CSV
|
||||
"application/vnd.google-apps.presentation": "pptx", # Google Slides → PowerPoint
|
||||
"video/quicktime": "mov",
|
||||
"application/vnd.google.colaboratory": "ipynb",
|
||||
"application/pdf": "pdf",
|
||||
"image/jpeg": "jpg",
|
||||
"image/png": "png",
|
||||
"text/plain": "txt",
|
||||
"application/zip": "zip",
|
||||
}
|
||||
|
||||
return mime_extensions.get(mime_type)
|
||||
|
||||
|
||||
@require_optional_import(
|
||||
[
|
||||
"googleapiclient",
|
||||
],
|
||||
"google-api",
|
||||
)
|
||||
def download_file(
|
||||
service: Any,
|
||||
file_id: str,
|
||||
file_name: str,
|
||||
mime_type: str,
|
||||
download_folder: Path,
|
||||
subfolder_path: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Download or export file based on its MIME type, optionally saving to a subfolder."""
|
||||
file_extension = _get_file_extension(mime_type)
|
||||
if file_extension and (not file_name.lower().endswith(file_extension.lower())):
|
||||
file_name = f"{file_name}.{file_extension}"
|
||||
|
||||
# Define export formats for Google Docs, Sheets, and Slides
|
||||
export_mime_types = {
|
||||
"application/vnd.google-apps.document": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", # Google Docs → Word
|
||||
"application/vnd.google-apps.spreadsheet": "text/csv", # Google Sheets → CSV
|
||||
"application/vnd.google-apps.presentation": "application/vnd.openxmlformats-officedocument.presentationml.presentation", # Google Slides → PowerPoint
|
||||
}
|
||||
|
||||
# Google Docs, Sheets, and Slides cannot be downloaded directly using service.files().get_media() because they are Google-native files
|
||||
if mime_type in export_mime_types:
|
||||
request = service.files().export(fileId=file_id, mimeType=export_mime_types[mime_type])
|
||||
else:
|
||||
# Download normal files (videos, images, etc.)
|
||||
request = service.files().get_media(fileId=file_id)
|
||||
|
||||
# Determine the final destination directory
|
||||
destination_dir = download_folder
|
||||
if subfolder_path:
|
||||
destination_dir = download_folder / subfolder_path
|
||||
# Ensure the subfolder exists, create it if necessary
|
||||
destination_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Construct the full path for the file
|
||||
file_path = destination_dir / file_name
|
||||
|
||||
# Save file
|
||||
try:
|
||||
with io.BytesIO() as buffer:
|
||||
downloader = MediaIoBaseDownload(buffer, request)
|
||||
done = False
|
||||
while not done:
|
||||
_, done = downloader.next_chunk()
|
||||
|
||||
buffer.seek(0)
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(buffer.getvalue())
|
||||
|
||||
# Print out the relative path of the downloaded file
|
||||
relative_path = Path(subfolder_path) / file_name if subfolder_path else Path(file_name)
|
||||
return f"✅ Downloaded: {relative_path}"
|
||||
|
||||
except Exception as e:
|
||||
# Error message if unable to download
|
||||
relative_path = Path(subfolder_path) / file_name if subfolder_path else Path(file_name)
|
||||
return f"❌ FAILED to download {relative_path}: {e}"
|
||||
@@ -0,0 +1,88 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Literal, Optional, Union
|
||||
|
||||
from .....doc_utils import export_module
|
||||
from .....import_utils import optional_import_block
|
||||
from .... import Toolkit, tool
|
||||
from ..model import GoogleFileInfo
|
||||
from ..toolkit_protocol import GoogleToolkitProtocol
|
||||
from .drive_functions import download_file, list_files_and_folders
|
||||
|
||||
with optional_import_block():
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
__all__ = [
|
||||
"GoogleDriveToolkit",
|
||||
]
|
||||
|
||||
|
||||
@export_module("autogen.tools.experimental.google.drive")
|
||||
class GoogleDriveToolkit(Toolkit, GoogleToolkitProtocol):
|
||||
"""A tool map for Google Drive."""
|
||||
|
||||
def __init__( # type: ignore[no-any-unimported]
|
||||
self,
|
||||
*,
|
||||
credentials: "Credentials",
|
||||
download_folder: Union[Path, str],
|
||||
exclude: Optional[list[Literal["list_drive_files_and_folders", "download_file_from_drive"]]] = None,
|
||||
api_version: str = "v3",
|
||||
) -> None:
|
||||
"""Initialize the Google Drive tool map.
|
||||
|
||||
Args:
|
||||
credentials: The Google OAuth2 credentials.
|
||||
download_folder: The folder to download files to.
|
||||
exclude: The tool names to exclude.
|
||||
api_version: The Google Drive API version to use."
|
||||
"""
|
||||
self.service = build(serviceName="drive", version=api_version, credentials=credentials)
|
||||
|
||||
if isinstance(download_folder, str):
|
||||
download_folder = Path(download_folder)
|
||||
download_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@tool(description="List files and folders in a Google Drive")
|
||||
def list_drive_files_and_folders(
|
||||
page_size: Annotated[int, "The number of files to list per page."] = 10,
|
||||
folder_id: Annotated[
|
||||
Optional[str],
|
||||
"The ID of the folder to list files from. If not provided, lists all files in the root folder.",
|
||||
] = None,
|
||||
) -> list[GoogleFileInfo]:
|
||||
return list_files_and_folders(service=self.service, page_size=page_size, folder_id=folder_id)
|
||||
|
||||
@tool(description="download a file from Google Drive")
|
||||
def download_file_from_drive(
|
||||
file_info: Annotated[GoogleFileInfo, "The file info to download."],
|
||||
subfolder_path: Annotated[
|
||||
Optional[str],
|
||||
"The subfolder path to save the file in. If not provided, saves in the main download folder.",
|
||||
] = None,
|
||||
) -> str:
|
||||
return download_file(
|
||||
service=self.service,
|
||||
file_id=file_info.id,
|
||||
file_name=file_info.name,
|
||||
mime_type=file_info.mime_type,
|
||||
download_folder=download_folder,
|
||||
subfolder_path=subfolder_path,
|
||||
)
|
||||
|
||||
if exclude is None:
|
||||
exclude = []
|
||||
|
||||
tools = [tool for tool in [list_drive_files_and_folders, download_file_from_drive] if tool.name not in exclude]
|
||||
super().__init__(tools=tools)
|
||||
|
||||
@classmethod
|
||||
def recommended_scopes(cls) -> list[str]:
|
||||
"""Return the recommended scopes manatory for using tools from this tool map."""
|
||||
return [
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
]
|
||||
17
mm_agents/coact/autogen/tools/experimental/google/model.py
Normal file
17
mm_agents/coact/autogen/tools/experimental/google/model.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
__all__ = [
|
||||
"GoogleFileInfo",
|
||||
]
|
||||
|
||||
|
||||
class GoogleFileInfo(BaseModel):
|
||||
name: Annotated[str, Field(description="The name of the file.")]
|
||||
id: Annotated[str, Field(description="The ID of the file.")]
|
||||
mime_type: Annotated[str, Field(alias="mimeType", description="The MIME type of the file.")]
|
||||
@@ -0,0 +1,19 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
__all__ = [
|
||||
"GoogleToolkitProtocol",
|
||||
]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class GoogleToolkitProtocol(Protocol):
|
||||
"""A protocol for Google tool maps."""
|
||||
|
||||
@classmethod
|
||||
def recommended_scopes(cls) -> list[str]:
|
||||
"""Defines a required static method without implementation."""
|
||||
...
|
||||
Reference in New Issue
Block a user