Переглянути джерело

feat: tool credentials cache and introduce _position.yaml (#2386)

tags/0.5.4
Yeuoly 1 рік тому
джерело
коміт
5010706d8b
Аккаунт користувача з таким Email не знайдено

+ 49
- 0
api/core/helper/tool_provider_cache.py Переглянути файл

@@ -0,0 +1,49 @@
import json
from enum import Enum
from json import JSONDecodeError
from typing import Optional

from extensions.ext_redis import redis_client


class ToolProviderCredentialsCacheType(Enum):
PROVIDER = "tool_provider"

class ToolProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType):
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"

def get(self) -> Optional[dict]:
"""
Get cached model provider credentials.

:return:
"""
cached_provider_credentials = redis_client.get(self.cache_key)
if cached_provider_credentials:
try:
cached_provider_credentials = cached_provider_credentials.decode('utf-8')
cached_provider_credentials = json.loads(cached_provider_credentials)
except JSONDecodeError:
return None

return cached_provider_credentials
else:
return None

def set(self, credentials: dict) -> None:
"""
Cache model provider credentials.

:param credentials: provider credentials
:return:
"""
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))

def delete(self) -> None:
"""
Delete cached model provider credentials.

:return:
"""
redis_client.delete(self.cache_key)

+ 15
- 0
api/core/tools/provider/_position.yaml Переглянути файл

@@ -0,0 +1,15 @@
- google
- bing
- wikipedia
- dalle
- azuredalle
- webscraper
- wolframalpha
- github
- chart
- time
- yahoo
- stablediffusion
- vectorizer
- youtube
- gaode

+ 17
- 19
api/core/tools/provider/builtin/_positions.py Переглянути файл

@@ -1,31 +1,29 @@
from typing import List

from core.tools.entities.user_entities import UserToolProvider
from core.tools.entities.tool_entities import ToolProviderType
from typing import List
from yaml import load, FullLoader

position = {
'google': 1,
'bing': 2,
'wikipedia': 2,
'dalle': 3,
'webscraper': 4,
'wolframalpha': 5,
'chart': 6,
'time': 7,
'yahoo': 8,
'stablediffusion': 9,
'vectorizer': 10,
'youtube': 11,
'github': 12,
'gaode': 13
}
import os.path

position = {}

class BuiltinToolProviderSort:
@staticmethod
def sort(providers: List[UserToolProvider]) -> List[UserToolProvider]:
global position
if not position:
tmp_position = {}
file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml')
with open(file_path, 'r') as f:
for pos, val in enumerate(load(f, Loader=FullLoader)):
tmp_position[val] = pos
position = tmp_position

def sort_compare(provider: UserToolProvider) -> int:
# if provider.type == UserToolProvider.ProviderType.MODEL:
# return position.get(f'model_provider.{provider.name}', 10000)
return position.get(provider.name, 10000)
sorted_providers = sorted(providers, key=sort_compare)

return sorted_providers
return sorted_providers

+ 14
- 6
api/core/tools/utils/configuration.py Переглянути файл

@@ -1,10 +1,10 @@
from typing import Any, Dict
from typing import Dict, Any
from pydantic import BaseModel

from core.helper import encrypter
from core.tools.entities.tool_entities import ToolProviderCredentials
from core.tools.provider.tool_provider import ToolProviderController
from pydantic import BaseModel
from core.helper import encrypter
from core.helper.tool_provider_cache import ToolProviderCredentialsCacheType, ToolProviderCredentialsCache

class ToolConfiguration(BaseModel):
tenant_id: str
@@ -63,8 +63,15 @@ class ToolConfiguration(BaseModel):

return a deep copy of credentials with decrypted values
"""
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER
)
cached_credentials = cache.get()
if cached_credentials:
return cached_credentials
credentials = self._deep_copy(credentials)

# get fields need to be decrypted
fields = self.provider_controller.get_credentials_schema()
for field_name, field in fields.items():
@@ -74,5 +81,6 @@ class ToolConfiguration(BaseModel):
credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
except:
pass

cache.set(credentials)
return credentials

Завантаження…
Відмінити
Зберегти