Co-authored-by: StyleZhang <jasonapring2015@outlook.com>tags/0.3.19
| elif provider_name == 'openllm': | elif provider_name == 'openllm': | ||||
| from core.model_providers.providers.openllm_provider import OpenLLMProvider | from core.model_providers.providers.openllm_provider import OpenLLMProvider | ||||
| return OpenLLMProvider | return OpenLLMProvider | ||||
| elif provider_name == 'localai': | |||||
| from core.model_providers.providers.localai_provider import LocalAIProvider | |||||
| return LocalAIProvider | |||||
| else: | else: | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| from langchain.embeddings import LocalAIEmbeddings | |||||
| from replicate.exceptions import ModelError, ReplicateError | |||||
| from core.model_providers.error import LLMBadRequestError | |||||
| from core.model_providers.providers.base import BaseModelProvider | |||||
| from core.model_providers.models.embedding.base import BaseEmbedding | |||||
| class LocalAIEmbedding(BaseEmbedding): | |||||
| def __init__(self, model_provider: BaseModelProvider, name: str): | |||||
| credentials = model_provider.get_model_credentials( | |||||
| model_name=name, | |||||
| model_type=self.type | |||||
| ) | |||||
| client = LocalAIEmbeddings( | |||||
| model=name, | |||||
| openai_api_key="1", | |||||
| openai_api_base=credentials['server_url'], | |||||
| ) | |||||
| super().__init__(model_provider, client, name) | |||||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||||
| if isinstance(ex, (ModelError, ReplicateError)): | |||||
| return LLMBadRequestError(f"LocalAI embedding: {str(ex)}") | |||||
| else: | |||||
| return ex | 
| import logging | |||||
| from typing import List, Optional, Any | |||||
| import openai | |||||
| from langchain.callbacks.manager import Callbacks | |||||
| from langchain.schema import LLMResult, get_buffer_string | |||||
| from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ | |||||
| LLMRateLimitError, LLMAuthorizationError | |||||
| from core.model_providers.providers.base import BaseModelProvider | |||||
| from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI | |||||
| from core.third_party.langchain.llms.open_ai import EnhanceOpenAI | |||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| from core.model_providers.models.entity.message import PromptMessage | |||||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | |||||
| class LocalAIModel(BaseLLM): | |||||
| def __init__(self, model_provider: BaseModelProvider, | |||||
| name: str, | |||||
| model_kwargs: ModelKwargs, | |||||
| streaming: bool = False, | |||||
| callbacks: Callbacks = None): | |||||
| credentials = model_provider.get_model_credentials( | |||||
| model_name=name, | |||||
| model_type=self.type | |||||
| ) | |||||
| if credentials['completion_type'] == 'chat_completion': | |||||
| self.model_mode = ModelMode.CHAT | |||||
| else: | |||||
| self.model_mode = ModelMode.COMPLETION | |||||
| super().__init__(model_provider, name, model_kwargs, streaming, callbacks) | |||||
| def _init_client(self) -> Any: | |||||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | |||||
| if self.model_mode == ModelMode.COMPLETION: | |||||
| client = EnhanceOpenAI( | |||||
| model_name=self.name, | |||||
| streaming=self.streaming, | |||||
| callbacks=self.callbacks, | |||||
| request_timeout=60, | |||||
| openai_api_key="1", | |||||
| openai_api_base=self.credentials['server_url'] + '/v1', | |||||
| **provider_model_kwargs | |||||
| ) | |||||
| else: | |||||
| extra_model_kwargs = { | |||||
| 'top_p': provider_model_kwargs.get('top_p') | |||||
| } | |||||
| client = EnhanceChatOpenAI( | |||||
| model_name=self.name, | |||||
| temperature=provider_model_kwargs.get('temperature'), | |||||
| max_tokens=provider_model_kwargs.get('max_tokens'), | |||||
| model_kwargs=extra_model_kwargs, | |||||
| streaming=self.streaming, | |||||
| callbacks=self.callbacks, | |||||
| request_timeout=60, | |||||
| openai_api_key="1", | |||||
| openai_api_base=self.credentials['server_url'] + '/v1' | |||||
| ) | |||||
| return client | |||||
| def _run(self, messages: List[PromptMessage], | |||||
| stop: Optional[List[str]] = None, | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs) -> LLMResult: | |||||
| """ | |||||
| run predict by prompt messages and stop words. | |||||
| :param messages: | |||||
| :param stop: | |||||
| :param callbacks: | |||||
| :return: | |||||
| """ | |||||
| prompts = self._get_prompt_from_messages(messages) | |||||
| return self._client.generate([prompts], stop, callbacks) | |||||
| def get_num_tokens(self, messages: List[PromptMessage]) -> int: | |||||
| """ | |||||
| get num tokens of prompt messages. | |||||
| :param messages: | |||||
| :return: | |||||
| """ | |||||
| prompts = self._get_prompt_from_messages(messages) | |||||
| if isinstance(prompts, str): | |||||
| return self._client.get_num_tokens(prompts) | |||||
| else: | |||||
| return max(sum([self._client.get_num_tokens(get_buffer_string([m])) for m in prompts]) - len(prompts), 0) | |||||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) | |||||
| if self.model_mode == ModelMode.COMPLETION: | |||||
| for k, v in provider_model_kwargs.items(): | |||||
| if hasattr(self.client, k): | |||||
| setattr(self.client, k, v) | |||||
| else: | |||||
| extra_model_kwargs = { | |||||
| 'top_p': provider_model_kwargs.get('top_p') | |||||
| } | |||||
| self.client.temperature = provider_model_kwargs.get('temperature') | |||||
| self.client.max_tokens = provider_model_kwargs.get('max_tokens') | |||||
| self.client.model_kwargs = extra_model_kwargs | |||||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||||
| if isinstance(ex, openai.error.InvalidRequestError): | |||||
| logging.warning("Invalid request to LocalAI API.") | |||||
| return LLMBadRequestError(str(ex)) | |||||
| elif isinstance(ex, openai.error.APIConnectionError): | |||||
| logging.warning("Failed to connect to LocalAI API.") | |||||
| return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex)) | |||||
| elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)): | |||||
| logging.warning("LocalAI service unavailable.") | |||||
| return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex)) | |||||
| elif isinstance(ex, openai.error.RateLimitError): | |||||
| return LLMRateLimitError(str(ex)) | |||||
| elif isinstance(ex, openai.error.AuthenticationError): | |||||
| return LLMAuthorizationError(str(ex)) | |||||
| elif isinstance(ex, openai.error.OpenAIError): | |||||
| return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex)) | |||||
| else: | |||||
| return ex | |||||
| @classmethod | |||||
| def support_streaming(cls): | |||||
| return True | 
| import json | |||||
| from typing import Type | |||||
| from langchain.embeddings import LocalAIEmbeddings | |||||
| from langchain.schema import HumanMessage | |||||
| from core.helper import encrypter | |||||
| from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding | |||||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule | |||||
| from core.model_providers.models.llm.localai_model import LocalAIModel | |||||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||||
| from core.model_providers.models.base import BaseProviderModel | |||||
| from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI | |||||
| from core.third_party.langchain.llms.open_ai import EnhanceOpenAI | |||||
| from models.provider import ProviderType | |||||
| class LocalAIProvider(BaseModelProvider): | |||||
| @property | |||||
| def provider_name(self): | |||||
| """ | |||||
| Returns the name of a provider. | |||||
| """ | |||||
| return 'localai' | |||||
| def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: | |||||
| return [] | |||||
| def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: | |||||
| """ | |||||
| Returns the model class. | |||||
| :param model_type: | |||||
| :return: | |||||
| """ | |||||
| if model_type == ModelType.TEXT_GENERATION: | |||||
| model_class = LocalAIModel | |||||
| elif model_type == ModelType.EMBEDDINGS: | |||||
| model_class = LocalAIEmbedding | |||||
| else: | |||||
| raise NotImplementedError | |||||
| return model_class | |||||
| def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules: | |||||
| """ | |||||
| get model parameter rules. | |||||
| :param model_name: | |||||
| :param model_type: | |||||
| :return: | |||||
| """ | |||||
| return ModelKwargsRules( | |||||
| temperature=KwargRule[float](min=0, max=2, default=0.7), | |||||
| top_p=KwargRule[float](min=0, max=1, default=1), | |||||
| max_tokens=KwargRule[int](min=10, max=4097, default=16), | |||||
| ) | |||||
| @classmethod | |||||
| def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): | |||||
| """ | |||||
| check model credentials valid. | |||||
| :param model_name: | |||||
| :param model_type: | |||||
| :param credentials: | |||||
| """ | |||||
| if 'server_url' not in credentials: | |||||
| raise CredentialsValidateFailedError('LocalAI Server URL must be provided.') | |||||
| try: | |||||
| if model_type == ModelType.EMBEDDINGS: | |||||
| model = LocalAIEmbeddings( | |||||
| model=model_name, | |||||
| openai_api_key='1', | |||||
| openai_api_base=credentials['server_url'] | |||||
| ) | |||||
| model.embed_query("ping") | |||||
| else: | |||||
| if ('completion_type' not in credentials | |||||
| or credentials['completion_type'] not in ['completion', 'chat_completion']): | |||||
| raise CredentialsValidateFailedError('LocalAI Completion Type must be provided.') | |||||
| if credentials['completion_type'] == 'chat_completion': | |||||
| model = EnhanceChatOpenAI( | |||||
| model_name=model_name, | |||||
| openai_api_key='1', | |||||
| openai_api_base=credentials['server_url'] + '/v1', | |||||
| max_tokens=10, | |||||
| request_timeout=60, | |||||
| ) | |||||
| model([HumanMessage(content='ping')]) | |||||
| else: | |||||
| model = EnhanceOpenAI( | |||||
| model_name=model_name, | |||||
| openai_api_key='1', | |||||
| openai_api_base=credentials['server_url'] + '/v1', | |||||
| max_tokens=10, | |||||
| request_timeout=60, | |||||
| ) | |||||
| model('ping') | |||||
| except Exception as ex: | |||||
| raise CredentialsValidateFailedError(str(ex)) | |||||
| @classmethod | |||||
| def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, | |||||
| credentials: dict) -> dict: | |||||
| """ | |||||
| encrypt model credentials for save. | |||||
| :param tenant_id: | |||||
| :param model_name: | |||||
| :param model_type: | |||||
| :param credentials: | |||||
| :return: | |||||
| """ | |||||
| credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url']) | |||||
| return credentials | |||||
| def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: | |||||
| """ | |||||
| get credentials for llm use. | |||||
| :param model_name: | |||||
| :param model_type: | |||||
| :param obfuscated: | |||||
| :return: | |||||
| """ | |||||
| if self.provider.provider_type != ProviderType.CUSTOM.value: | |||||
| raise NotImplementedError | |||||
| provider_model = self._get_provider_model(model_name, model_type) | |||||
| if not provider_model.encrypted_config: | |||||
| return { | |||||
| 'server_url': None, | |||||
| } | |||||
| credentials = json.loads(provider_model.encrypted_config) | |||||
| if credentials['server_url']: | |||||
| credentials['server_url'] = encrypter.decrypt_token( | |||||
| self.provider.tenant_id, | |||||
| credentials['server_url'] | |||||
| ) | |||||
| if obfuscated: | |||||
| credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url']) | |||||
| return credentials | |||||
| @classmethod | |||||
| def is_provider_credentials_valid_or_raise(cls, credentials: dict): | |||||
| return | |||||
| @classmethod | |||||
| def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: | |||||
| return {} | |||||
| def get_provider_credentials(self, obfuscated: bool = False) -> dict: | |||||
| return {} | 
| "replicate", | "replicate", | ||||
| "huggingface_hub", | "huggingface_hub", | ||||
| "xinference", | "xinference", | ||||
| "openllm" | |||||
| "openllm", | |||||
| "localai" | |||||
| ] | ] | 
| { | |||||
| "support_provider_types": [ | |||||
| "custom" | |||||
| ], | |||||
| "system_config": null, | |||||
| "model_flexibility": "configurable" | |||||
| } | 
| return { | return { | ||||
| **super()._default_params, | **super()._default_params, | ||||
| "api_type": 'openai', | "api_type": 'openai', | ||||
| "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), | |||||
| "api_base": self.openai_api_base if self.openai_api_base | |||||
| else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), | |||||
| "api_version": None, | "api_version": None, | ||||
| "api_key": self.openai_api_key, | "api_key": self.openai_api_key, | ||||
| "organization": self.openai_organization if self.openai_organization else None, | "organization": self.openai_organization if self.openai_organization else None, | 
| import os | import os | ||||
| from typing import Dict, Any, Mapping, Optional, Union, Tuple | |||||
| from typing import Dict, Any, Mapping, Optional, Union, Tuple, List, Iterator | |||||
| from langchain import OpenAI | from langchain import OpenAI | ||||
| from langchain.callbacks.manager import CallbackManagerForLLMRun | |||||
| from langchain.llms.openai import completion_with_retry, _stream_response_to_generation_chunk | |||||
| from langchain.schema.output import GenerationChunk | |||||
| from pydantic import root_validator | from pydantic import root_validator | ||||
| def _invocation_params(self) -> Dict[str, Any]: | def _invocation_params(self) -> Dict[str, Any]: | ||||
| return {**super()._invocation_params, **{ | return {**super()._invocation_params, **{ | ||||
| "api_type": 'openai', | "api_type": 'openai', | ||||
| "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), | |||||
| "api_base": self.openai_api_base if self.openai_api_base | |||||
| else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), | |||||
| "api_version": None, | "api_version": None, | ||||
| "api_key": self.openai_api_key, | "api_key": self.openai_api_key, | ||||
| "organization": self.openai_organization if self.openai_organization else None, | "organization": self.openai_organization if self.openai_organization else None, | ||||
| def _identifying_params(self) -> Mapping[str, Any]: | def _identifying_params(self) -> Mapping[str, Any]: | ||||
| return {**super()._identifying_params, **{ | return {**super()._identifying_params, **{ | ||||
| "api_type": 'openai', | "api_type": 'openai', | ||||
| "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), | |||||
| "api_base": self.openai_api_base if self.openai_api_base | |||||
| else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), | |||||
| "api_version": None, | "api_version": None, | ||||
| "api_key": self.openai_api_key, | "api_key": self.openai_api_key, | ||||
| "organization": self.openai_organization if self.openai_organization else None, | "organization": self.openai_organization if self.openai_organization else None, | ||||
| }} | }} | ||||
| def _stream( | |||||
| self, | |||||
| prompt: str, | |||||
| stop: Optional[List[str]] = None, | |||||
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |||||
| **kwargs: Any, | |||||
| ) -> Iterator[GenerationChunk]: | |||||
| params = {**self._invocation_params, **kwargs, "stream": True} | |||||
| self.get_sub_prompts(params, [prompt], stop) # this mutates params | |||||
| for stream_resp in completion_with_retry( | |||||
| self, prompt=prompt, run_manager=run_manager, **params | |||||
| ): | |||||
| if 'text' in stream_resp["choices"][0]: | |||||
| chunk = _stream_response_to_generation_chunk(stream_resp) | |||||
| yield chunk | |||||
| if run_manager: | |||||
| run_manager.on_llm_new_token( | |||||
| chunk.text, | |||||
| verbose=self.verbose, | |||||
| logprobs=chunk.generation_info["logprobs"] | |||||
| if chunk.generation_info | |||||
| else None, | |||||
| ) | 
| XINFERENCE_MODEL_UID= | XINFERENCE_MODEL_UID= | ||||
| # OpenLLM Credentials | # OpenLLM Credentials | ||||
| OPENLLM_SERVER_URL= | |||||
| OPENLLM_SERVER_URL= | |||||
| # LocalAI Credentials | |||||
| LOCALAI_SERVER_URL= | 
| import json | |||||
| import os | |||||
| from unittest.mock import patch, MagicMock | |||||
| from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding | |||||
| from core.model_providers.models.entity.model_params import ModelType | |||||
| from core.model_providers.providers.localai_provider import LocalAIProvider | |||||
| from models.provider import Provider, ProviderType, ProviderModel | |||||
| def get_mock_provider(): | |||||
| return Provider( | |||||
| id='provider_id', | |||||
| tenant_id='tenant_id', | |||||
| provider_name='localai', | |||||
| provider_type=ProviderType.CUSTOM.value, | |||||
| encrypted_config='', | |||||
| is_valid=True, | |||||
| ) | |||||
| def get_mock_embedding_model(mocker): | |||||
| model_name = 'text-embedding-ada-002' | |||||
| server_url = os.environ['LOCALAI_SERVER_URL'] | |||||
| model_provider = LocalAIProvider(provider=get_mock_provider()) | |||||
| mock_query = MagicMock() | |||||
| mock_query.filter.return_value.first.return_value = ProviderModel( | |||||
| provider_name='localai', | |||||
| model_name=model_name, | |||||
| model_type=ModelType.EMBEDDINGS.value, | |||||
| encrypted_config=json.dumps({ | |||||
| 'server_url': server_url, | |||||
| }), | |||||
| is_valid=True, | |||||
| ) | |||||
| mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) | |||||
| return LocalAIEmbedding( | |||||
| model_provider=model_provider, | |||||
| name=model_name | |||||
| ) | |||||
| def decrypt_side_effect(tenant_id, encrypted_api_key): | |||||
| return encrypted_api_key | |||||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||||
| def test_embed_documents(mock_decrypt, mocker): | |||||
| embedding_model = get_mock_embedding_model(mocker) | |||||
| rst = embedding_model.client.embed_documents(['test', 'test1']) | |||||
| assert isinstance(rst, list) | |||||
| assert len(rst) == 2 | |||||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||||
| def test_embed_query(mock_decrypt, mocker): | |||||
| embedding_model = get_mock_embedding_model(mocker) | |||||
| rst = embedding_model.client.embed_query('test') | |||||
| assert isinstance(rst, list) | 
| import json | |||||
| import os | |||||
| from unittest.mock import patch, MagicMock | |||||
| from core.model_providers.models.llm.localai_model import LocalAIModel | |||||
| from core.model_providers.providers.localai_provider import LocalAIProvider | |||||
| from core.model_providers.models.entity.message import PromptMessage | |||||
| from core.model_providers.models.entity.model_params import ModelKwargs, ModelType | |||||
| from models.provider import Provider, ProviderType, ProviderModel | |||||
| def get_mock_provider(server_url): | |||||
| return Provider( | |||||
| id='provider_id', | |||||
| tenant_id='tenant_id', | |||||
| provider_name='localai', | |||||
| provider_type=ProviderType.CUSTOM.value, | |||||
| encrypted_config=json.dumps({}), | |||||
| is_valid=True, | |||||
| ) | |||||
| def get_mock_model(model_name, mocker): | |||||
| model_kwargs = ModelKwargs( | |||||
| max_tokens=10, | |||||
| temperature=0 | |||||
| ) | |||||
| server_url = os.environ['LOCALAI_SERVER_URL'] | |||||
| mock_query = MagicMock() | |||||
| mock_query.filter.return_value.first.return_value = ProviderModel( | |||||
| provider_name='localai', | |||||
| model_name=model_name, | |||||
| model_type=ModelType.TEXT_GENERATION.value, | |||||
| encrypted_config=json.dumps({'server_url': server_url, 'completion_type': 'completion'}), | |||||
| is_valid=True, | |||||
| ) | |||||
| mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) | |||||
| openai_provider = LocalAIProvider(provider=get_mock_provider(server_url)) | |||||
| return LocalAIModel( | |||||
| model_provider=openai_provider, | |||||
| name=model_name, | |||||
| model_kwargs=model_kwargs | |||||
| ) | |||||
| def decrypt_side_effect(tenant_id, encrypted_openai_api_key): | |||||
| return encrypted_openai_api_key | |||||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||||
| def test_get_num_tokens(mock_decrypt, mocker): | |||||
| openai_model = get_mock_model('ggml-gpt4all-j', mocker) | |||||
| rst = openai_model.get_num_tokens([PromptMessage(content='you are a kindness Assistant.')]) | |||||
| assert rst > 0 | |||||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||||
| def test_run(mock_decrypt, mocker): | |||||
| mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) | |||||
| openai_model = get_mock_model('ggml-gpt4all-j', mocker) | |||||
| rst = openai_model.run( | |||||
| [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')], | |||||
| stop=['\nHuman:'], | |||||
| ) | |||||
| assert len(rst.content) > 0 | 
| import pytest | |||||
| from unittest.mock import patch, MagicMock | |||||
| import json | |||||
| from core.model_providers.models.entity.model_params import ModelType | |||||
| from core.model_providers.providers.base import CredentialsValidateFailedError | |||||
| from core.model_providers.providers.localai_provider import LocalAIProvider | |||||
| from models.provider import ProviderType, Provider, ProviderModel | |||||
| PROVIDER_NAME = 'localai' | |||||
| MODEL_PROVIDER_CLASS = LocalAIProvider | |||||
| VALIDATE_CREDENTIAL = { | |||||
| 'server_url': 'http://127.0.0.1:8080/' | |||||
| } | |||||
| def encrypt_side_effect(tenant_id, encrypt_key): | |||||
| return f'encrypted_{encrypt_key}' | |||||
| def decrypt_side_effect(tenant_id, encrypted_key): | |||||
| return encrypted_key.replace('encrypted_', '') | |||||
| def test_is_credentials_valid_or_raise_valid(mocker): | |||||
| mocker.patch('langchain.embeddings.localai.LocalAIEmbeddings.embed_query', | |||||
| return_value="abc") | |||||
| MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( | |||||
| model_name='username/test_model_name', | |||||
| model_type=ModelType.EMBEDDINGS, | |||||
| credentials=VALIDATE_CREDENTIAL.copy() | |||||
| ) | |||||
| def test_is_credentials_valid_or_raise_invalid(): | |||||
| # raise CredentialsValidateFailedError if server_url is not in credentials | |||||
| with pytest.raises(CredentialsValidateFailedError): | |||||
| MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( | |||||
| model_name='test_model_name', | |||||
| model_type=ModelType.EMBEDDINGS, | |||||
| credentials={} | |||||
| ) | |||||
| @patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) | |||||
| def test_encrypt_model_credentials(mock_encrypt, mocker): | |||||
| server_url = 'http://127.0.0.1:8080/' | |||||
| result = MODEL_PROVIDER_CLASS.encrypt_model_credentials( | |||||
| tenant_id='tenant_id', | |||||
| model_name='test_model_name', | |||||
| model_type=ModelType.EMBEDDINGS, | |||||
| credentials=VALIDATE_CREDENTIAL.copy() | |||||
| ) | |||||
| mock_encrypt.assert_called_with('tenant_id', server_url) | |||||
| assert result['server_url'] == f'encrypted_{server_url}' | |||||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||||
| def test_get_model_credentials_custom(mock_decrypt, mocker): | |||||
| provider = Provider( | |||||
| id='provider_id', | |||||
| tenant_id='tenant_id', | |||||
| provider_name=PROVIDER_NAME, | |||||
| provider_type=ProviderType.CUSTOM.value, | |||||
| encrypted_config=None, | |||||
| is_valid=True, | |||||
| ) | |||||
| encrypted_credential = VALIDATE_CREDENTIAL.copy() | |||||
| encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url'] | |||||
| mock_query = MagicMock() | |||||
| mock_query.filter.return_value.first.return_value = ProviderModel( | |||||
| encrypted_config=json.dumps(encrypted_credential) | |||||
| ) | |||||
| mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) | |||||
| model_provider = MODEL_PROVIDER_CLASS(provider=provider) | |||||
| result = model_provider.get_model_credentials( | |||||
| model_name='test_model_name', | |||||
| model_type=ModelType.EMBEDDINGS | |||||
| ) | |||||
| assert result['server_url'] == 'http://127.0.0.1:8080/' | |||||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||||
| def test_get_model_credentials_obfuscated(mock_decrypt, mocker): | |||||
| provider = Provider( | |||||
| id='provider_id', | |||||
| tenant_id='tenant_id', | |||||
| provider_name=PROVIDER_NAME, | |||||
| provider_type=ProviderType.CUSTOM.value, | |||||
| encrypted_config=None, | |||||
| is_valid=True, | |||||
| ) | |||||
| encrypted_credential = VALIDATE_CREDENTIAL.copy() | |||||
| encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url'] | |||||
| mock_query = MagicMock() | |||||
| mock_query.filter.return_value.first.return_value = ProviderModel( | |||||
| encrypted_config=json.dumps(encrypted_credential) | |||||
| ) | |||||
| mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) | |||||
| model_provider = MODEL_PROVIDER_CLASS(provider=provider) | |||||
| result = model_provider.get_model_credentials( | |||||
| model_name='test_model_name', | |||||
| model_type=ModelType.EMBEDDINGS, | |||||
| obfuscated=True | |||||
| ) | |||||
| middle_token = result['server_url'][6:-2] | |||||
| assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['server_url']) - 8, 0) | |||||
| assert all(char == '*' for char in middle_token) | 
| // GENERATE BY script | |||||
| // DON NOT EDIT IT MANUALLY | |||||
| import * as React from 'react' | |||||
| import data from './Localai.json' | |||||
| import IconBase from '@/app/components/base/icons/IconBase' | |||||
| import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase' | |||||
| const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>(( | |||||
| props, | |||||
| ref, | |||||
| ) => <IconBase {...props} ref={ref} data={data as IconData} />) | |||||
| export default Icon | 
| // GENERATE BY script | |||||
| // DON NOT EDIT IT MANUALLY | |||||
| import * as React from 'react' | |||||
| import data from './LocalaiText.json' | |||||
| import IconBase from '@/app/components/base/icons/IconBase' | |||||
| import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase' | |||||
| const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>(( | |||||
| props, | |||||
| ref, | |||||
| ) => <IconBase {...props} ref={ref} data={data as IconData} />) | |||||
| export default Icon | 
| export { default as IflytekSparkTextCn } from './IflytekSparkTextCn' | export { default as IflytekSparkTextCn } from './IflytekSparkTextCn' | ||||
| export { default as IflytekSparkText } from './IflytekSparkText' | export { default as IflytekSparkText } from './IflytekSparkText' | ||||
| export { default as IflytekSpark } from './IflytekSpark' | export { default as IflytekSpark } from './IflytekSpark' | ||||
| export { default as LocalaiText } from './LocalaiText' | |||||
| export { default as Localai } from './Localai' | |||||
| export { default as Microsoft } from './Microsoft' | export { default as Microsoft } from './Microsoft' | ||||
| export { default as OpenaiBlack } from './OpenaiBlack' | export { default as OpenaiBlack } from './OpenaiBlack' | ||||
| export { default as OpenaiBlue } from './OpenaiBlue' | export { default as OpenaiBlue } from './OpenaiBlue' | 
| import chatglm from './chatglm' | import chatglm from './chatglm' | ||||
| import xinference from './xinference' | import xinference from './xinference' | ||||
| import openllm from './openllm' | import openllm from './openllm' | ||||
| import localai from './localai' | |||||
| export default { | export default { | ||||
| openai, | openai, | ||||
| chatglm, | chatglm, | ||||
| xinference, | xinference, | ||||
| openllm, | openllm, | ||||
| localai, | |||||
| } | } | 
| import { ProviderEnum } from '../declarations' | |||||
| import type { FormValue, ProviderConfig } from '../declarations' | |||||
| import { Localai, LocalaiText } from '@/app/components/base/icons/src/public/llm' | |||||
| const config: ProviderConfig = { | |||||
| selector: { | |||||
| name: { | |||||
| 'en': 'LocalAI', | |||||
| 'zh-Hans': 'LocalAI', | |||||
| }, | |||||
| icon: <Localai className='w-full h-full' />, | |||||
| }, | |||||
| item: { | |||||
| key: ProviderEnum.localai, | |||||
| titleIcon: { | |||||
| 'en': <LocalaiText className='h-6' />, | |||||
| 'zh-Hans': <LocalaiText className='h-6' />, | |||||
| }, | |||||
| disable: { | |||||
| tip: { | |||||
| 'en': 'Only supports the ', | |||||
| 'zh-Hans': '仅支持', | |||||
| }, | |||||
| link: { | |||||
| href: { | |||||
| 'en': 'https://docs.dify.ai/getting-started/install-self-hosted', | |||||
| 'zh-Hans': 'https://docs.dify.ai/v/zh-hans/getting-started/install-self-hosted', | |||||
| }, | |||||
| label: { | |||||
| 'en': 'community open-source version', | |||||
| 'zh-Hans': '社区开源版本', | |||||
| }, | |||||
| }, | |||||
| }, | |||||
| }, | |||||
| modal: { | |||||
| key: ProviderEnum.localai, | |||||
| title: { | |||||
| 'en': 'LocalAI', | |||||
| 'zh-Hans': 'LocalAI', | |||||
| }, | |||||
| icon: <Localai className='h-6' />, | |||||
| link: { | |||||
| href: 'https://github.com/go-skynet/LocalAI', | |||||
| label: { | |||||
| 'en': 'How to deploy LocalAI', | |||||
| 'zh-Hans': '如何部署 LocalAI', | |||||
| }, | |||||
| }, | |||||
| defaultValue: { | |||||
| model_type: 'text-generation', | |||||
| completion_type: 'completion', | |||||
| }, | |||||
| validateKeys: (v?: FormValue) => { | |||||
| if (v?.model_type === 'text-generation') { | |||||
| return [ | |||||
| 'model_type', | |||||
| 'model_name', | |||||
| 'server_url', | |||||
| 'completion_type', | |||||
| ] | |||||
| } | |||||
| if (v?.model_type === 'embeddings') { | |||||
| return [ | |||||
| 'model_type', | |||||
| 'model_name', | |||||
| 'server_url', | |||||
| ] | |||||
| } | |||||
| return [] | |||||
| }, | |||||
| filterValue: (v?: FormValue) => { | |||||
| let filteredKeys: string[] = [] | |||||
| if (v?.model_type === 'text-generation') { | |||||
| filteredKeys = [ | |||||
| 'model_type', | |||||
| 'model_name', | |||||
| 'server_url', | |||||
| 'completion_type', | |||||
| ] | |||||
| } | |||||
| if (v?.model_type === 'embeddings') { | |||||
| filteredKeys = [ | |||||
| 'model_type', | |||||
| 'model_name', | |||||
| 'server_url', | |||||
| ] | |||||
| } | |||||
| return filteredKeys.reduce((prev: FormValue, next: string) => { | |||||
| prev[next] = v?.[next] || '' | |||||
| return prev | |||||
| }, {}) | |||||
| }, | |||||
| fields: [ | |||||
| { | |||||
| type: 'radio', | |||||
| key: 'model_type', | |||||
| required: true, | |||||
| label: { | |||||
| 'en': 'Model Type', | |||||
| 'zh-Hans': '模型类型', | |||||
| }, | |||||
| options: [ | |||||
| { | |||||
| key: 'text-generation', | |||||
| label: { | |||||
| 'en': 'Text Generation', | |||||
| 'zh-Hans': '文本生成', | |||||
| }, | |||||
| }, | |||||
| { | |||||
| key: 'embeddings', | |||||
| label: { | |||||
| 'en': 'Embeddings', | |||||
| 'zh-Hans': 'Embeddings', | |||||
| }, | |||||
| }, | |||||
| ], | |||||
| }, | |||||
| { | |||||
| type: 'text', | |||||
| key: 'model_name', | |||||
| required: true, | |||||
| label: { | |||||
| 'en': 'Model Name', | |||||
| 'zh-Hans': '模型名称', | |||||
| }, | |||||
| placeholder: { | |||||
| 'en': 'Enter your Model Name here', | |||||
| 'zh-Hans': '在此输入您的模型名称', | |||||
| }, | |||||
| }, | |||||
| { | |||||
| hidden: (value?: FormValue) => value?.model_type === 'embeddings', | |||||
| type: 'radio', | |||||
| key: 'completion_type', | |||||
| required: true, | |||||
| label: { | |||||
| 'en': 'Completion Type', | |||||
| 'zh-Hans': 'Completion Type', | |||||
| }, | |||||
| options: [ | |||||
| { | |||||
| key: 'completion', | |||||
| label: { | |||||
| 'en': 'Completion', | |||||
| 'zh-Hans': 'Completion', | |||||
| }, | |||||
| }, | |||||
| { | |||||
| key: 'chat_completion', | |||||
| label: { | |||||
| 'en': 'Chat Completion', | |||||
| 'zh-Hans': 'Chat Completion', | |||||
| }, | |||||
| }, | |||||
| ], | |||||
| }, | |||||
| { | |||||
| type: 'text', | |||||
| key: 'server_url', | |||||
| required: true, | |||||
| label: { | |||||
| 'en': 'Server url', | |||||
| 'zh-Hans': 'Server url', | |||||
| }, | |||||
| placeholder: { | |||||
| 'en': 'Enter your Server Url, eg: https://example.com/xxx', | |||||
| 'zh-Hans': '在此输入您的 Server Url,如:https://example.com/xxx', | |||||
| }, | |||||
| }, | |||||
| ], | |||||
| }, | |||||
| } | |||||
| export default config | 
| 'chatglm' = 'chatglm', | 'chatglm' = 'chatglm', | ||||
| 'xinference' = 'xinference', | 'xinference' = 'xinference', | ||||
| 'openllm' = 'openllm', | 'openllm' = 'openllm', | ||||
| 'localai' = 'localai', | |||||
| } | } | ||||
| export type ProviderConfigItem = { | export type ProviderConfigItem = { | 
| config.chatglm, | config.chatglm, | ||||
| config.xinference, | config.xinference, | ||||
| config.openllm, | config.openllm, | ||||
| config.localai, | |||||
| ] | ] | ||||
| } | } | ||||
| import { ProviderEnum } from './declarations' | import { ProviderEnum } from './declarations' | ||||
| import { validateModelProvider } from '@/service/common' | import { validateModelProvider } from '@/service/common' | ||||
| export const ConfigurableProviders = [ProviderEnum.azure_openai, ProviderEnum.replicate, ProviderEnum.huggingface_hub, ProviderEnum.xinference, ProviderEnum.openllm] | |||||
| export const ConfigurableProviders = [ProviderEnum.azure_openai, ProviderEnum.replicate, ProviderEnum.huggingface_hub, ProviderEnum.xinference, ProviderEnum.openllm, ProviderEnum.localai] | |||||
| export const validateModelProviderFn = async (providerName: ProviderEnum, v: any) => { | export const validateModelProviderFn = async (providerName: ProviderEnum, v: any) => { | ||||
| let body, url | let body, url |