| @@ -60,6 +60,9 @@ class ModelProviderFactory: | |||
| elif provider_name == 'xinference': | |||
| from core.model_providers.providers.xinference_provider import XinferenceProvider | |||
| return XinferenceProvider | |||
| elif provider_name == 'openllm': | |||
| from core.model_providers.providers.openllm_provider import OpenLLMProvider | |||
| return OpenLLMProvider | |||
| else: | |||
| raise NotImplementedError | |||
| @@ -0,0 +1,60 @@ | |||
| from typing import List, Optional, Any | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.llms import OpenLLM | |||
| from langchain.schema import LLMResult | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.model_providers.models.entity.message import PromptMessage | |||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | |||
| class OpenLLMModel(BaseLLM): | |||
| model_mode: ModelMode = ModelMode.COMPLETION | |||
| def _init_client(self) -> Any: | |||
| self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | |||
| client = OpenLLM( | |||
| server_url=self.credentials.get('server_url'), | |||
| callbacks=self.callbacks, | |||
| **self.provider_model_kwargs | |||
| ) | |||
| return client | |||
| def _run(self, messages: List[PromptMessage], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| **kwargs) -> LLMResult: | |||
| """ | |||
| run predict by prompt messages and stop words. | |||
| :param messages: | |||
| :param stop: | |||
| :param callbacks: | |||
| :return: | |||
| """ | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| return self._client.generate([prompts], stop, callbacks) | |||
| def get_num_tokens(self, messages: List[PromptMessage]) -> int: | |||
| """ | |||
| get num tokens of prompt messages. | |||
| :param messages: | |||
| :return: | |||
| """ | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| return max(self._client.get_num_tokens(prompts), 0) | |||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||
| pass | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| return LLMBadRequestError(f"OpenLLM: {str(ex)}") | |||
| @classmethod | |||
| def support_streaming(cls): | |||
| return False | |||
| @@ -0,0 +1,137 @@ | |||
| import json | |||
| from typing import Type | |||
| from langchain.llms import OpenLLM | |||
| from core.helper import encrypter | |||
| from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType | |||
| from core.model_providers.models.llm.openllm_model import OpenLLMModel | |||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from models.provider import ProviderType | |||
| class OpenLLMProvider(BaseModelProvider): | |||
| @property | |||
| def provider_name(self): | |||
| """ | |||
| Returns the name of a provider. | |||
| """ | |||
| return 'openllm' | |||
| def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: | |||
| return [] | |||
| def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: | |||
| """ | |||
| Returns the model class. | |||
| :param model_type: | |||
| :return: | |||
| """ | |||
| if model_type == ModelType.TEXT_GENERATION: | |||
| model_class = OpenLLMModel | |||
| else: | |||
| raise NotImplementedError | |||
| return model_class | |||
| def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules: | |||
| """ | |||
| get model parameter rules. | |||
| :param model_name: | |||
| :param model_type: | |||
| :return: | |||
| """ | |||
| return ModelKwargsRules( | |||
| temperature=KwargRule[float](min=0, max=2, default=1), | |||
| top_p=KwargRule[float](min=0, max=1, default=0.7), | |||
| presence_penalty=KwargRule[float](min=-2, max=2, default=0), | |||
| frequency_penalty=KwargRule[float](min=-2, max=2, default=0), | |||
| max_tokens=KwargRule[int](min=10, max=4000, default=128), | |||
| ) | |||
| @classmethod | |||
| def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): | |||
| """ | |||
| check model credentials valid. | |||
| :param model_name: | |||
| :param model_type: | |||
| :param credentials: | |||
| """ | |||
| if 'server_url' not in credentials: | |||
| raise CredentialsValidateFailedError('OpenLLM Server URL must be provided.') | |||
| try: | |||
| credential_kwargs = { | |||
| 'server_url': credentials['server_url'] | |||
| } | |||
| llm = OpenLLM( | |||
| max_tokens=10, | |||
| **credential_kwargs | |||
| ) | |||
| llm("ping") | |||
| except Exception as ex: | |||
| raise CredentialsValidateFailedError(str(ex)) | |||
| @classmethod | |||
| def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, | |||
| credentials: dict) -> dict: | |||
| """ | |||
| encrypt model credentials for save. | |||
| :param tenant_id: | |||
| :param model_name: | |||
| :param model_type: | |||
| :param credentials: | |||
| :return: | |||
| """ | |||
| credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url']) | |||
| return credentials | |||
| def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: | |||
| """ | |||
| get credentials for llm use. | |||
| :param model_name: | |||
| :param model_type: | |||
| :param obfuscated: | |||
| :return: | |||
| """ | |||
| if self.provider.provider_type != ProviderType.CUSTOM.value: | |||
| raise NotImplementedError | |||
| provider_model = self._get_provider_model(model_name, model_type) | |||
| if not provider_model.encrypted_config: | |||
| return { | |||
| 'server_url': None | |||
| } | |||
| credentials = json.loads(provider_model.encrypted_config) | |||
| if credentials['server_url']: | |||
| credentials['server_url'] = encrypter.decrypt_token( | |||
| self.provider.tenant_id, | |||
| credentials['server_url'] | |||
| ) | |||
| if obfuscated: | |||
| credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url']) | |||
| return credentials | |||
| @classmethod | |||
| def is_provider_credentials_valid_or_raise(cls, credentials: dict): | |||
| return | |||
| @classmethod | |||
| def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: | |||
| return {} | |||
| def get_provider_credentials(self, obfuscated: bool = False) -> dict: | |||
| return {} | |||
| @@ -9,5 +9,6 @@ | |||
| "chatglm", | |||
| "replicate", | |||
| "huggingface_hub", | |||
| "xinference" | |||
| "xinference", | |||
| "openllm" | |||
| ] | |||
| @@ -0,0 +1,7 @@ | |||
| { | |||
| "support_provider_types": [ | |||
| "custom" | |||
| ], | |||
| "system_config": null, | |||
| "model_flexibility": "configurable" | |||
| } | |||
| @@ -49,4 +49,5 @@ huggingface_hub~=0.16.4 | |||
| transformers~=4.31.0 | |||
| stripe~=5.5.0 | |||
| pandas==1.5.3 | |||
| xinference==0.2.0 | |||
| xinference==0.2.0 | |||
| openllm~=0.2.26 | |||
| @@ -36,4 +36,7 @@ CHATGLM_API_BASE= | |||
| # Xinference Credentials | |||
| XINFERENCE_SERVER_URL= | |||
| XINFERENCE_MODEL_UID= | |||
| XINFERENCE_MODEL_UID= | |||
| # OpenLLM Credentials | |||
| OPENLLM_SERVER_URL= | |||
| @@ -0,0 +1,72 @@ | |||
| import json | |||
| import os | |||
| from unittest.mock import patch, MagicMock | |||
| from core.model_providers.models.entity.message import PromptMessage, MessageType | |||
| from core.model_providers.models.entity.model_params import ModelKwargs, ModelType | |||
| from core.model_providers.models.llm.openllm_model import OpenLLMModel | |||
| from core.model_providers.providers.openllm_provider import OpenLLMProvider | |||
| from models.provider import Provider, ProviderType, ProviderModel | |||
| def get_mock_provider(): | |||
| return Provider( | |||
| id='provider_id', | |||
| tenant_id='tenant_id', | |||
| provider_name='openllm', | |||
| provider_type=ProviderType.CUSTOM.value, | |||
| encrypted_config='', | |||
| is_valid=True, | |||
| ) | |||
| def get_mock_model(model_name, mocker): | |||
| model_kwargs = ModelKwargs( | |||
| max_tokens=10, | |||
| temperature=0.01 | |||
| ) | |||
| server_url = os.environ['OPENLLM_SERVER_URL'] | |||
| model_provider = OpenLLMProvider(provider=get_mock_provider()) | |||
| mock_query = MagicMock() | |||
| mock_query.filter.return_value.first.return_value = ProviderModel( | |||
| provider_name='openllm', | |||
| model_name=model_name, | |||
| model_type=ModelType.TEXT_GENERATION.value, | |||
| encrypted_config=json.dumps({ | |||
| 'server_url': server_url | |||
| }), | |||
| is_valid=True, | |||
| ) | |||
| mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) | |||
| return OpenLLMModel( | |||
| model_provider=model_provider, | |||
| name=model_name, | |||
| model_kwargs=model_kwargs | |||
| ) | |||
| def decrypt_side_effect(tenant_id, encrypted_api_key): | |||
| return encrypted_api_key | |||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||
| def test_get_num_tokens(mock_decrypt, mocker): | |||
| model = get_mock_model('facebook/opt-125m', mocker) | |||
| rst = model.get_num_tokens([ | |||
| PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') | |||
| ]) | |||
| assert rst == 5 | |||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||
| def test_run(mock_decrypt, mocker): | |||
| mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) | |||
| model = get_mock_model('facebook/opt-125m', mocker) | |||
| messages = [PromptMessage(content='Human: who are you? \nAnswer: ')] | |||
| rst = model.run( | |||
| messages | |||
| ) | |||
| assert len(rst.content) > 0 | |||
| @@ -0,0 +1,125 @@ | |||
| import pytest | |||
| from unittest.mock import patch, MagicMock | |||
| import json | |||
| from core.model_providers.models.entity.model_params import ModelType | |||
| from core.model_providers.providers.base import CredentialsValidateFailedError | |||
| from core.model_providers.providers.openllm_provider import OpenLLMProvider | |||
| from models.provider import ProviderType, Provider, ProviderModel | |||
| PROVIDER_NAME = 'openllm' | |||
| MODEL_PROVIDER_CLASS = OpenLLMProvider | |||
| VALIDATE_CREDENTIAL = { | |||
| 'server_url': 'http://127.0.0.1:3333/' | |||
| } | |||
| def encrypt_side_effect(tenant_id, encrypt_key): | |||
| return f'encrypted_{encrypt_key}' | |||
| def decrypt_side_effect(tenant_id, encrypted_key): | |||
| return encrypted_key.replace('encrypted_', '') | |||
| def test_is_credentials_valid_or_raise_valid(mocker): | |||
| mocker.patch('langchain.llms.openllm.OpenLLM._identifying_params', return_value=None) | |||
| mocker.patch('langchain.llms.openllm.OpenLLM._call', | |||
| return_value="abc") | |||
| MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( | |||
| model_name='username/test_model_name', | |||
| model_type=ModelType.TEXT_GENERATION, | |||
| credentials=VALIDATE_CREDENTIAL.copy() | |||
| ) | |||
| def test_is_credentials_valid_or_raise_invalid(mocker): | |||
| mocker.patch('langchain.llms.openllm.OpenLLM._identifying_params', return_value=None) | |||
| # raise CredentialsValidateFailedError if credential is not in credentials | |||
| with pytest.raises(CredentialsValidateFailedError): | |||
| MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( | |||
| model_name='test_model_name', | |||
| model_type=ModelType.TEXT_GENERATION, | |||
| credentials={} | |||
| ) | |||
| # raise CredentialsValidateFailedError if credential is invalid | |||
| with pytest.raises(CredentialsValidateFailedError): | |||
| MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( | |||
| model_name='test_model_name', | |||
| model_type=ModelType.TEXT_GENERATION, | |||
| credentials={'server_url': 'invalid'}) | |||
| @patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) | |||
| def test_encrypt_model_credentials(mock_encrypt): | |||
| api_key = 'http://127.0.0.1:3333/' | |||
| result = MODEL_PROVIDER_CLASS.encrypt_model_credentials( | |||
| tenant_id='tenant_id', | |||
| model_name='test_model_name', | |||
| model_type=ModelType.TEXT_GENERATION, | |||
| credentials=VALIDATE_CREDENTIAL.copy() | |||
| ) | |||
| mock_encrypt.assert_called_with('tenant_id', api_key) | |||
| assert result['server_url'] == f'encrypted_{api_key}' | |||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||
| def test_get_model_credentials_custom(mock_decrypt, mocker): | |||
| provider = Provider( | |||
| id='provider_id', | |||
| tenant_id='tenant_id', | |||
| provider_name=PROVIDER_NAME, | |||
| provider_type=ProviderType.CUSTOM.value, | |||
| encrypted_config=None, | |||
| is_valid=True, | |||
| ) | |||
| encrypted_credential = VALIDATE_CREDENTIAL.copy() | |||
| encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url'] | |||
| mock_query = MagicMock() | |||
| mock_query.filter.return_value.first.return_value = ProviderModel( | |||
| encrypted_config=json.dumps(encrypted_credential) | |||
| ) | |||
| mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) | |||
| model_provider = MODEL_PROVIDER_CLASS(provider=provider) | |||
| result = model_provider.get_model_credentials( | |||
| model_name='test_model_name', | |||
| model_type=ModelType.TEXT_GENERATION | |||
| ) | |||
| assert result['server_url'] == 'http://127.0.0.1:3333/' | |||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||
| def test_get_model_credentials_obfuscated(mock_decrypt, mocker): | |||
| provider = Provider( | |||
| id='provider_id', | |||
| tenant_id='tenant_id', | |||
| provider_name=PROVIDER_NAME, | |||
| provider_type=ProviderType.CUSTOM.value, | |||
| encrypted_config=None, | |||
| is_valid=True, | |||
| ) | |||
| encrypted_credential = VALIDATE_CREDENTIAL.copy() | |||
| encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url'] | |||
| mock_query = MagicMock() | |||
| mock_query.filter.return_value.first.return_value = ProviderModel( | |||
| encrypted_config=json.dumps(encrypted_credential) | |||
| ) | |||
| mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) | |||
| model_provider = MODEL_PROVIDER_CLASS(provider=provider) | |||
| result = model_provider.get_model_credentials( | |||
| model_name='test_model_name', | |||
| model_type=ModelType.TEXT_GENERATION, | |||
| obfuscated=True | |||
| ) | |||
| middle_token = result['server_url'][6:-2] | |||
| assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['server_url']) - 8, 0) | |||
| assert all(char == '*' for char in middle_token) | |||