| @@ -1,13 +1,13 @@ | |||
| from typing import List, Optional, Any | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.llms import Xinference | |||
| from langchain.schema import LLMResult | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.model_providers.models.entity.message import PromptMessage | |||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | |||
| from core.third_party.langchain.llms.xinference_llm import XinferenceLLM | |||
| class XinferenceModel(BaseLLM): | |||
| @@ -16,8 +16,9 @@ class XinferenceModel(BaseLLM): | |||
| def _init_client(self) -> Any: | |||
| self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | |||
| client = Xinference( | |||
| **self.credentials, | |||
| client = XinferenceLLM( | |||
| server_url=self.credentials['server_url'], | |||
| model_uid=self.credentials['model_uid'], | |||
| ) | |||
| client.callbacks = self.callbacks | |||
| @@ -1,7 +1,8 @@ | |||
| import json | |||
| from typing import Type | |||
| from langchain.llms import Xinference | |||
| import requests | |||
| from xinference.client import RESTfulGenerateModelHandle, RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle | |||
| from core.helper import encrypter | |||
| from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding | |||
| @@ -10,6 +11,7 @@ from core.model_providers.models.llm.xinference_model import XinferenceModel | |||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.third_party.langchain.llms.xinference_llm import XinferenceLLM | |||
| from models.provider import ProviderType | |||
| @@ -48,13 +50,32 @@ class XinferenceProvider(BaseModelProvider): | |||
| :param model_type: | |||
| :return: | |||
| """ | |||
| return ModelKwargsRules( | |||
| temperature=KwargRule[float](min=0, max=2, default=1), | |||
| top_p=KwargRule[float](min=0, max=1, default=0.7), | |||
| presence_penalty=KwargRule[float](min=-2, max=2, default=0), | |||
| frequency_penalty=KwargRule[float](min=-2, max=2, default=0), | |||
| max_tokens=KwargRule[int](min=10, max=4000, default=256), | |||
| ) | |||
| credentials = self.get_model_credentials(model_name, model_type) | |||
| if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm": | |||
| return ModelKwargsRules( | |||
| temperature=KwargRule[float](min=0.01, max=2, default=1), | |||
| top_p=KwargRule[float](min=0, max=1, default=0.7), | |||
| presence_penalty=KwargRule[float](enabled=False), | |||
| frequency_penalty=KwargRule[float](enabled=False), | |||
| max_tokens=KwargRule[int](min=10, max=4000, default=256), | |||
| ) | |||
| elif credentials['model_format'] == "ggmlv3": | |||
| return ModelKwargsRules( | |||
| temperature=KwargRule[float](min=0.01, max=2, default=1), | |||
| top_p=KwargRule[float](min=0, max=1, default=0.7), | |||
| presence_penalty=KwargRule[float](min=-2, max=2, default=0), | |||
| frequency_penalty=KwargRule[float](min=-2, max=2, default=0), | |||
| max_tokens=KwargRule[int](min=10, max=4000, default=256), | |||
| ) | |||
| else: | |||
| return ModelKwargsRules( | |||
| temperature=KwargRule[float](min=0.01, max=2, default=1), | |||
| top_p=KwargRule[float](min=0, max=1, default=0.7), | |||
| presence_penalty=KwargRule[float](enabled=False), | |||
| frequency_penalty=KwargRule[float](enabled=False), | |||
| max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=256), | |||
| ) | |||
| @classmethod | |||
| def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): | |||
| @@ -77,11 +98,11 @@ class XinferenceProvider(BaseModelProvider): | |||
| 'model_uid': credentials['model_uid'], | |||
| } | |||
| llm = Xinference( | |||
| llm = XinferenceLLM( | |||
| **credential_kwargs | |||
| ) | |||
| llm("ping", generate_config={'max_tokens': 10}) | |||
| llm("ping") | |||
| except Exception as ex: | |||
| raise CredentialsValidateFailedError(str(ex)) | |||
| @@ -97,7 +118,11 @@ class XinferenceProvider(BaseModelProvider): | |||
| :param credentials: | |||
| :return: | |||
| """ | |||
| extra_credentials = cls._get_extra_credentials(credentials) | |||
| credentials.update(extra_credentials) | |||
| credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url']) | |||
| return credentials | |||
| def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: | |||
| @@ -132,6 +157,30 @@ class XinferenceProvider(BaseModelProvider): | |||
| return credentials | |||
| @classmethod | |||
| def _get_extra_credentials(self, credentials: dict) -> dict: | |||
| url = f"{credentials['server_url']}/v1/models/{credentials['model_uid']}" | |||
| response = requests.get(url) | |||
| if response.status_code != 200: | |||
| raise RuntimeError( | |||
| f"Failed to get the model description, detail: {response.json()['detail']}" | |||
| ) | |||
| desc = response.json() | |||
| extra_credentials = { | |||
| 'model_format': desc['model_format'], | |||
| } | |||
| if desc["model_format"] == "ggmlv3" and "chatglm" in desc["model_name"]: | |||
| extra_credentials['model_handle_type'] = 'chatglm' | |||
| elif "generate" in desc["model_ability"]: | |||
| extra_credentials['model_handle_type'] = 'generate' | |||
| elif "chat" in desc["model_ability"]: | |||
| extra_credentials['model_handle_type'] = 'chat' | |||
| else: | |||
| raise NotImplementedError(f"Model handle type not supported.") | |||
| return extra_credentials | |||
| @classmethod | |||
| def is_provider_credentials_valid_or_raise(cls, credentials: dict): | |||
| return | |||
| @@ -0,0 +1,132 @@ | |||
| from typing import Optional, List, Any, Union, Generator | |||
| from langchain.callbacks.manager import CallbackManagerForLLMRun | |||
| from langchain.llms import Xinference | |||
| from langchain.llms.utils import enforce_stop_tokens | |||
| from xinference.client import RESTfulChatglmCppChatModelHandle, \ | |||
| RESTfulChatModelHandle, RESTfulGenerateModelHandle | |||
| class XinferenceLLM(Xinference): | |||
| def _call( | |||
| self, | |||
| prompt: str, | |||
| stop: Optional[List[str]] = None, | |||
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |||
| **kwargs: Any, | |||
| ) -> str: | |||
| """Call the xinference model and return the output. | |||
| Args: | |||
| prompt: The prompt to use for generation. | |||
| stop: Optional list of stop words to use when generating. | |||
| generate_config: Optional dictionary for the configuration used for | |||
| generation. | |||
| Returns: | |||
| The generated string by the model. | |||
| """ | |||
| model = self.client.get_model(self.model_uid) | |||
| if isinstance(model, RESTfulChatModelHandle): | |||
| generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {}) | |||
| if stop: | |||
| generate_config["stop"] = stop | |||
| if generate_config and generate_config.get("stream"): | |||
| combined_text_output = "" | |||
| for token in self._stream_generate( | |||
| model=model, | |||
| prompt=prompt, | |||
| run_manager=run_manager, | |||
| generate_config=generate_config, | |||
| ): | |||
| combined_text_output += token | |||
| return combined_text_output | |||
| else: | |||
| completion = model.chat(prompt=prompt, generate_config=generate_config) | |||
| return completion["choices"][0]["text"] | |||
| elif isinstance(model, RESTfulGenerateModelHandle): | |||
| generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {}) | |||
| if stop: | |||
| generate_config["stop"] = stop | |||
| if generate_config and generate_config.get("stream"): | |||
| combined_text_output = "" | |||
| for token in self._stream_generate( | |||
| model=model, | |||
| prompt=prompt, | |||
| run_manager=run_manager, | |||
| generate_config=generate_config, | |||
| ): | |||
| combined_text_output += token | |||
| return combined_text_output | |||
| else: | |||
| completion = model.generate(prompt=prompt, generate_config=generate_config) | |||
| return completion["choices"][0]["text"] | |||
| elif isinstance(model, RESTfulChatglmCppChatModelHandle): | |||
| generate_config: "ChatglmCppGenerateConfig" = kwargs.get("generate_config", {}) | |||
| if generate_config and generate_config.get("stream"): | |||
| combined_text_output = "" | |||
| for token in self._stream_generate( | |||
| model=model, | |||
| prompt=prompt, | |||
| run_manager=run_manager, | |||
| generate_config=generate_config, | |||
| ): | |||
| combined_text_output += token | |||
| completion = combined_text_output | |||
| else: | |||
| completion = model.chat(prompt=prompt, generate_config=generate_config) | |||
| completion = completion["choices"][0]["text"] | |||
| if stop is not None: | |||
| completion = enforce_stop_tokens(completion, stop) | |||
| return completion | |||
| def _stream_generate( | |||
| self, | |||
| model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"], | |||
| prompt: str, | |||
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |||
| generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None, | |||
| ) -> Generator[str, None, None]: | |||
| """ | |||
| Args: | |||
| prompt: The prompt to use for generation. | |||
| model: The model used for generation. | |||
| stop: Optional list of stop words to use when generating. | |||
| generate_config: Optional dictionary for the configuration used for | |||
| generation. | |||
| Yields: | |||
| A string token. | |||
| """ | |||
| if isinstance(model, RESTfulGenerateModelHandle): | |||
| streaming_response = model.generate( | |||
| prompt=prompt, generate_config=generate_config | |||
| ) | |||
| else: | |||
| streaming_response = model.chat( | |||
| prompt=prompt, generate_config=generate_config | |||
| ) | |||
| for chunk in streaming_response: | |||
| if isinstance(chunk, dict): | |||
| choices = chunk.get("choices", []) | |||
| if choices: | |||
| choice = choices[0] | |||
| if isinstance(choice, dict): | |||
| token = choice.get("text", "") | |||
| log_probs = choice.get("logprobs") | |||
| if run_manager: | |||
| run_manager.on_llm_new_token( | |||
| token=token, verbose=self.verbose, log_probs=log_probs | |||
| ) | |||
| yield token | |||
| @@ -4,7 +4,6 @@ import json | |||
| from core.model_providers.models.entity.model_params import ModelType | |||
| from core.model_providers.providers.base import CredentialsValidateFailedError | |||
| from core.model_providers.providers.replicate_provider import ReplicateProvider | |||
| from core.model_providers.providers.xinference_provider import XinferenceProvider | |||
| from models.provider import ProviderType, Provider, ProviderModel | |||
| @@ -25,7 +24,7 @@ def decrypt_side_effect(tenant_id, encrypted_key): | |||
| def test_is_credentials_valid_or_raise_valid(mocker): | |||
| mocker.patch('langchain.llms.xinference.Xinference._call', | |||
| mocker.patch('core.third_party.langchain.llms.xinference_llm.XinferenceLLM._call', | |||
| return_value="abc") | |||
| MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( | |||
| @@ -53,8 +52,15 @@ def test_is_credentials_valid_or_raise_invalid(): | |||
| @patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) | |||
| def test_encrypt_model_credentials(mock_encrypt): | |||
| def test_encrypt_model_credentials(mock_encrypt, mocker): | |||
| api_key = 'http://127.0.0.1:9997/' | |||
| mocker.patch('core.model_providers.providers.xinference_provider.XinferenceProvider._get_extra_credentials', | |||
| return_value={ | |||
| 'model_handle_type': 'generate', | |||
| 'model_format': 'ggmlv3' | |||
| }) | |||
| result = MODEL_PROVIDER_CLASS.encrypt_model_credentials( | |||
| tenant_id='tenant_id', | |||
| model_name='test_model_name', | |||