| from typing import List, Optional, Any | from typing import List, Optional, Any | ||||
| from langchain.callbacks.manager import Callbacks | from langchain.callbacks.manager import Callbacks | ||||
| from langchain.llms import Xinference | |||||
| from langchain.schema import LLMResult | from langchain.schema import LLMResult | ||||
| from core.model_providers.error import LLMBadRequestError | from core.model_providers.error import LLMBadRequestError | ||||
| from core.model_providers.models.llm.base import BaseLLM | from core.model_providers.models.llm.base import BaseLLM | ||||
| from core.model_providers.models.entity.message import PromptMessage | from core.model_providers.models.entity.message import PromptMessage | ||||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | ||||
| from core.third_party.langchain.llms.xinference_llm import XinferenceLLM | |||||
| class XinferenceModel(BaseLLM): | class XinferenceModel(BaseLLM): | ||||
| def _init_client(self) -> Any: | def _init_client(self) -> Any: | ||||
| self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | ||||
| client = Xinference( | |||||
| **self.credentials, | |||||
| client = XinferenceLLM( | |||||
| server_url=self.credentials['server_url'], | |||||
| model_uid=self.credentials['model_uid'], | |||||
| ) | ) | ||||
| client.callbacks = self.callbacks | client.callbacks = self.callbacks |
| import json | import json | ||||
| from typing import Type | from typing import Type | ||||
| from langchain.llms import Xinference | |||||
| import requests | |||||
| from xinference.client import RESTfulGenerateModelHandle, RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle | |||||
| from core.helper import encrypter | from core.helper import encrypter | ||||
| from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding | from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding | ||||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | ||||
| from core.model_providers.models.base import BaseProviderModel | from core.model_providers.models.base import BaseProviderModel | ||||
| from core.third_party.langchain.llms.xinference_llm import XinferenceLLM | |||||
| from models.provider import ProviderType | from models.provider import ProviderType | ||||
| :param model_type: | :param model_type: | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| return ModelKwargsRules( | |||||
| temperature=KwargRule[float](min=0, max=2, default=1), | |||||
| top_p=KwargRule[float](min=0, max=1, default=0.7), | |||||
| presence_penalty=KwargRule[float](min=-2, max=2, default=0), | |||||
| frequency_penalty=KwargRule[float](min=-2, max=2, default=0), | |||||
| max_tokens=KwargRule[int](min=10, max=4000, default=256), | |||||
| ) | |||||
| credentials = self.get_model_credentials(model_name, model_type) | |||||
| if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm": | |||||
| return ModelKwargsRules( | |||||
| temperature=KwargRule[float](min=0.01, max=2, default=1), | |||||
| top_p=KwargRule[float](min=0, max=1, default=0.7), | |||||
| presence_penalty=KwargRule[float](enabled=False), | |||||
| frequency_penalty=KwargRule[float](enabled=False), | |||||
| max_tokens=KwargRule[int](min=10, max=4000, default=256), | |||||
| ) | |||||
| elif credentials['model_format'] == "ggmlv3": | |||||
| return ModelKwargsRules( | |||||
| temperature=KwargRule[float](min=0.01, max=2, default=1), | |||||
| top_p=KwargRule[float](min=0, max=1, default=0.7), | |||||
| presence_penalty=KwargRule[float](min=-2, max=2, default=0), | |||||
| frequency_penalty=KwargRule[float](min=-2, max=2, default=0), | |||||
| max_tokens=KwargRule[int](min=10, max=4000, default=256), | |||||
| ) | |||||
| else: | |||||
| return ModelKwargsRules( | |||||
| temperature=KwargRule[float](min=0.01, max=2, default=1), | |||||
| top_p=KwargRule[float](min=0, max=1, default=0.7), | |||||
| presence_penalty=KwargRule[float](enabled=False), | |||||
| frequency_penalty=KwargRule[float](enabled=False), | |||||
| max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=256), | |||||
| ) | |||||
| @classmethod | @classmethod | ||||
| def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): | def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): | ||||
| 'model_uid': credentials['model_uid'], | 'model_uid': credentials['model_uid'], | ||||
| } | } | ||||
| llm = Xinference( | |||||
| llm = XinferenceLLM( | |||||
| **credential_kwargs | **credential_kwargs | ||||
| ) | ) | ||||
| llm("ping", generate_config={'max_tokens': 10}) | |||||
| llm("ping") | |||||
| except Exception as ex: | except Exception as ex: | ||||
| raise CredentialsValidateFailedError(str(ex)) | raise CredentialsValidateFailedError(str(ex)) | ||||
| :param credentials: | :param credentials: | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| extra_credentials = cls._get_extra_credentials(credentials) | |||||
| credentials.update(extra_credentials) | |||||
| credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url']) | credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url']) | ||||
| return credentials | return credentials | ||||
| def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: | def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: | ||||
| return credentials | return credentials | ||||
| @classmethod | |||||
| def _get_extra_credentials(self, credentials: dict) -> dict: | |||||
| url = f"{credentials['server_url']}/v1/models/{credentials['model_uid']}" | |||||
| response = requests.get(url) | |||||
| if response.status_code != 200: | |||||
| raise RuntimeError( | |||||
| f"Failed to get the model description, detail: {response.json()['detail']}" | |||||
| ) | |||||
| desc = response.json() | |||||
| extra_credentials = { | |||||
| 'model_format': desc['model_format'], | |||||
| } | |||||
| if desc["model_format"] == "ggmlv3" and "chatglm" in desc["model_name"]: | |||||
| extra_credentials['model_handle_type'] = 'chatglm' | |||||
| elif "generate" in desc["model_ability"]: | |||||
| extra_credentials['model_handle_type'] = 'generate' | |||||
| elif "chat" in desc["model_ability"]: | |||||
| extra_credentials['model_handle_type'] = 'chat' | |||||
| else: | |||||
| raise NotImplementedError(f"Model handle type not supported.") | |||||
| return extra_credentials | |||||
| @classmethod | @classmethod | ||||
| def is_provider_credentials_valid_or_raise(cls, credentials: dict): | def is_provider_credentials_valid_or_raise(cls, credentials: dict): | ||||
| return | return |
| from typing import Optional, List, Any, Union, Generator | |||||
| from langchain.callbacks.manager import CallbackManagerForLLMRun | |||||
| from langchain.llms import Xinference | |||||
| from langchain.llms.utils import enforce_stop_tokens | |||||
| from xinference.client import RESTfulChatglmCppChatModelHandle, \ | |||||
| RESTfulChatModelHandle, RESTfulGenerateModelHandle | |||||
| class XinferenceLLM(Xinference): | |||||
| def _call( | |||||
| self, | |||||
| prompt: str, | |||||
| stop: Optional[List[str]] = None, | |||||
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |||||
| **kwargs: Any, | |||||
| ) -> str: | |||||
| """Call the xinference model and return the output. | |||||
| Args: | |||||
| prompt: The prompt to use for generation. | |||||
| stop: Optional list of stop words to use when generating. | |||||
| generate_config: Optional dictionary for the configuration used for | |||||
| generation. | |||||
| Returns: | |||||
| The generated string by the model. | |||||
| """ | |||||
| model = self.client.get_model(self.model_uid) | |||||
| if isinstance(model, RESTfulChatModelHandle): | |||||
| generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {}) | |||||
| if stop: | |||||
| generate_config["stop"] = stop | |||||
| if generate_config and generate_config.get("stream"): | |||||
| combined_text_output = "" | |||||
| for token in self._stream_generate( | |||||
| model=model, | |||||
| prompt=prompt, | |||||
| run_manager=run_manager, | |||||
| generate_config=generate_config, | |||||
| ): | |||||
| combined_text_output += token | |||||
| return combined_text_output | |||||
| else: | |||||
| completion = model.chat(prompt=prompt, generate_config=generate_config) | |||||
| return completion["choices"][0]["text"] | |||||
| elif isinstance(model, RESTfulGenerateModelHandle): | |||||
| generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {}) | |||||
| if stop: | |||||
| generate_config["stop"] = stop | |||||
| if generate_config and generate_config.get("stream"): | |||||
| combined_text_output = "" | |||||
| for token in self._stream_generate( | |||||
| model=model, | |||||
| prompt=prompt, | |||||
| run_manager=run_manager, | |||||
| generate_config=generate_config, | |||||
| ): | |||||
| combined_text_output += token | |||||
| return combined_text_output | |||||
| else: | |||||
| completion = model.generate(prompt=prompt, generate_config=generate_config) | |||||
| return completion["choices"][0]["text"] | |||||
| elif isinstance(model, RESTfulChatglmCppChatModelHandle): | |||||
| generate_config: "ChatglmCppGenerateConfig" = kwargs.get("generate_config", {}) | |||||
| if generate_config and generate_config.get("stream"): | |||||
| combined_text_output = "" | |||||
| for token in self._stream_generate( | |||||
| model=model, | |||||
| prompt=prompt, | |||||
| run_manager=run_manager, | |||||
| generate_config=generate_config, | |||||
| ): | |||||
| combined_text_output += token | |||||
| completion = combined_text_output | |||||
| else: | |||||
| completion = model.chat(prompt=prompt, generate_config=generate_config) | |||||
| completion = completion["choices"][0]["text"] | |||||
| if stop is not None: | |||||
| completion = enforce_stop_tokens(completion, stop) | |||||
| return completion | |||||
| def _stream_generate( | |||||
| self, | |||||
| model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"], | |||||
| prompt: str, | |||||
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |||||
| generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None, | |||||
| ) -> Generator[str, None, None]: | |||||
| """ | |||||
| Args: | |||||
| prompt: The prompt to use for generation. | |||||
| model: The model used for generation. | |||||
| stop: Optional list of stop words to use when generating. | |||||
| generate_config: Optional dictionary for the configuration used for | |||||
| generation. | |||||
| Yields: | |||||
| A string token. | |||||
| """ | |||||
| if isinstance(model, RESTfulGenerateModelHandle): | |||||
| streaming_response = model.generate( | |||||
| prompt=prompt, generate_config=generate_config | |||||
| ) | |||||
| else: | |||||
| streaming_response = model.chat( | |||||
| prompt=prompt, generate_config=generate_config | |||||
| ) | |||||
| for chunk in streaming_response: | |||||
| if isinstance(chunk, dict): | |||||
| choices = chunk.get("choices", []) | |||||
| if choices: | |||||
| choice = choices[0] | |||||
| if isinstance(choice, dict): | |||||
| token = choice.get("text", "") | |||||
| log_probs = choice.get("logprobs") | |||||
| if run_manager: | |||||
| run_manager.on_llm_new_token( | |||||
| token=token, verbose=self.verbose, log_probs=log_probs | |||||
| ) | |||||
| yield token |
| from core.model_providers.models.entity.model_params import ModelType | from core.model_providers.models.entity.model_params import ModelType | ||||
| from core.model_providers.providers.base import CredentialsValidateFailedError | from core.model_providers.providers.base import CredentialsValidateFailedError | ||||
| from core.model_providers.providers.replicate_provider import ReplicateProvider | |||||
| from core.model_providers.providers.xinference_provider import XinferenceProvider | from core.model_providers.providers.xinference_provider import XinferenceProvider | ||||
| from models.provider import ProviderType, Provider, ProviderModel | from models.provider import ProviderType, Provider, ProviderModel | ||||
| def test_is_credentials_valid_or_raise_valid(mocker): | def test_is_credentials_valid_or_raise_valid(mocker): | ||||
| mocker.patch('langchain.llms.xinference.Xinference._call', | |||||
| mocker.patch('core.third_party.langchain.llms.xinference_llm.XinferenceLLM._call', | |||||
| return_value="abc") | return_value="abc") | ||||
| MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( | MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( | ||||
| @patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) | @patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) | ||||
| def test_encrypt_model_credentials(mock_encrypt): | |||||
| def test_encrypt_model_credentials(mock_encrypt, mocker): | |||||
| api_key = 'http://127.0.0.1:9997/' | api_key = 'http://127.0.0.1:9997/' | ||||
| mocker.patch('core.model_providers.providers.xinference_provider.XinferenceProvider._get_extra_credentials', | |||||
| return_value={ | |||||
| 'model_handle_type': 'generate', | |||||
| 'model_format': 'ggmlv3' | |||||
| }) | |||||
| result = MODEL_PROVIDER_CLASS.encrypt_model_credentials( | result = MODEL_PROVIDER_CLASS.encrypt_model_credentials( | ||||
| tenant_id='tenant_id', | tenant_id='tenant_id', | ||||
| model_name='test_model_name', | model_name='test_model_name', |