Co-authored-by: StyleZhang <jasonapring2015@outlook.com>tags/0.3.19
| @@ -63,6 +63,9 @@ class ModelProviderFactory: | |||
| elif provider_name == 'openllm': | |||
| from core.model_providers.providers.openllm_provider import OpenLLMProvider | |||
| return OpenLLMProvider | |||
| elif provider_name == 'localai': | |||
| from core.model_providers.providers.localai_provider import LocalAIProvider | |||
| return LocalAIProvider | |||
| else: | |||
| raise NotImplementedError | |||
| @@ -0,0 +1,29 @@ | |||
| from langchain.embeddings import LocalAIEmbeddings | |||
| from replicate.exceptions import ModelError, ReplicateError | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| from core.model_providers.models.embedding.base import BaseEmbedding | |||
| class LocalAIEmbedding(BaseEmbedding): | |||
| def __init__(self, model_provider: BaseModelProvider, name: str): | |||
| credentials = model_provider.get_model_credentials( | |||
| model_name=name, | |||
| model_type=self.type | |||
| ) | |||
| client = LocalAIEmbeddings( | |||
| model=name, | |||
| openai_api_key="1", | |||
| openai_api_base=credentials['server_url'], | |||
| ) | |||
| super().__init__(model_provider, client, name) | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| if isinstance(ex, (ModelError, ReplicateError)): | |||
| return LLMBadRequestError(f"LocalAI embedding: {str(ex)}") | |||
| else: | |||
| return ex | |||
| @@ -0,0 +1,131 @@ | |||
| import logging | |||
| from typing import List, Optional, Any | |||
| import openai | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.schema import LLMResult, get_buffer_string | |||
| from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ | |||
| LLMRateLimitError, LLMAuthorizationError | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI | |||
| from core.third_party.langchain.llms.open_ai import EnhanceOpenAI | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.model_providers.models.entity.message import PromptMessage | |||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | |||
| class LocalAIModel(BaseLLM): | |||
| def __init__(self, model_provider: BaseModelProvider, | |||
| name: str, | |||
| model_kwargs: ModelKwargs, | |||
| streaming: bool = False, | |||
| callbacks: Callbacks = None): | |||
| credentials = model_provider.get_model_credentials( | |||
| model_name=name, | |||
| model_type=self.type | |||
| ) | |||
| if credentials['completion_type'] == 'chat_completion': | |||
| self.model_mode = ModelMode.CHAT | |||
| else: | |||
| self.model_mode = ModelMode.COMPLETION | |||
| super().__init__(model_provider, name, model_kwargs, streaming, callbacks) | |||
| def _init_client(self) -> Any: | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | |||
| if self.model_mode == ModelMode.COMPLETION: | |||
| client = EnhanceOpenAI( | |||
| model_name=self.name, | |||
| streaming=self.streaming, | |||
| callbacks=self.callbacks, | |||
| request_timeout=60, | |||
| openai_api_key="1", | |||
| openai_api_base=self.credentials['server_url'] + '/v1', | |||
| **provider_model_kwargs | |||
| ) | |||
| else: | |||
| extra_model_kwargs = { | |||
| 'top_p': provider_model_kwargs.get('top_p') | |||
| } | |||
| client = EnhanceChatOpenAI( | |||
| model_name=self.name, | |||
| temperature=provider_model_kwargs.get('temperature'), | |||
| max_tokens=provider_model_kwargs.get('max_tokens'), | |||
| model_kwargs=extra_model_kwargs, | |||
| streaming=self.streaming, | |||
| callbacks=self.callbacks, | |||
| request_timeout=60, | |||
| openai_api_key="1", | |||
| openai_api_base=self.credentials['server_url'] + '/v1' | |||
| ) | |||
| return client | |||
| def _run(self, messages: List[PromptMessage], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| **kwargs) -> LLMResult: | |||
| """ | |||
| run predict by prompt messages and stop words. | |||
| :param messages: | |||
| :param stop: | |||
| :param callbacks: | |||
| :return: | |||
| """ | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| return self._client.generate([prompts], stop, callbacks) | |||
| def get_num_tokens(self, messages: List[PromptMessage]) -> int: | |||
| """ | |||
| get num tokens of prompt messages. | |||
| :param messages: | |||
| :return: | |||
| """ | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| if isinstance(prompts, str): | |||
| return self._client.get_num_tokens(prompts) | |||
| else: | |||
| return max(sum([self._client.get_num_tokens(get_buffer_string([m])) for m in prompts]) - len(prompts), 0) | |||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) | |||
| if self.model_mode == ModelMode.COMPLETION: | |||
| for k, v in provider_model_kwargs.items(): | |||
| if hasattr(self.client, k): | |||
| setattr(self.client, k, v) | |||
| else: | |||
| extra_model_kwargs = { | |||
| 'top_p': provider_model_kwargs.get('top_p') | |||
| } | |||
| self.client.temperature = provider_model_kwargs.get('temperature') | |||
| self.client.max_tokens = provider_model_kwargs.get('max_tokens') | |||
| self.client.model_kwargs = extra_model_kwargs | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| if isinstance(ex, openai.error.InvalidRequestError): | |||
| logging.warning("Invalid request to LocalAI API.") | |||
| return LLMBadRequestError(str(ex)) | |||
| elif isinstance(ex, openai.error.APIConnectionError): | |||
| logging.warning("Failed to connect to LocalAI API.") | |||
| return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex)) | |||
| elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)): | |||
| logging.warning("LocalAI service unavailable.") | |||
| return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex)) | |||
| elif isinstance(ex, openai.error.RateLimitError): | |||
| return LLMRateLimitError(str(ex)) | |||
| elif isinstance(ex, openai.error.AuthenticationError): | |||
| return LLMAuthorizationError(str(ex)) | |||
| elif isinstance(ex, openai.error.OpenAIError): | |||
| return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex)) | |||
| else: | |||
| return ex | |||
| @classmethod | |||
| def support_streaming(cls): | |||
| return True | |||
| @@ -0,0 +1,164 @@ | |||
| import json | |||
| from typing import Type | |||
| from langchain.embeddings import LocalAIEmbeddings | |||
| from langchain.schema import HumanMessage | |||
| from core.helper import encrypter | |||
| from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding | |||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule | |||
| from core.model_providers.models.llm.localai_model import LocalAIModel | |||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI | |||
| from core.third_party.langchain.llms.open_ai import EnhanceOpenAI | |||
| from models.provider import ProviderType | |||
| class LocalAIProvider(BaseModelProvider): | |||
| @property | |||
| def provider_name(self): | |||
| """ | |||
| Returns the name of a provider. | |||
| """ | |||
| return 'localai' | |||
| def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: | |||
| return [] | |||
| def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: | |||
| """ | |||
| Returns the model class. | |||
| :param model_type: | |||
| :return: | |||
| """ | |||
| if model_type == ModelType.TEXT_GENERATION: | |||
| model_class = LocalAIModel | |||
| elif model_type == ModelType.EMBEDDINGS: | |||
| model_class = LocalAIEmbedding | |||
| else: | |||
| raise NotImplementedError | |||
| return model_class | |||
| def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules: | |||
| """ | |||
| get model parameter rules. | |||
| :param model_name: | |||
| :param model_type: | |||
| :return: | |||
| """ | |||
| return ModelKwargsRules( | |||
| temperature=KwargRule[float](min=0, max=2, default=0.7), | |||
| top_p=KwargRule[float](min=0, max=1, default=1), | |||
| max_tokens=KwargRule[int](min=10, max=4097, default=16), | |||
| ) | |||
| @classmethod | |||
| def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): | |||
| """ | |||
| check model credentials valid. | |||
| :param model_name: | |||
| :param model_type: | |||
| :param credentials: | |||
| """ | |||
| if 'server_url' not in credentials: | |||
| raise CredentialsValidateFailedError('LocalAI Server URL must be provided.') | |||
| try: | |||
| if model_type == ModelType.EMBEDDINGS: | |||
| model = LocalAIEmbeddings( | |||
| model=model_name, | |||
| openai_api_key='1', | |||
| openai_api_base=credentials['server_url'] | |||
| ) | |||
| model.embed_query("ping") | |||
| else: | |||
| if ('completion_type' not in credentials | |||
| or credentials['completion_type'] not in ['completion', 'chat_completion']): | |||
| raise CredentialsValidateFailedError('LocalAI Completion Type must be provided.') | |||
| if credentials['completion_type'] == 'chat_completion': | |||
| model = EnhanceChatOpenAI( | |||
| model_name=model_name, | |||
| openai_api_key='1', | |||
| openai_api_base=credentials['server_url'] + '/v1', | |||
| max_tokens=10, | |||
| request_timeout=60, | |||
| ) | |||
| model([HumanMessage(content='ping')]) | |||
| else: | |||
| model = EnhanceOpenAI( | |||
| model_name=model_name, | |||
| openai_api_key='1', | |||
| openai_api_base=credentials['server_url'] + '/v1', | |||
| max_tokens=10, | |||
| request_timeout=60, | |||
| ) | |||
| model('ping') | |||
| except Exception as ex: | |||
| raise CredentialsValidateFailedError(str(ex)) | |||
| @classmethod | |||
| def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, | |||
| credentials: dict) -> dict: | |||
| """ | |||
| encrypt model credentials for save. | |||
| :param tenant_id: | |||
| :param model_name: | |||
| :param model_type: | |||
| :param credentials: | |||
| :return: | |||
| """ | |||
| credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url']) | |||
| return credentials | |||
| def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: | |||
| """ | |||
| get credentials for llm use. | |||
| :param model_name: | |||
| :param model_type: | |||
| :param obfuscated: | |||
| :return: | |||
| """ | |||
| if self.provider.provider_type != ProviderType.CUSTOM.value: | |||
| raise NotImplementedError | |||
| provider_model = self._get_provider_model(model_name, model_type) | |||
| if not provider_model.encrypted_config: | |||
| return { | |||
| 'server_url': None, | |||
| } | |||
| credentials = json.loads(provider_model.encrypted_config) | |||
| if credentials['server_url']: | |||
| credentials['server_url'] = encrypter.decrypt_token( | |||
| self.provider.tenant_id, | |||
| credentials['server_url'] | |||
| ) | |||
| if obfuscated: | |||
| credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url']) | |||
| return credentials | |||
| @classmethod | |||
| def is_provider_credentials_valid_or_raise(cls, credentials: dict): | |||
| return | |||
| @classmethod | |||
| def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: | |||
| return {} | |||
| def get_provider_credentials(self, obfuscated: bool = False) -> dict: | |||
| return {} | |||
| @@ -10,5 +10,6 @@ | |||
| "replicate", | |||
| "huggingface_hub", | |||
| "xinference", | |||
| "openllm" | |||
| "openllm", | |||
| "localai" | |||
| ] | |||
| @@ -0,0 +1,7 @@ | |||
| { | |||
| "support_provider_types": [ | |||
| "custom" | |||
| ], | |||
| "system_config": null, | |||
| "model_flexibility": "configurable" | |||
| } | |||
| @@ -42,7 +42,8 @@ class EnhanceChatOpenAI(ChatOpenAI): | |||
| return { | |||
| **super()._default_params, | |||
| "api_type": 'openai', | |||
| "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), | |||
| "api_base": self.openai_api_base if self.openai_api_base | |||
| else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), | |||
| "api_version": None, | |||
| "api_key": self.openai_api_key, | |||
| "organization": self.openai_organization if self.openai_organization else None, | |||
| @@ -1,7 +1,10 @@ | |||
| import os | |||
| from typing import Dict, Any, Mapping, Optional, Union, Tuple | |||
| from typing import Dict, Any, Mapping, Optional, Union, Tuple, List, Iterator | |||
| from langchain import OpenAI | |||
| from langchain.callbacks.manager import CallbackManagerForLLMRun | |||
| from langchain.llms.openai import completion_with_retry, _stream_response_to_generation_chunk | |||
| from langchain.schema.output import GenerationChunk | |||
| from pydantic import root_validator | |||
| @@ -33,7 +36,8 @@ class EnhanceOpenAI(OpenAI): | |||
| def _invocation_params(self) -> Dict[str, Any]: | |||
| return {**super()._invocation_params, **{ | |||
| "api_type": 'openai', | |||
| "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), | |||
| "api_base": self.openai_api_base if self.openai_api_base | |||
| else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), | |||
| "api_version": None, | |||
| "api_key": self.openai_api_key, | |||
| "organization": self.openai_organization if self.openai_organization else None, | |||
| @@ -43,8 +47,33 @@ class EnhanceOpenAI(OpenAI): | |||
| def _identifying_params(self) -> Mapping[str, Any]: | |||
| return {**super()._identifying_params, **{ | |||
| "api_type": 'openai', | |||
| "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), | |||
| "api_base": self.openai_api_base if self.openai_api_base | |||
| else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), | |||
| "api_version": None, | |||
| "api_key": self.openai_api_key, | |||
| "organization": self.openai_organization if self.openai_organization else None, | |||
| }} | |||
| def _stream( | |||
| self, | |||
| prompt: str, | |||
| stop: Optional[List[str]] = None, | |||
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |||
| **kwargs: Any, | |||
| ) -> Iterator[GenerationChunk]: | |||
| params = {**self._invocation_params, **kwargs, "stream": True} | |||
| self.get_sub_prompts(params, [prompt], stop) # this mutates params | |||
| for stream_resp in completion_with_retry( | |||
| self, prompt=prompt, run_manager=run_manager, **params | |||
| ): | |||
| if 'text' in stream_resp["choices"][0]: | |||
| chunk = _stream_response_to_generation_chunk(stream_resp) | |||
| yield chunk | |||
| if run_manager: | |||
| run_manager.on_llm_new_token( | |||
| chunk.text, | |||
| verbose=self.verbose, | |||
| logprobs=chunk.generation_info["logprobs"] | |||
| if chunk.generation_info | |||
| else None, | |||
| ) | |||
| @@ -39,4 +39,7 @@ XINFERENCE_SERVER_URL= | |||
| XINFERENCE_MODEL_UID= | |||
| # OpenLLM Credentials | |||
| OPENLLM_SERVER_URL= | |||
| OPENLLM_SERVER_URL= | |||
| # LocalAI Credentials | |||
| LOCALAI_SERVER_URL= | |||
| @@ -0,0 +1,61 @@ | |||
| import json | |||
| import os | |||
| from unittest.mock import patch, MagicMock | |||
| from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding | |||
| from core.model_providers.models.entity.model_params import ModelType | |||
| from core.model_providers.providers.localai_provider import LocalAIProvider | |||
| from models.provider import Provider, ProviderType, ProviderModel | |||
| def get_mock_provider(): | |||
| return Provider( | |||
| id='provider_id', | |||
| tenant_id='tenant_id', | |||
| provider_name='localai', | |||
| provider_type=ProviderType.CUSTOM.value, | |||
| encrypted_config='', | |||
| is_valid=True, | |||
| ) | |||
| def get_mock_embedding_model(mocker): | |||
| model_name = 'text-embedding-ada-002' | |||
| server_url = os.environ['LOCALAI_SERVER_URL'] | |||
| model_provider = LocalAIProvider(provider=get_mock_provider()) | |||
| mock_query = MagicMock() | |||
| mock_query.filter.return_value.first.return_value = ProviderModel( | |||
| provider_name='localai', | |||
| model_name=model_name, | |||
| model_type=ModelType.EMBEDDINGS.value, | |||
| encrypted_config=json.dumps({ | |||
| 'server_url': server_url, | |||
| }), | |||
| is_valid=True, | |||
| ) | |||
| mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) | |||
| return LocalAIEmbedding( | |||
| model_provider=model_provider, | |||
| name=model_name | |||
| ) | |||
| def decrypt_side_effect(tenant_id, encrypted_api_key): | |||
| return encrypted_api_key | |||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||
| def test_embed_documents(mock_decrypt, mocker): | |||
| embedding_model = get_mock_embedding_model(mocker) | |||
| rst = embedding_model.client.embed_documents(['test', 'test1']) | |||
| assert isinstance(rst, list) | |||
| assert len(rst) == 2 | |||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||
| def test_embed_query(mock_decrypt, mocker): | |||
| embedding_model = get_mock_embedding_model(mocker) | |||
| rst = embedding_model.client.embed_query('test') | |||
| assert isinstance(rst, list) | |||
| @@ -0,0 +1,68 @@ | |||
| import json | |||
| import os | |||
| from unittest.mock import patch, MagicMock | |||
| from core.model_providers.models.llm.localai_model import LocalAIModel | |||
| from core.model_providers.providers.localai_provider import LocalAIProvider | |||
| from core.model_providers.models.entity.message import PromptMessage | |||
| from core.model_providers.models.entity.model_params import ModelKwargs, ModelType | |||
| from models.provider import Provider, ProviderType, ProviderModel | |||
| def get_mock_provider(server_url): | |||
| return Provider( | |||
| id='provider_id', | |||
| tenant_id='tenant_id', | |||
| provider_name='localai', | |||
| provider_type=ProviderType.CUSTOM.value, | |||
| encrypted_config=json.dumps({}), | |||
| is_valid=True, | |||
| ) | |||
| def get_mock_model(model_name, mocker): | |||
| model_kwargs = ModelKwargs( | |||
| max_tokens=10, | |||
| temperature=0 | |||
| ) | |||
| server_url = os.environ['LOCALAI_SERVER_URL'] | |||
| mock_query = MagicMock() | |||
| mock_query.filter.return_value.first.return_value = ProviderModel( | |||
| provider_name='localai', | |||
| model_name=model_name, | |||
| model_type=ModelType.TEXT_GENERATION.value, | |||
| encrypted_config=json.dumps({'server_url': server_url, 'completion_type': 'completion'}), | |||
| is_valid=True, | |||
| ) | |||
| mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) | |||
| openai_provider = LocalAIProvider(provider=get_mock_provider(server_url)) | |||
| return LocalAIModel( | |||
| model_provider=openai_provider, | |||
| name=model_name, | |||
| model_kwargs=model_kwargs | |||
| ) | |||
| def decrypt_side_effect(tenant_id, encrypted_openai_api_key): | |||
| return encrypted_openai_api_key | |||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||
| def test_get_num_tokens(mock_decrypt, mocker): | |||
| openai_model = get_mock_model('ggml-gpt4all-j', mocker) | |||
| rst = openai_model.get_num_tokens([PromptMessage(content='you are a kindness Assistant.')]) | |||
| assert rst > 0 | |||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||
| def test_run(mock_decrypt, mocker): | |||
| mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) | |||
| openai_model = get_mock_model('ggml-gpt4all-j', mocker) | |||
| rst = openai_model.run( | |||
| [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')], | |||
| stop=['\nHuman:'], | |||
| ) | |||
| assert len(rst.content) > 0 | |||
| @@ -0,0 +1,116 @@ | |||
| import pytest | |||
| from unittest.mock import patch, MagicMock | |||
| import json | |||
| from core.model_providers.models.entity.model_params import ModelType | |||
| from core.model_providers.providers.base import CredentialsValidateFailedError | |||
| from core.model_providers.providers.localai_provider import LocalAIProvider | |||
| from models.provider import ProviderType, Provider, ProviderModel | |||
| PROVIDER_NAME = 'localai' | |||
| MODEL_PROVIDER_CLASS = LocalAIProvider | |||
| VALIDATE_CREDENTIAL = { | |||
| 'server_url': 'http://127.0.0.1:8080/' | |||
| } | |||
| def encrypt_side_effect(tenant_id, encrypt_key): | |||
| return f'encrypted_{encrypt_key}' | |||
| def decrypt_side_effect(tenant_id, encrypted_key): | |||
| return encrypted_key.replace('encrypted_', '') | |||
| def test_is_credentials_valid_or_raise_valid(mocker): | |||
| mocker.patch('langchain.embeddings.localai.LocalAIEmbeddings.embed_query', | |||
| return_value="abc") | |||
| MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( | |||
| model_name='username/test_model_name', | |||
| model_type=ModelType.EMBEDDINGS, | |||
| credentials=VALIDATE_CREDENTIAL.copy() | |||
| ) | |||
| def test_is_credentials_valid_or_raise_invalid(): | |||
| # raise CredentialsValidateFailedError if server_url is not in credentials | |||
| with pytest.raises(CredentialsValidateFailedError): | |||
| MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( | |||
| model_name='test_model_name', | |||
| model_type=ModelType.EMBEDDINGS, | |||
| credentials={} | |||
| ) | |||
| @patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) | |||
| def test_encrypt_model_credentials(mock_encrypt, mocker): | |||
| server_url = 'http://127.0.0.1:8080/' | |||
| result = MODEL_PROVIDER_CLASS.encrypt_model_credentials( | |||
| tenant_id='tenant_id', | |||
| model_name='test_model_name', | |||
| model_type=ModelType.EMBEDDINGS, | |||
| credentials=VALIDATE_CREDENTIAL.copy() | |||
| ) | |||
| mock_encrypt.assert_called_with('tenant_id', server_url) | |||
| assert result['server_url'] == f'encrypted_{server_url}' | |||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||
| def test_get_model_credentials_custom(mock_decrypt, mocker): | |||
| provider = Provider( | |||
| id='provider_id', | |||
| tenant_id='tenant_id', | |||
| provider_name=PROVIDER_NAME, | |||
| provider_type=ProviderType.CUSTOM.value, | |||
| encrypted_config=None, | |||
| is_valid=True, | |||
| ) | |||
| encrypted_credential = VALIDATE_CREDENTIAL.copy() | |||
| encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url'] | |||
| mock_query = MagicMock() | |||
| mock_query.filter.return_value.first.return_value = ProviderModel( | |||
| encrypted_config=json.dumps(encrypted_credential) | |||
| ) | |||
| mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) | |||
| model_provider = MODEL_PROVIDER_CLASS(provider=provider) | |||
| result = model_provider.get_model_credentials( | |||
| model_name='test_model_name', | |||
| model_type=ModelType.EMBEDDINGS | |||
| ) | |||
| assert result['server_url'] == 'http://127.0.0.1:8080/' | |||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||
| def test_get_model_credentials_obfuscated(mock_decrypt, mocker): | |||
| provider = Provider( | |||
| id='provider_id', | |||
| tenant_id='tenant_id', | |||
| provider_name=PROVIDER_NAME, | |||
| provider_type=ProviderType.CUSTOM.value, | |||
| encrypted_config=None, | |||
| is_valid=True, | |||
| ) | |||
| encrypted_credential = VALIDATE_CREDENTIAL.copy() | |||
| encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url'] | |||
| mock_query = MagicMock() | |||
| mock_query.filter.return_value.first.return_value = ProviderModel( | |||
| encrypted_config=json.dumps(encrypted_credential) | |||
| ) | |||
| mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) | |||
| model_provider = MODEL_PROVIDER_CLASS(provider=provider) | |||
| result = model_provider.get_model_credentials( | |||
| model_name='test_model_name', | |||
| model_type=ModelType.EMBEDDINGS, | |||
| obfuscated=True | |||
| ) | |||
| middle_token = result['server_url'][6:-2] | |||
| assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['server_url']) - 8, 0) | |||
| assert all(char == '*' for char in middle_token) | |||
| @@ -0,0 +1,14 @@ | |||
| // GENERATE BY script | |||
| // DON NOT EDIT IT MANUALLY | |||
| import * as React from 'react' | |||
| import data from './Localai.json' | |||
| import IconBase from '@/app/components/base/icons/IconBase' | |||
| import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase' | |||
| const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>(( | |||
| props, | |||
| ref, | |||
| ) => <IconBase {...props} ref={ref} data={data as IconData} />) | |||
| export default Icon | |||
| @@ -0,0 +1,14 @@ | |||
| // GENERATE BY script | |||
| // DON NOT EDIT IT MANUALLY | |||
| import * as React from 'react' | |||
| import data from './LocalaiText.json' | |||
| import IconBase from '@/app/components/base/icons/IconBase' | |||
| import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase' | |||
| const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>(( | |||
| props, | |||
| ref, | |||
| ) => <IconBase {...props} ref={ref} data={data as IconData} />) | |||
| export default Icon | |||
| @@ -14,6 +14,8 @@ export { default as Huggingface } from './Huggingface' | |||
| export { default as IflytekSparkTextCn } from './IflytekSparkTextCn' | |||
| export { default as IflytekSparkText } from './IflytekSparkText' | |||
| export { default as IflytekSpark } from './IflytekSpark' | |||
| export { default as LocalaiText } from './LocalaiText' | |||
| export { default as Localai } from './Localai' | |||
| export { default as Microsoft } from './Microsoft' | |||
| export { default as OpenaiBlack } from './OpenaiBlack' | |||
| export { default as OpenaiBlue } from './OpenaiBlue' | |||
| @@ -10,6 +10,7 @@ import minimax from './minimax' | |||
| import chatglm from './chatglm' | |||
| import xinference from './xinference' | |||
| import openllm from './openllm' | |||
| import localai from './localai' | |||
| export default { | |||
| openai, | |||
| @@ -24,4 +25,5 @@ export default { | |||
| chatglm, | |||
| xinference, | |||
| openllm, | |||
| localai, | |||
| } | |||
| @@ -0,0 +1,176 @@ | |||
| import { ProviderEnum } from '../declarations' | |||
| import type { FormValue, ProviderConfig } from '../declarations' | |||
| import { Localai, LocalaiText } from '@/app/components/base/icons/src/public/llm' | |||
| const config: ProviderConfig = { | |||
| selector: { | |||
| name: { | |||
| 'en': 'LocalAI', | |||
| 'zh-Hans': 'LocalAI', | |||
| }, | |||
| icon: <Localai className='w-full h-full' />, | |||
| }, | |||
| item: { | |||
| key: ProviderEnum.localai, | |||
| titleIcon: { | |||
| 'en': <LocalaiText className='h-6' />, | |||
| 'zh-Hans': <LocalaiText className='h-6' />, | |||
| }, | |||
| disable: { | |||
| tip: { | |||
| 'en': 'Only supports the ', | |||
| 'zh-Hans': '仅支持', | |||
| }, | |||
| link: { | |||
| href: { | |||
| 'en': 'https://docs.dify.ai/getting-started/install-self-hosted', | |||
| 'zh-Hans': 'https://docs.dify.ai/v/zh-hans/getting-started/install-self-hosted', | |||
| }, | |||
| label: { | |||
| 'en': 'community open-source version', | |||
| 'zh-Hans': '社区开源版本', | |||
| }, | |||
| }, | |||
| }, | |||
| }, | |||
| modal: { | |||
| key: ProviderEnum.localai, | |||
| title: { | |||
| 'en': 'LocalAI', | |||
| 'zh-Hans': 'LocalAI', | |||
| }, | |||
| icon: <Localai className='h-6' />, | |||
| link: { | |||
| href: 'https://github.com/go-skynet/LocalAI', | |||
| label: { | |||
| 'en': 'How to deploy LocalAI', | |||
| 'zh-Hans': '如何部署 LocalAI', | |||
| }, | |||
| }, | |||
| defaultValue: { | |||
| model_type: 'text-generation', | |||
| completion_type: 'completion', | |||
| }, | |||
| validateKeys: (v?: FormValue) => { | |||
| if (v?.model_type === 'text-generation') { | |||
| return [ | |||
| 'model_type', | |||
| 'model_name', | |||
| 'server_url', | |||
| 'completion_type', | |||
| ] | |||
| } | |||
| if (v?.model_type === 'embeddings') { | |||
| return [ | |||
| 'model_type', | |||
| 'model_name', | |||
| 'server_url', | |||
| ] | |||
| } | |||
| return [] | |||
| }, | |||
| filterValue: (v?: FormValue) => { | |||
| let filteredKeys: string[] = [] | |||
| if (v?.model_type === 'text-generation') { | |||
| filteredKeys = [ | |||
| 'model_type', | |||
| 'model_name', | |||
| 'server_url', | |||
| 'completion_type', | |||
| ] | |||
| } | |||
| if (v?.model_type === 'embeddings') { | |||
| filteredKeys = [ | |||
| 'model_type', | |||
| 'model_name', | |||
| 'server_url', | |||
| ] | |||
| } | |||
| return filteredKeys.reduce((prev: FormValue, next: string) => { | |||
| prev[next] = v?.[next] || '' | |||
| return prev | |||
| }, {}) | |||
| }, | |||
| fields: [ | |||
| { | |||
| type: 'radio', | |||
| key: 'model_type', | |||
| required: true, | |||
| label: { | |||
| 'en': 'Model Type', | |||
| 'zh-Hans': '模型类型', | |||
| }, | |||
| options: [ | |||
| { | |||
| key: 'text-generation', | |||
| label: { | |||
| 'en': 'Text Generation', | |||
| 'zh-Hans': '文本生成', | |||
| }, | |||
| }, | |||
| { | |||
| key: 'embeddings', | |||
| label: { | |||
| 'en': 'Embeddings', | |||
| 'zh-Hans': 'Embeddings', | |||
| }, | |||
| }, | |||
| ], | |||
| }, | |||
| { | |||
| type: 'text', | |||
| key: 'model_name', | |||
| required: true, | |||
| label: { | |||
| 'en': 'Model Name', | |||
| 'zh-Hans': '模型名称', | |||
| }, | |||
| placeholder: { | |||
| 'en': 'Enter your Model Name here', | |||
| 'zh-Hans': '在此输入您的模型名称', | |||
| }, | |||
| }, | |||
| { | |||
| hidden: (value?: FormValue) => value?.model_type === 'embeddings', | |||
| type: 'radio', | |||
| key: 'completion_type', | |||
| required: true, | |||
| label: { | |||
| 'en': 'Completion Type', | |||
| 'zh-Hans': 'Completion Type', | |||
| }, | |||
| options: [ | |||
| { | |||
| key: 'completion', | |||
| label: { | |||
| 'en': 'Completion', | |||
| 'zh-Hans': 'Completion', | |||
| }, | |||
| }, | |||
| { | |||
| key: 'chat_completion', | |||
| label: { | |||
| 'en': 'Chat Completion', | |||
| 'zh-Hans': 'Chat Completion', | |||
| }, | |||
| }, | |||
| ], | |||
| }, | |||
| { | |||
| type: 'text', | |||
| key: 'server_url', | |||
| required: true, | |||
| label: { | |||
| 'en': 'Server url', | |||
| 'zh-Hans': 'Server url', | |||
| }, | |||
| placeholder: { | |||
| 'en': 'Enter your Server Url, eg: https://example.com/xxx', | |||
| 'zh-Hans': '在此输入您的 Server Url,如:https://example.com/xxx', | |||
| }, | |||
| }, | |||
| ], | |||
| }, | |||
| } | |||
| export default config | |||
| @@ -41,6 +41,7 @@ export enum ProviderEnum { | |||
| 'chatglm' = 'chatglm', | |||
| 'xinference' = 'xinference', | |||
| 'openllm' = 'openllm', | |||
| 'localai' = 'localai', | |||
| } | |||
| export type ProviderConfigItem = { | |||
| @@ -99,6 +99,7 @@ const ModelPage = () => { | |||
| config.chatglm, | |||
| config.xinference, | |||
| config.openllm, | |||
| config.localai, | |||
| ] | |||
| } | |||
| @@ -2,7 +2,7 @@ import { ValidatedStatus } from '../key-validator/declarations' | |||
| import { ProviderEnum } from './declarations' | |||
| import { validateModelProvider } from '@/service/common' | |||
| export const ConfigurableProviders = [ProviderEnum.azure_openai, ProviderEnum.replicate, ProviderEnum.huggingface_hub, ProviderEnum.xinference, ProviderEnum.openllm] | |||
| export const ConfigurableProviders = [ProviderEnum.azure_openai, ProviderEnum.replicate, ProviderEnum.huggingface_hub, ProviderEnum.xinference, ProviderEnum.openllm, ProviderEnum.localai] | |||
| export const validateModelProviderFn = async (providerName: ProviderEnum, v: any) => { | |||
| let body, url | |||