145 lines
5.6 KiB
Python
145 lines
5.6 KiB
Python
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
|
# SPDX-License-Identifier: MIT
|
|
# Install Azure Cosmos DB SDK if not already
|
|
|
|
import pickle
|
|
from typing import Any, Optional, TypedDict, Union
|
|
|
|
from ..import_utils import optional_import_block, require_optional_import
|
|
from .abstract_cache_base import AbstractCache
|
|
|
|
with optional_import_block():
|
|
from azure.cosmos import CosmosClient, PartitionKey
|
|
from azure.cosmos.exceptions import CosmosResourceNotFoundError
|
|
|
|
|
|
@require_optional_import("azure", "cosmosdb")
|
|
class CosmosDBConfig(TypedDict, total=False):
|
|
connection_string: str
|
|
database_id: str
|
|
container_id: str
|
|
cache_seed: Optional[Union[str, int]]
|
|
client: Optional["CosmosClient"]
|
|
|
|
|
|
@require_optional_import("azure", "cosmosdb")
|
|
class CosmosDBCache(AbstractCache):
|
|
"""Synchronous implementation of AbstractCache using Azure Cosmos DB NoSQL API.
|
|
|
|
This class provides a concrete implementation of the AbstractCache
|
|
interface using Azure Cosmos DB for caching data, with synchronous operations.
|
|
|
|
Attributes:
|
|
seed (Union[str, int]): A seed or namespace used as a partition key.
|
|
client (CosmosClient): The Cosmos DB client used for caching.
|
|
container: The container instance used for caching.
|
|
"""
|
|
|
|
def __init__(self, seed: Union[str, int], cosmosdb_config: CosmosDBConfig):
|
|
"""Initialize the CosmosDBCache instance.
|
|
|
|
Args:
|
|
seed: A seed or namespace for the cache, used as a partition key.
|
|
cosmosdb_config: The configuration for the Cosmos DB cache.
|
|
"""
|
|
self.seed = str(seed)
|
|
self.client = cosmosdb_config.get("client") or CosmosClient.from_connection_string(
|
|
cosmosdb_config["connection_string"]
|
|
)
|
|
database_id = cosmosdb_config.get("database_id", "autogen_cache")
|
|
self.database = self.client.get_database_client(database_id)
|
|
container_id = cosmosdb_config.get("container_id")
|
|
self.container = self.database.create_container_if_not_exists(
|
|
id=container_id, partition_key=PartitionKey(path="/partitionKey")
|
|
)
|
|
|
|
@classmethod
|
|
def create_cache(cls, seed: Union[str, int], cosmosdb_config: CosmosDBConfig):
|
|
"""Factory method to create a CosmosDBCache instance based on the provided configuration.
|
|
This method decides whether to use an existing CosmosClient or create a new one.
|
|
"""
|
|
if "client" in cosmosdb_config and isinstance(cosmosdb_config["client"], CosmosClient):
|
|
return cls.from_existing_client(seed, **cosmosdb_config)
|
|
else:
|
|
return cls.from_config(seed, cosmosdb_config)
|
|
|
|
@classmethod
|
|
def from_config(cls, seed: Union[str, int], cosmosdb_config: CosmosDBConfig):
|
|
return cls(str(seed), cosmosdb_config)
|
|
|
|
@classmethod
|
|
def from_connection_string(cls, seed: Union[str, int], connection_string: str, database_id: str, container_id: str):
|
|
config = {"connection_string": connection_string, "database_id": database_id, "container_id": container_id}
|
|
return cls(str(seed), config)
|
|
|
|
@classmethod
|
|
def from_existing_client(cls, seed: Union[str, int], client: "CosmosClient", database_id: str, container_id: str):
|
|
config = {"client": client, "database_id": database_id, "container_id": container_id}
|
|
return cls(str(seed), config)
|
|
|
|
def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
|
|
"""Retrieve an item from the Cosmos DB cache.
|
|
|
|
Args:
|
|
key (str): The key identifying the item in the cache.
|
|
default (optional): The default value to return if the key is not found.
|
|
|
|
Returns:
|
|
The deserialized value associated with the key if found, else the default value.
|
|
"""
|
|
try:
|
|
response = self.container.read_item(item=key, partition_key=str(self.seed))
|
|
return pickle.loads(response["data"])
|
|
except CosmosResourceNotFoundError:
|
|
return default
|
|
except Exception as e:
|
|
# Log the exception or rethrow after logging if needed
|
|
# Consider logging or handling the error appropriately here
|
|
raise e
|
|
|
|
def set(self, key: str, value: Any) -> None:
|
|
"""Set an item in the Cosmos DB cache.
|
|
|
|
Args:
|
|
key (str): The key under which the item is to be stored.
|
|
value: The value to be stored in the cache.
|
|
|
|
Notes:
|
|
The value is serialized using pickle before being stored.
|
|
"""
|
|
try:
|
|
serialized_value = pickle.dumps(value)
|
|
item = {"id": key, "partitionKey": str(self.seed), "data": serialized_value}
|
|
self.container.upsert_item(item)
|
|
except Exception as e:
|
|
# Log or handle exception
|
|
raise e
|
|
|
|
def close(self) -> None:
|
|
"""Close the Cosmos DB client.
|
|
|
|
Perform any necessary cleanup, such as closing network connections.
|
|
"""
|
|
# CosmosClient doesn"t require explicit close in the current SDK
|
|
# If you created the client inside this class, you should close it if necessary
|
|
pass
|
|
|
|
def __enter__(self):
|
|
"""Context management entry.
|
|
|
|
Returns:
|
|
self: The instance itself.
|
|
"""
|
|
return self
|
|
|
|
def __exit__(self, exc_type: Optional[type], exc_value: Optional[Exception], traceback: Optional[Any]) -> None:
|
|
"""Context management exit.
|
|
|
|
Perform cleanup actions such as closing the Cosmos DB client.
|
|
"""
|
|
self.close()
|