浏览代码

fix: xinference chat support (#939)

tags/0.3.15
takatost 2 年前
父节点
当前提交
e0a48c4972
没有帐户链接到提交者的电子邮件

+ 4
- 3
api/core/model_providers/models/llm/xinference_model.py 查看文件

from typing import List, Optional, Any from typing import List, Optional, Any


from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.llms import Xinference
from langchain.schema import LLMResult from langchain.schema import LLMResult


from core.model_providers.error import LLMBadRequestError from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM




class XinferenceModel(BaseLLM): class XinferenceModel(BaseLLM):
def _init_client(self) -> Any: def _init_client(self) -> Any:
self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)


client = Xinference(
**self.credentials,
client = XinferenceLLM(
server_url=self.credentials['server_url'],
model_uid=self.credentials['model_uid'],
) )


client.callbacks = self.callbacks client.callbacks = self.callbacks

+ 59
- 10
api/core/model_providers/providers/xinference_provider.py 查看文件

import json import json
from typing import Type from typing import Type


from langchain.llms import Xinference
import requests
from xinference.client import RESTfulGenerateModelHandle, RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle


from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError


from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
from models.provider import ProviderType from models.provider import ProviderType




:param model_type: :param model_type:
:return: :return:
""" """
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
)
credentials = self.get_model_credentials(model_name, model_type)
if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm":
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
)
elif credentials['model_format'] == "ggmlv3":
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
)
else:
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=256),
)



@classmethod @classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
'model_uid': credentials['model_uid'], 'model_uid': credentials['model_uid'],
} }


llm = Xinference(
llm = XinferenceLLM(
**credential_kwargs **credential_kwargs
) )


llm("ping", generate_config={'max_tokens': 10})
llm("ping")
except Exception as ex: except Exception as ex:
raise CredentialsValidateFailedError(str(ex)) raise CredentialsValidateFailedError(str(ex))


:param credentials: :param credentials:
:return: :return:
""" """
extra_credentials = cls._get_extra_credentials(credentials)
credentials.update(extra_credentials)

credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url']) credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])

return credentials return credentials


def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:


return credentials return credentials


@classmethod
def _get_extra_credentials(self, credentials: dict) -> dict:
url = f"{credentials['server_url']}/v1/models/{credentials['model_uid']}"
response = requests.get(url)
if response.status_code != 200:
raise RuntimeError(
f"Failed to get the model description, detail: {response.json()['detail']}"
)
desc = response.json()

extra_credentials = {
'model_format': desc['model_format'],
}
if desc["model_format"] == "ggmlv3" and "chatglm" in desc["model_name"]:
extra_credentials['model_handle_type'] = 'chatglm'
elif "generate" in desc["model_ability"]:
extra_credentials['model_handle_type'] = 'generate'
elif "chat" in desc["model_ability"]:
extra_credentials['model_handle_type'] = 'chat'
else:
raise NotImplementedError(f"Model handle type not supported.")

return extra_credentials

@classmethod @classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict): def is_provider_credentials_valid_or_raise(cls, credentials: dict):
return return

+ 132
- 0
api/core/third_party/langchain/llms/xinference_llm.py 查看文件

from typing import Optional, List, Any, Union, Generator

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms import Xinference
from langchain.llms.utils import enforce_stop_tokens
from xinference.client import RESTfulChatglmCppChatModelHandle, \
RESTfulChatModelHandle, RESTfulGenerateModelHandle


class XinferenceLLM(Xinference):
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call the xinference model and return the output.

Args:
prompt: The prompt to use for generation.
stop: Optional list of stop words to use when generating.
generate_config: Optional dictionary for the configuration used for
generation.

Returns:
The generated string by the model.
"""
model = self.client.get_model(self.model_uid)

if isinstance(model, RESTfulChatModelHandle):
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})

if stop:
generate_config["stop"] = stop

if generate_config and generate_config.get("stream"):
combined_text_output = ""
for token in self._stream_generate(
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
):
combined_text_output += token
return combined_text_output
else:
completion = model.chat(prompt=prompt, generate_config=generate_config)
return completion["choices"][0]["text"]
elif isinstance(model, RESTfulGenerateModelHandle):
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})

if stop:
generate_config["stop"] = stop

if generate_config and generate_config.get("stream"):
combined_text_output = ""
for token in self._stream_generate(
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
):
combined_text_output += token
return combined_text_output

else:
completion = model.generate(prompt=prompt, generate_config=generate_config)
return completion["choices"][0]["text"]
elif isinstance(model, RESTfulChatglmCppChatModelHandle):
generate_config: "ChatglmCppGenerateConfig" = kwargs.get("generate_config", {})

if generate_config and generate_config.get("stream"):
combined_text_output = ""
for token in self._stream_generate(
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
):
combined_text_output += token
completion = combined_text_output
else:
completion = model.chat(prompt=prompt, generate_config=generate_config)
completion = completion["choices"][0]["text"]

if stop is not None:
completion = enforce_stop_tokens(completion, stop)

return completion


def _stream_generate(
self,
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"],
prompt: str,
run_manager: Optional[CallbackManagerForLLMRun] = None,
generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None,
) -> Generator[str, None, None]:
"""
Args:
prompt: The prompt to use for generation.
model: The model used for generation.
stop: Optional list of stop words to use when generating.
generate_config: Optional dictionary for the configuration used for
generation.

Yields:
A string token.
"""
if isinstance(model, RESTfulGenerateModelHandle):
streaming_response = model.generate(
prompt=prompt, generate_config=generate_config
)
else:
streaming_response = model.chat(
prompt=prompt, generate_config=generate_config
)

for chunk in streaming_response:
if isinstance(chunk, dict):
choices = chunk.get("choices", [])
if choices:
choice = choices[0]
if isinstance(choice, dict):
token = choice.get("text", "")
log_probs = choice.get("logprobs")
if run_manager:
run_manager.on_llm_new_token(
token=token, verbose=self.verbose, log_probs=log_probs
)
yield token

+ 9
- 3
api/tests/unit_tests/model_providers/test_xinference_provider.py 查看文件



from core.model_providers.models.entity.model_params import ModelType from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import CredentialsValidateFailedError from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.replicate_provider import ReplicateProvider
from core.model_providers.providers.xinference_provider import XinferenceProvider from core.model_providers.providers.xinference_provider import XinferenceProvider
from models.provider import ProviderType, Provider, ProviderModel from models.provider import ProviderType, Provider, ProviderModel






def test_is_credentials_valid_or_raise_valid(mocker): def test_is_credentials_valid_or_raise_valid(mocker):
mocker.patch('langchain.llms.xinference.Xinference._call',
mocker.patch('core.third_party.langchain.llms.xinference_llm.XinferenceLLM._call',
return_value="abc") return_value="abc")


MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(




@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) @patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_model_credentials(mock_encrypt):
def test_encrypt_model_credentials(mock_encrypt, mocker):
api_key = 'http://127.0.0.1:9997/' api_key = 'http://127.0.0.1:9997/'

mocker.patch('core.model_providers.providers.xinference_provider.XinferenceProvider._get_extra_credentials',
return_value={
'model_handle_type': 'generate',
'model_format': 'ggmlv3'
})

result = MODEL_PROVIDER_CLASS.encrypt_model_credentials( result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
tenant_id='tenant_id', tenant_id='tenant_id',
model_name='test_model_name', model_name='test_model_name',

正在加载...
取消
保存