| @@ -0,0 +1,49 @@ | |||
| import json | |||
| from enum import Enum | |||
| from json import JSONDecodeError | |||
| from typing import Optional | |||
| from extensions.ext_redis import redis_client | |||
| class ToolProviderCredentialsCacheType(Enum): | |||
| PROVIDER = "tool_provider" | |||
| class ToolProviderCredentialsCache: | |||
| def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType): | |||
| self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}" | |||
| def get(self) -> Optional[dict]: | |||
| """ | |||
| Get cached model provider credentials. | |||
| :return: | |||
| """ | |||
| cached_provider_credentials = redis_client.get(self.cache_key) | |||
| if cached_provider_credentials: | |||
| try: | |||
| cached_provider_credentials = cached_provider_credentials.decode('utf-8') | |||
| cached_provider_credentials = json.loads(cached_provider_credentials) | |||
| except JSONDecodeError: | |||
| return None | |||
| return cached_provider_credentials | |||
| else: | |||
| return None | |||
| def set(self, credentials: dict) -> None: | |||
| """ | |||
| Cache model provider credentials. | |||
| :param credentials: provider credentials | |||
| :return: | |||
| """ | |||
| redis_client.setex(self.cache_key, 86400, json.dumps(credentials)) | |||
| def delete(self) -> None: | |||
| """ | |||
| Delete cached model provider credentials. | |||
| :return: | |||
| """ | |||
| redis_client.delete(self.cache_key) | |||
| @@ -0,0 +1,15 @@ | |||
| - bing | |||
| - wikipedia | |||
| - dalle | |||
| - azuredalle | |||
| - webscraper | |||
| - wolframalpha | |||
| - github | |||
| - chart | |||
| - time | |||
| - yahoo | |||
| - stablediffusion | |||
| - vectorizer | |||
| - youtube | |||
| - gaode | |||
| @@ -1,31 +1,29 @@ | |||
| from typing import List | |||
| from core.tools.entities.user_entities import UserToolProvider | |||
| from core.tools.entities.tool_entities import ToolProviderType | |||
| from typing import List | |||
| from yaml import load, FullLoader | |||
| position = { | |||
| 'google': 1, | |||
| 'bing': 2, | |||
| 'wikipedia': 2, | |||
| 'dalle': 3, | |||
| 'webscraper': 4, | |||
| 'wolframalpha': 5, | |||
| 'chart': 6, | |||
| 'time': 7, | |||
| 'yahoo': 8, | |||
| 'stablediffusion': 9, | |||
| 'vectorizer': 10, | |||
| 'youtube': 11, | |||
| 'github': 12, | |||
| 'gaode': 13 | |||
| } | |||
| import os.path | |||
| position = {} | |||
| class BuiltinToolProviderSort: | |||
| @staticmethod | |||
| def sort(providers: List[UserToolProvider]) -> List[UserToolProvider]: | |||
| global position | |||
| if not position: | |||
| tmp_position = {} | |||
| file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml') | |||
| with open(file_path, 'r') as f: | |||
| for pos, val in enumerate(load(f, Loader=FullLoader)): | |||
| tmp_position[val] = pos | |||
| position = tmp_position | |||
| def sort_compare(provider: UserToolProvider) -> int: | |||
| # if provider.type == UserToolProvider.ProviderType.MODEL: | |||
| # return position.get(f'model_provider.{provider.name}', 10000) | |||
| return position.get(provider.name, 10000) | |||
| sorted_providers = sorted(providers, key=sort_compare) | |||
| return sorted_providers | |||
| return sorted_providers | |||
| @@ -1,10 +1,10 @@ | |||
| from typing import Any, Dict | |||
| from typing import Dict, Any | |||
| from pydantic import BaseModel | |||
| from core.helper import encrypter | |||
| from core.tools.entities.tool_entities import ToolProviderCredentials | |||
| from core.tools.provider.tool_provider import ToolProviderController | |||
| from pydantic import BaseModel | |||
| from core.helper import encrypter | |||
| from core.helper.tool_provider_cache import ToolProviderCredentialsCacheType, ToolProviderCredentialsCache | |||
| class ToolConfiguration(BaseModel): | |||
| tenant_id: str | |||
| @@ -63,8 +63,15 @@ class ToolConfiguration(BaseModel): | |||
| return a deep copy of credentials with decrypted values | |||
| """ | |||
| cache = ToolProviderCredentialsCache( | |||
| tenant_id=self.tenant_id, | |||
| identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}', | |||
| cache_type=ToolProviderCredentialsCacheType.PROVIDER | |||
| ) | |||
| cached_credentials = cache.get() | |||
| if cached_credentials: | |||
| return cached_credentials | |||
| credentials = self._deep_copy(credentials) | |||
| # get fields need to be decrypted | |||
| fields = self.provider_controller.get_credentials_schema() | |||
| for field_name, field in fields.items(): | |||
| @@ -74,5 +81,6 @@ class ToolConfiguration(BaseModel): | |||
| credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name]) | |||
| except: | |||
| pass | |||
| cache.set(credentials) | |||
| return credentials | |||