| import logging | |||||
| from core.model_runtime.model_providers.__base.model_provider import ModelProvider | |||||
| logger = logging.getLogger(__name__) | |||||
| class HuggingfaceTeiProvider(ModelProvider): | |||||
| def validate_provider_credentials(self, credentials: dict) -> None: | |||||
| pass |
| provider: huggingface_tei | |||||
| label: | |||||
| en_US: Text Embedding Inference | |||||
| description: | |||||
| en_US: A blazing fast inference solution for text embeddings models. | |||||
| zh_Hans: 用于文本嵌入模型的超快速推理解决方案。 | |||||
| background: "#FFF8DC" | |||||
| help: | |||||
| title: | |||||
| en_US: How to deploy Text Embedding Inference | |||||
| zh_Hans: 如何部署 Text Embedding Inference | |||||
| url: | |||||
| en_US: https://github.com/huggingface/text-embeddings-inference | |||||
| supported_model_types: | |||||
| - text-embedding | |||||
| - rerank | |||||
| configurate_methods: | |||||
| - customizable-model | |||||
| model_credential_schema: | |||||
| model: | |||||
| label: | |||||
| en_US: Model Name | |||||
| zh_Hans: 模型名称 | |||||
| placeholder: | |||||
| en_US: Enter your model name | |||||
| zh_Hans: 输入模型名称 | |||||
| credential_form_schemas: | |||||
| - variable: server_url | |||||
| label: | |||||
| zh_Hans: 服务器URL | |||||
| en_US: Server url | |||||
| type: secret-input | |||||
| required: true | |||||
| placeholder: | |||||
| zh_Hans: 在此输入Text Embedding Inference的服务器地址,如 http://192.168.1.100:8080 | |||||
| en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080 |
| from typing import Optional | |||||
| import httpx | |||||
| from core.model_runtime.entities.common_entities import I18nObject | |||||
| from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType | |||||
| from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult | |||||
| from core.model_runtime.errors.invoke import ( | |||||
| InvokeAuthorizationError, | |||||
| InvokeBadRequestError, | |||||
| InvokeConnectionError, | |||||
| InvokeError, | |||||
| InvokeRateLimitError, | |||||
| InvokeServerUnavailableError, | |||||
| ) | |||||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||||
| from core.model_runtime.model_providers.__base.rerank_model import RerankModel | |||||
| from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiHelper | |||||
| class HuggingfaceTeiRerankModel(RerankModel): | |||||
| """ | |||||
| Model class for Text Embedding Inference rerank model. | |||||
| """ | |||||
| def _invoke( | |||||
| self, | |||||
| model: str, | |||||
| credentials: dict, | |||||
| query: str, | |||||
| docs: list[str], | |||||
| score_threshold: Optional[float] = None, | |||||
| top_n: Optional[int] = None, | |||||
| user: Optional[str] = None, | |||||
| ) -> RerankResult: | |||||
| """ | |||||
| Invoke rerank model | |||||
| :param model: model name | |||||
| :param credentials: model credentials | |||||
| :param query: search query | |||||
| :param docs: docs for reranking | |||||
| :param score_threshold: score threshold | |||||
| :param top_n: top n | |||||
| :param user: unique user id | |||||
| :return: rerank result | |||||
| """ | |||||
| if len(docs) == 0: | |||||
| return RerankResult(model=model, docs=[]) | |||||
| server_url = credentials['server_url'] | |||||
| if server_url.endswith('/'): | |||||
| server_url = server_url[:-1] | |||||
| try: | |||||
| results = TeiHelper.invoke_rerank(server_url, query, docs) | |||||
| rerank_documents = [] | |||||
| for result in results: | |||||
| rerank_document = RerankDocument( | |||||
| index=result['index'], | |||||
| text=result['text'], | |||||
| score=result['score'], | |||||
| ) | |||||
| if score_threshold is None or result['score'] >= score_threshold: | |||||
| rerank_documents.append(rerank_document) | |||||
| if top_n is not None and len(rerank_documents) >= top_n: | |||||
| break | |||||
| return RerankResult(model=model, docs=rerank_documents) | |||||
| except httpx.HTTPStatusError as e: | |||||
| raise InvokeServerUnavailableError(str(e)) | |||||
| def validate_credentials(self, model: str, credentials: dict) -> None: | |||||
| """ | |||||
| Validate model credentials | |||||
| :param model: model name | |||||
| :param credentials: model credentials | |||||
| :return: | |||||
| """ | |||||
| try: | |||||
| server_url = credentials['server_url'] | |||||
| extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) | |||||
| if extra_args.model_type != 'reranker': | |||||
| raise CredentialsValidateFailedError('Current model is not a rerank model') | |||||
| credentials['context_size'] = extra_args.max_input_length | |||||
| self.invoke( | |||||
| model=model, | |||||
| credentials=credentials, | |||||
| query='Whose kasumi', | |||||
| docs=[ | |||||
| 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', | |||||
| 'Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ', | |||||
| 'and she leads a team named PopiParty.', | |||||
| ], | |||||
| score_threshold=0.8, | |||||
| ) | |||||
| except Exception as ex: | |||||
| raise CredentialsValidateFailedError(str(ex)) | |||||
| @property | |||||
| def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: | |||||
| """ | |||||
| Map model invoke error to unified error | |||||
| The key is the error type thrown to the caller | |||||
| The value is the error type thrown by the model, | |||||
| which needs to be converted into a unified error type for the caller. | |||||
| :return: Invoke error mapping | |||||
| """ | |||||
| return { | |||||
| InvokeConnectionError: [InvokeConnectionError], | |||||
| InvokeServerUnavailableError: [InvokeServerUnavailableError], | |||||
| InvokeRateLimitError: [InvokeRateLimitError], | |||||
| InvokeAuthorizationError: [InvokeAuthorizationError], | |||||
| InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], | |||||
| } | |||||
| def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: | |||||
| """ | |||||
| used to define customizable model schema | |||||
| """ | |||||
| entity = AIModelEntity( | |||||
| model=model, | |||||
| label=I18nObject(en_US=model), | |||||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | |||||
| model_type=ModelType.RERANK, | |||||
| model_properties={ | |||||
| ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)), | |||||
| }, | |||||
| parameter_rules=[], | |||||
| ) | |||||
| return entity |
| from threading import Lock | |||||
| from time import time | |||||
| from typing import Optional | |||||
| import httpx | |||||
| from requests.adapters import HTTPAdapter | |||||
| from requests.exceptions import ConnectionError, MissingSchema, Timeout | |||||
| from requests.sessions import Session | |||||
| from yarl import URL | |||||
| class TeiModelExtraParameter: | |||||
| model_type: str | |||||
| max_input_length: int | |||||
| max_client_batch_size: int | |||||
| def __init__(self, model_type: str, max_input_length: int, max_client_batch_size: Optional[int] = None) -> None: | |||||
| self.model_type = model_type | |||||
| self.max_input_length = max_input_length | |||||
| self.max_client_batch_size = max_client_batch_size | |||||
| cache = {} | |||||
| cache_lock = Lock() | |||||
| class TeiHelper: | |||||
| @staticmethod | |||||
| def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter: | |||||
| TeiHelper._clean_cache() | |||||
| with cache_lock: | |||||
| if model_name not in cache: | |||||
| cache[model_name] = { | |||||
| 'expires': time() + 300, | |||||
| 'value': TeiHelper._get_tei_extra_parameter(server_url), | |||||
| } | |||||
| return cache[model_name]['value'] | |||||
| @staticmethod | |||||
| def _clean_cache() -> None: | |||||
| try: | |||||
| with cache_lock: | |||||
| expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()] | |||||
| for model_uid in expired_keys: | |||||
| del cache[model_uid] | |||||
| except RuntimeError as e: | |||||
| pass | |||||
| @staticmethod | |||||
| def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter: | |||||
| """ | |||||
| get tei model extra parameter like model_type, max_input_length, max_batch_requests | |||||
| """ | |||||
| url = str(URL(server_url) / 'info') | |||||
| # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 | |||||
| session = Session() | |||||
| session.mount('http://', HTTPAdapter(max_retries=3)) | |||||
| session.mount('https://', HTTPAdapter(max_retries=3)) | |||||
| try: | |||||
| response = session.get(url, timeout=10) | |||||
| except (MissingSchema, ConnectionError, Timeout) as e: | |||||
| raise RuntimeError(f'get tei model extra parameter failed, url: {url}, error: {e}') | |||||
| if response.status_code != 200: | |||||
| raise RuntimeError( | |||||
| f'get tei model extra parameter failed, status code: {response.status_code}, response: {response.text}' | |||||
| ) | |||||
| response_json = response.json() | |||||
| model_type = response_json.get('model_type', {}) | |||||
| if len(model_type.keys()) < 1: | |||||
| raise RuntimeError('model_type is empty') | |||||
| model_type = list(model_type.keys())[0] | |||||
| if model_type not in ['embedding', 'reranker']: | |||||
| raise RuntimeError(f'invalid model_type: {model_type}') | |||||
| max_input_length = response_json.get('max_input_length', 512) | |||||
| max_client_batch_size = response_json.get('max_client_batch_size', 1) | |||||
| return TeiModelExtraParameter( | |||||
| model_type=model_type, | |||||
| max_input_length=max_input_length, | |||||
| max_client_batch_size=max_client_batch_size | |||||
| ) | |||||
| @staticmethod | |||||
| def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: | |||||
| """ | |||||
| Invoke tokenize endpoint | |||||
| Example response: | |||||
| [ | |||||
| [ | |||||
| { | |||||
| "id": 0, | |||||
| "text": "<s>", | |||||
| "special": true, | |||||
| "start": null, | |||||
| "stop": null | |||||
| }, | |||||
| { | |||||
| "id": 7704, | |||||
| "text": "str", | |||||
| "special": false, | |||||
| "start": 0, | |||||
| "stop": 3 | |||||
| }, | |||||
| < MORE TOKENS > | |||||
| ] | |||||
| ] | |||||
| :param server_url: server url | |||||
| :param texts: texts to tokenize | |||||
| """ | |||||
| resp = httpx.post( | |||||
| f'{server_url}/tokenize', | |||||
| json={'inputs': texts}, | |||||
| ) | |||||
| resp.raise_for_status() | |||||
| return resp.json() | |||||
| @staticmethod | |||||
| def invoke_embeddings(server_url: str, texts: list[str]) -> dict: | |||||
| """ | |||||
| Invoke embeddings endpoint | |||||
| Example response: | |||||
| { | |||||
| "object": "list", | |||||
| "data": [ | |||||
| { | |||||
| "object": "embedding", | |||||
| "embedding": [...], | |||||
| "index": 0 | |||||
| } | |||||
| ], | |||||
| "model": "MODEL_NAME", | |||||
| "usage": { | |||||
| "prompt_tokens": 3, | |||||
| "total_tokens": 3 | |||||
| } | |||||
| } | |||||
| :param server_url: server url | |||||
| :param texts: texts to embed | |||||
| """ | |||||
| # Use OpenAI compatible API here, which has usage tracking | |||||
| resp = httpx.post( | |||||
| f'{server_url}/v1/embeddings', | |||||
| json={'input': texts}, | |||||
| ) | |||||
| resp.raise_for_status() | |||||
| return resp.json() | |||||
| @staticmethod | |||||
| def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]: | |||||
| """ | |||||
| Invoke rerank endpoint | |||||
| Example response: | |||||
| [ | |||||
| { | |||||
| "index": 0, | |||||
| "text": "Deep Learning is ...", | |||||
| "score": 0.9950755 | |||||
| } | |||||
| ] | |||||
| :param server_url: server url | |||||
| :param texts: texts to rerank | |||||
| :param candidates: candidates to rerank | |||||
| """ | |||||
| params = {'query': query, 'texts': docs, 'return_text': True} | |||||
| response = httpx.post( | |||||
| server_url + '/rerank', | |||||
| json=params, | |||||
| ) | |||||
| response.raise_for_status() | |||||
| return response.json() |
| import time | |||||
| from typing import Optional | |||||
| from core.model_runtime.entities.common_entities import I18nObject | |||||
| from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType | |||||
| from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult | |||||
| from core.model_runtime.errors.invoke import ( | |||||
| InvokeAuthorizationError, | |||||
| InvokeBadRequestError, | |||||
| InvokeConnectionError, | |||||
| InvokeError, | |||||
| InvokeRateLimitError, | |||||
| InvokeServerUnavailableError, | |||||
| ) | |||||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||||
| from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | |||||
| from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiHelper | |||||
| class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): | |||||
| """ | |||||
| Model class for Text Embedding Inference text embedding model. | |||||
| """ | |||||
| def _invoke( | |||||
| self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None | |||||
| ) -> TextEmbeddingResult: | |||||
| """ | |||||
| Invoke text embedding model | |||||
| credentials should be like: | |||||
| { | |||||
| 'server_url': 'server url', | |||||
| 'model_uid': 'model uid', | |||||
| } | |||||
| :param model: model name | |||||
| :param credentials: model credentials | |||||
| :param texts: texts to embed | |||||
| :param user: unique user id | |||||
| :return: embeddings result | |||||
| """ | |||||
| server_url = credentials['server_url'] | |||||
| if server_url.endswith('/'): | |||||
| server_url = server_url[:-1] | |||||
| # get model properties | |||||
| context_size = self._get_context_size(model, credentials) | |||||
| max_chunks = self._get_max_chunks(model, credentials) | |||||
| inputs = [] | |||||
| indices = [] | |||||
| used_tokens = 0 | |||||
| # get tokenized results from TEI | |||||
| batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts) | |||||
| for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)): | |||||
| # Check if the number of tokens is larger than the context size | |||||
| num_tokens = len(tokenize_result) | |||||
| if num_tokens >= context_size: | |||||
| # Find the best cutoff point | |||||
| pre_special_token_count = 0 | |||||
| for token in tokenize_result: | |||||
| if token['special']: | |||||
| pre_special_token_count += 1 | |||||
| else: | |||||
| break | |||||
| rest_special_token_count = len([token for token in tokenize_result if token['special']]) - pre_special_token_count | |||||
| # Calculate the cutoff point, leave 20 extra space to avoid exceeding the limit | |||||
| token_cutoff = context_size - rest_special_token_count - 20 | |||||
| # Find the cutoff index | |||||
| cutpoint_token = tokenize_result[token_cutoff] | |||||
| cutoff = cutpoint_token['start'] | |||||
| inputs.append(text[0: cutoff]) | |||||
| else: | |||||
| inputs.append(text) | |||||
| indices += [i] | |||||
| batched_embeddings = [] | |||||
| _iter = range(0, len(inputs), max_chunks) | |||||
| try: | |||||
| used_tokens = 0 | |||||
| for i in _iter: | |||||
| iter_texts = inputs[i : i + max_chunks] | |||||
| results = TeiHelper.invoke_embeddings(server_url, iter_texts) | |||||
| embeddings = results['data'] | |||||
| embeddings = [embedding['embedding'] for embedding in embeddings] | |||||
| batched_embeddings.extend(embeddings) | |||||
| usage = results['usage'] | |||||
| used_tokens += usage['total_tokens'] | |||||
| except RuntimeError as e: | |||||
| raise InvokeServerUnavailableError(str(e)) | |||||
| usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) | |||||
| result = TextEmbeddingResult(model=model, embeddings=batched_embeddings, usage=usage) | |||||
| return result | |||||
| def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: | |||||
| """ | |||||
| Get number of tokens for given prompt messages | |||||
| :param model: model name | |||||
| :param credentials: model credentials | |||||
| :param texts: texts to embed | |||||
| :return: | |||||
| """ | |||||
| num_tokens = 0 | |||||
| server_url = credentials['server_url'] | |||||
| if server_url.endswith('/'): | |||||
| server_url = server_url[:-1] | |||||
| batch_tokens = TeiHelper.invoke_tokenize(server_url, texts) | |||||
| num_tokens = sum(len(tokens) for tokens in batch_tokens) | |||||
| return num_tokens | |||||
| def validate_credentials(self, model: str, credentials: dict) -> None: | |||||
| """ | |||||
| Validate model credentials | |||||
| :param model: model name | |||||
| :param credentials: model credentials | |||||
| :return: | |||||
| """ | |||||
| try: | |||||
| server_url = credentials['server_url'] | |||||
| extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) | |||||
| print(extra_args) | |||||
| if extra_args.model_type != 'embedding': | |||||
| raise CredentialsValidateFailedError('Current model is not a embedding model') | |||||
| credentials['context_size'] = extra_args.max_input_length | |||||
| credentials['max_chunks'] = extra_args.max_client_batch_size | |||||
| self._invoke(model=model, credentials=credentials, texts=['ping']) | |||||
| except Exception as ex: | |||||
| raise CredentialsValidateFailedError(str(ex)) | |||||
| @property | |||||
| def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: | |||||
| return { | |||||
| InvokeConnectionError: [InvokeConnectionError], | |||||
| InvokeServerUnavailableError: [InvokeServerUnavailableError], | |||||
| InvokeRateLimitError: [InvokeRateLimitError], | |||||
| InvokeAuthorizationError: [InvokeAuthorizationError], | |||||
| InvokeBadRequestError: [KeyError], | |||||
| } | |||||
| def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: | |||||
| """ | |||||
| Calculate response usage | |||||
| :param model: model name | |||||
| :param credentials: model credentials | |||||
| :param tokens: input tokens | |||||
| :return: usage | |||||
| """ | |||||
| # get input price info | |||||
| input_price_info = self.get_price( | |||||
| model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens | |||||
| ) | |||||
| # transform usage | |||||
| usage = EmbeddingUsage( | |||||
| tokens=tokens, | |||||
| total_tokens=tokens, | |||||
| unit_price=input_price_info.unit_price, | |||||
| price_unit=input_price_info.unit, | |||||
| total_price=input_price_info.total_amount, | |||||
| currency=input_price_info.currency, | |||||
| latency=time.perf_counter() - self.started_at, | |||||
| ) | |||||
| return usage | |||||
| def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: | |||||
| """ | |||||
| used to define customizable model schema | |||||
| """ | |||||
| entity = AIModelEntity( | |||||
| model=model, | |||||
| label=I18nObject(en_US=model), | |||||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | |||||
| model_type=ModelType.TEXT_EMBEDDING, | |||||
| model_properties={ | |||||
| ModelPropertyKey.MAX_CHUNKS: int(credentials.get('max_chunks', 1)), | |||||
| ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)), | |||||
| }, | |||||
| parameter_rules=[], | |||||
| ) | |||||
| return entity |
| CODE_EXECUTION_ENDPOINT = "http://127.0.0.1:8194" | CODE_EXECUTION_ENDPOINT = "http://127.0.0.1:8194" | ||||
| CODE_EXECUTION_API_KEY = "dify-sandbox" | CODE_EXECUTION_API_KEY = "dify-sandbox" | ||||
| FIRECRAWL_API_KEY = "fc-" | FIRECRAWL_API_KEY = "fc-" | ||||
| TEI_EMBEDDING_SERVER_URL = "http://a.abc.com:11451" | |||||
| TEI_RERANK_SERVER_URL = "http://a.abc.com:11451" | |||||
| [tool.poetry] | [tool.poetry] | ||||
| name = "dify-api" | name = "dify-api" |
| from api.core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiModelExtraParameter | |||||
| class MockTEIClass: | |||||
| @staticmethod | |||||
| def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter: | |||||
| # During mock, we don't have a real server to query, so we just return a dummy value | |||||
| if 'rerank' in model_name: | |||||
| model_type = 'reranker' | |||||
| else: | |||||
| model_type = 'embedding' | |||||
| return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1) | |||||
| @staticmethod | |||||
| def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: | |||||
| # Use space as token separator, and split the text into tokens | |||||
| tokenized_texts = [] | |||||
| for text in texts: | |||||
| tokens = text.split(' ') | |||||
| current_index = 0 | |||||
| tokenized_text = [] | |||||
| for idx, token in enumerate(tokens): | |||||
| s_token = { | |||||
| 'id': idx, | |||||
| 'text': token, | |||||
| 'special': False, | |||||
| 'start': current_index, | |||||
| 'stop': current_index + len(token), | |||||
| } | |||||
| current_index += len(token) + 1 | |||||
| tokenized_text.append(s_token) | |||||
| tokenized_texts.append(tokenized_text) | |||||
| return tokenized_texts | |||||
| @staticmethod | |||||
| def invoke_embeddings(server_url: str, texts: list[str]) -> dict: | |||||
| # { | |||||
| # "object": "list", | |||||
| # "data": [ | |||||
| # { | |||||
| # "object": "embedding", | |||||
| # "embedding": [...], | |||||
| # "index": 0 | |||||
| # } | |||||
| # ], | |||||
| # "model": "MODEL_NAME", | |||||
| # "usage": { | |||||
| # "prompt_tokens": 3, | |||||
| # "total_tokens": 3 | |||||
| # } | |||||
| # } | |||||
| embeddings = [] | |||||
| for idx, text in enumerate(texts): | |||||
| embedding = [0.1] * 768 | |||||
| embeddings.append( | |||||
| { | |||||
| 'object': 'embedding', | |||||
| 'embedding': embedding, | |||||
| 'index': idx, | |||||
| } | |||||
| ) | |||||
| return { | |||||
| 'object': 'list', | |||||
| 'data': embeddings, | |||||
| 'model': 'MODEL_NAME', | |||||
| 'usage': { | |||||
| 'prompt_tokens': sum(len(text.split(' ')) for text in texts), | |||||
| 'total_tokens': sum(len(text.split(' ')) for text in texts), | |||||
| }, | |||||
| } | |||||
| def invoke_rerank(server_url: str, query: str, texts: list[str]) -> list[dict]: | |||||
| # Example response: | |||||
| # [ | |||||
| # { | |||||
| # "index": 0, | |||||
| # "text": "Deep Learning is ...", | |||||
| # "score": 0.9950755 | |||||
| # } | |||||
| # ] | |||||
| reranked_docs = [] | |||||
| for idx, text in enumerate(texts): | |||||
| reranked_docs.append( | |||||
| { | |||||
| 'index': idx, | |||||
| 'text': text, | |||||
| 'score': 0.9, | |||||
| } | |||||
| ) | |||||
| # For mock, only return the first document | |||||
| break | |||||
| return reranked_docs |
| import os | |||||
| import pytest | |||||
| from api.core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper | |||||
| from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult | |||||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||||
| from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import ( | |||||
| HuggingfaceTeiTextEmbeddingModel, | |||||
| ) | |||||
| from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass | |||||
| MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' | |||||
| @pytest.fixture | |||||
| def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch): | |||||
| if MOCK: | |||||
| monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter) | |||||
| monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize) | |||||
| monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings) | |||||
| monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank) | |||||
| yield | |||||
| if MOCK: | |||||
| monkeypatch.undo() | |||||
| @pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True) | |||||
| def test_validate_credentials(setup_tei_mock): | |||||
| model = HuggingfaceTeiTextEmbeddingModel() | |||||
| # model name is only used in mock | |||||
| model_name = 'embedding' | |||||
| if MOCK: | |||||
| # TEI Provider will check model type by API endpoint, at real server, the model type is correct. | |||||
| # So we dont need to check model type here. Only check in mock | |||||
| with pytest.raises(CredentialsValidateFailedError): | |||||
| model.validate_credentials( | |||||
| model='reranker', | |||||
| credentials={ | |||||
| 'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""), | |||||
| } | |||||
| ) | |||||
| model.validate_credentials( | |||||
| model=model_name, | |||||
| credentials={ | |||||
| 'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""), | |||||
| } | |||||
| ) | |||||
| @pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True) | |||||
| def test_invoke_model(setup_tei_mock): | |||||
| model = HuggingfaceTeiTextEmbeddingModel() | |||||
| model_name = 'embedding' | |||||
| result = model.invoke( | |||||
| model=model_name, | |||||
| credentials={ | |||||
| 'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""), | |||||
| }, | |||||
| texts=[ | |||||
| "hello", | |||||
| "world" | |||||
| ], | |||||
| user="abc-123" | |||||
| ) | |||||
| assert isinstance(result, TextEmbeddingResult) | |||||
| assert len(result.embeddings) == 2 | |||||
| assert result.usage.total_tokens > 0 |
| import os | |||||
| import pytest | |||||
| from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult | |||||
| from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult | |||||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||||
| from core.model_runtime.model_providers.huggingface_tei.rerank.rerank import ( | |||||
| HuggingfaceTeiRerankModel, | |||||
| ) | |||||
| from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper | |||||
| from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass | |||||
| MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' | |||||
| @pytest.fixture | |||||
| def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch): | |||||
| if MOCK: | |||||
| monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter) | |||||
| monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize) | |||||
| monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings) | |||||
| monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank) | |||||
| yield | |||||
| if MOCK: | |||||
| monkeypatch.undo() | |||||
| @pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True) | |||||
| def test_validate_credentials(setup_tei_mock): | |||||
| model = HuggingfaceTeiRerankModel() | |||||
| # model name is only used in mock | |||||
| model_name = 'reranker' | |||||
| if MOCK: | |||||
| # TEI Provider will check model type by API endpoint, at real server, the model type is correct. | |||||
| # So we dont need to check model type here. Only check in mock | |||||
| with pytest.raises(CredentialsValidateFailedError): | |||||
| model.validate_credentials( | |||||
| model='embedding', | |||||
| credentials={ | |||||
| 'server_url': os.environ.get('TEI_RERANK_SERVER_URL'), | |||||
| } | |||||
| ) | |||||
| model.validate_credentials( | |||||
| model=model_name, | |||||
| credentials={ | |||||
| 'server_url': os.environ.get('TEI_RERANK_SERVER_URL'), | |||||
| } | |||||
| ) | |||||
| @pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True) | |||||
| def test_invoke_model(setup_tei_mock): | |||||
| model = HuggingfaceTeiRerankModel() | |||||
| # model name is only used in mock | |||||
| model_name = 'reranker' | |||||
| result = model.invoke( | |||||
| model=model_name, | |||||
| credentials={ | |||||
| 'server_url': os.environ.get('TEI_RERANK_SERVER_URL'), | |||||
| }, | |||||
| query="Who is Kasumi?", | |||||
| docs=[ | |||||
| "Kasumi is a girl's name of Japanese origin meaning \"mist\".", | |||||
| "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", | |||||
| "and she leads a team named PopiParty." | |||||
| ], | |||||
| score_threshold=0.8 | |||||
| ) | |||||
| assert isinstance(result, RerankResult) | |||||
| assert len(result.docs) == 1 | |||||
| assert result.docs[0].index == 0 | |||||
| assert result.docs[0].score >= 0.8 |