| @@ -115,7 +115,7 @@ class ModelProviderModelValidateApi(Resource): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model_name', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=['text-generation', 'embeddings', 'speech2text'], location='json') | |||
| choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json') | |||
| parser.add_argument('config', type=dict, required=True, nullable=False, location='json') | |||
| args = parser.parse_args() | |||
| @@ -155,7 +155,7 @@ class ModelProviderModelUpdateApi(Resource): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model_name', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=['text-generation', 'embeddings', 'speech2text'], location='json') | |||
| choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json') | |||
| parser.add_argument('config', type=dict, required=True, nullable=False, location='json') | |||
| args = parser.parse_args() | |||
| @@ -184,7 +184,7 @@ class ModelProviderModelUpdateApi(Resource): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model_name', type=str, required=True, nullable=False, location='args') | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=['text-generation', 'embeddings', 'speech2text'], location='args') | |||
| choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args') | |||
| args = parser.parse_args() | |||
| provider_service = ProviderService() | |||
| @@ -0,0 +1,58 @@ | |||
| import logging | |||
| from typing import Optional, List | |||
| from langchain.schema import Document | |||
| from xinference_client.client.restful.restful_client import Client | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.model_providers.models.reranking.base import BaseReranking | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| class XinferenceReranking(BaseReranking): | |||
| def __init__(self, model_provider: BaseModelProvider, name: str): | |||
| self.credentials = model_provider.get_model_credentials( | |||
| model_name=name, | |||
| model_type=self.type | |||
| ) | |||
| client = Client(self.credentials['server_url']) | |||
| super().__init__(model_provider, client, name) | |||
| def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]: | |||
| docs = [] | |||
| doc_id = [] | |||
| for document in documents: | |||
| if document.metadata['doc_id'] not in doc_id: | |||
| doc_id.append(document.metadata['doc_id']) | |||
| docs.append(document.page_content) | |||
| model = self.client.get_model(self.credentials['model_uid']) | |||
| response = model.rerank(query=query, documents=docs, top_n=top_k) | |||
| rerank_documents = [] | |||
| for idx, result in enumerate(response['results']): | |||
| # format document | |||
| index = result['index'] | |||
| rerank_document = Document( | |||
| page_content=result['document'], | |||
| metadata={ | |||
| "doc_id": documents[index].metadata['doc_id'], | |||
| "doc_hash": documents[index].metadata['doc_hash'], | |||
| "document_id": documents[index].metadata['document_id'], | |||
| "dataset_id": documents[index].metadata['dataset_id'], | |||
| 'score': result['relevance_score'] | |||
| } | |||
| ) | |||
| # score threshold check | |||
| if score_threshold is not None: | |||
| if result.relevance_score >= score_threshold: | |||
| rerank_documents.append(rerank_document) | |||
| else: | |||
| rerank_documents.append(rerank_document) | |||
| return rerank_documents | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| return LLMBadRequestError(f"Xinference rerank: {str(ex)}") | |||
| @@ -2,11 +2,13 @@ import json | |||
| from typing import Type | |||
| import requests | |||
| from xinference_client.client.restful.restful_client import Client | |||
| from core.helper import encrypter | |||
| from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding | |||
| from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode | |||
| from core.model_providers.models.llm.xinference_model import XinferenceModel | |||
| from core.model_providers.models.reranking.xinference_reranking import XinferenceReranking | |||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| @@ -40,6 +42,8 @@ class XinferenceProvider(BaseModelProvider): | |||
| model_class = XinferenceModel | |||
| elif model_type == ModelType.EMBEDDINGS: | |||
| model_class = XinferenceEmbedding | |||
| elif model_type == ModelType.RERANKING: | |||
| model_class = XinferenceReranking | |||
| else: | |||
| raise NotImplementedError | |||
| @@ -113,6 +117,10 @@ class XinferenceProvider(BaseModelProvider): | |||
| ) | |||
| embedding.embed_query("ping") | |||
| elif model_type == ModelType.RERANKING: | |||
| rerank_client = Client(credential_kwargs['server_url']) | |||
| model = rerank_client.get_model(credential_kwargs['model_uid']) | |||
| model.rerank(query="ping", documents=["ping", "pong"], top_n=2) | |||
| except Exception as ex: | |||
| raise CredentialsValidateFailedError(str(ex)) | |||
| @@ -6,6 +6,7 @@ | |||
| "model_flexibility": "configurable", | |||
| "supported_model_types": [ | |||
| "text-generation", | |||
| "embeddings" | |||
| "embeddings", | |||
| "reranking" | |||
| ] | |||
| } | |||
| @@ -48,7 +48,7 @@ huggingface_hub~=0.16.4 | |||
| transformers~=4.31.0 | |||
| stripe~=5.5.0 | |||
| pandas==1.5.3 | |||
| xinference-client~=0.5.4 | |||
| xinference-client~=0.6.4 | |||
| safetensors==0.3.2 | |||
| zhipuai==1.0.7 | |||
| werkzeug==2.3.7 | |||
| @@ -50,4 +50,7 @@ XINFERENCE_MODEL_UID= | |||
| OPENLLM_SERVER_URL= | |||
| # LocalAI Credentials | |||
| LOCALAI_SERVER_URL= | |||
| LOCALAI_SERVER_URL= | |||
| # Cohere Credentials | |||
| COHERE_API_KEY= | |||
| @@ -0,0 +1,61 @@ | |||
| import json | |||
| import os | |||
| from unittest.mock import patch | |||
| from langchain.schema import Document | |||
| from core.model_providers.models.reranking.cohere_reranking import CohereReranking | |||
| from core.model_providers.providers.cohere_provider import CohereProvider | |||
| from models.provider import Provider, ProviderType | |||
| def get_mock_provider(valid_api_key): | |||
| return Provider( | |||
| id='provider_id', | |||
| tenant_id='tenant_id', | |||
| provider_name='cohere', | |||
| provider_type=ProviderType.CUSTOM.value, | |||
| encrypted_config=json.dumps({'api_key': valid_api_key}), | |||
| is_valid=True, | |||
| ) | |||
| def get_mock_model(): | |||
| valid_api_key = os.environ['COHERE_API_KEY'] | |||
| provider = CohereProvider(provider=get_mock_provider(valid_api_key)) | |||
| return CohereReranking( | |||
| model_provider=provider, | |||
| name='rerank-english-v2.0' | |||
| ) | |||
| def decrypt_side_effect(tenant_id, encrypted_api_key): | |||
| return encrypted_api_key | |||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||
| def test_run(mock_decrypt): | |||
| model = get_mock_model() | |||
| docs = [] | |||
| docs.append(Document( | |||
| page_content='bye', | |||
| metadata={ | |||
| "doc_id": 'a', | |||
| "doc_hash": 'doc_hash', | |||
| "document_id": 'document_id', | |||
| "dataset_id": 'dataset_id', | |||
| } | |||
| )) | |||
| docs.append(Document( | |||
| page_content='hello', | |||
| metadata={ | |||
| "doc_id": 'b', | |||
| "doc_hash": 'doc_hash', | |||
| "document_id": 'document_id', | |||
| "dataset_id": 'dataset_id', | |||
| } | |||
| )) | |||
| rst = model.rerank('hello', docs, None, 2) | |||
| assert rst[0].page_content == 'hello' | |||
| @@ -0,0 +1,78 @@ | |||
| import json | |||
| import os | |||
| from unittest.mock import patch, MagicMock | |||
| from langchain.schema import Document | |||
| from core.model_providers.models.entity.model_params import ModelType | |||
| from core.model_providers.models.reranking.xinference_reranking import XinferenceReranking | |||
| from core.model_providers.providers.xinference_provider import XinferenceProvider | |||
| from models.provider import Provider, ProviderType, ProviderModel | |||
| def get_mock_provider(valid_server_url, valid_model_uid): | |||
| return Provider( | |||
| id='provider_id', | |||
| tenant_id='tenant_id', | |||
| provider_name='xinference', | |||
| provider_type=ProviderType.CUSTOM.value, | |||
| encrypted_config=json.dumps({'server_url': valid_server_url, 'model_uid': valid_model_uid}), | |||
| is_valid=True, | |||
| ) | |||
| def get_mock_model(mocker): | |||
| valid_server_url = os.environ['XINFERENCE_SERVER_URL'] | |||
| valid_model_uid = os.environ['XINFERENCE_MODEL_UID'] | |||
| model_name = 'bge-reranker-base' | |||
| provider = XinferenceProvider(provider=get_mock_provider(valid_server_url, valid_model_uid)) | |||
| mock_query = MagicMock() | |||
| mock_query.filter.return_value.first.return_value = ProviderModel( | |||
| provider_name='xinference', | |||
| model_name=model_name, | |||
| model_type=ModelType.RERANKING.value, | |||
| encrypted_config=json.dumps({ | |||
| 'server_url': valid_server_url, | |||
| 'model_uid': valid_model_uid | |||
| }), | |||
| is_valid=True, | |||
| ) | |||
| mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) | |||
| return XinferenceReranking( | |||
| model_provider=provider, | |||
| name=model_name | |||
| ) | |||
| def decrypt_side_effect(tenant_id, encrypted_api_key): | |||
| return encrypted_api_key | |||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||
| def test_run(mock_decrypt, mocker): | |||
| model = get_mock_model(mocker) | |||
| docs = [] | |||
| docs.append(Document( | |||
| page_content='bye', | |||
| metadata={ | |||
| "doc_id": 'a', | |||
| "doc_hash": 'doc_hash', | |||
| "document_id": 'document_id', | |||
| "dataset_id": 'dataset_id', | |||
| } | |||
| )) | |||
| docs.append(Document( | |||
| page_content='hello', | |||
| metadata={ | |||
| "doc_id": 'b', | |||
| "doc_hash": 'doc_hash', | |||
| "document_id": 'document_id', | |||
| "dataset_id": 'dataset_id', | |||
| } | |||
| )) | |||
| rst = model.rerank('hello', docs, None, 2) | |||
| assert rst[0].page_content == 'hello' | |||