| @@ -0,0 +1,11 @@ | |||
| import logging | |||
| from core.model_runtime.model_providers.__base.model_provider import ModelProvider | |||
| logger = logging.getLogger(__name__) | |||
| class HuggingfaceTeiProvider(ModelProvider): | |||
| def validate_provider_credentials(self, credentials: dict) -> None: | |||
| pass | |||
| @@ -0,0 +1,36 @@ | |||
| provider: huggingface_tei | |||
| label: | |||
| en_US: Text Embedding Inference | |||
| description: | |||
| en_US: A blazing fast inference solution for text embeddings models. | |||
| zh_Hans: 用于文本嵌入模型的超快速推理解决方案。 | |||
| background: "#FFF8DC" | |||
| help: | |||
| title: | |||
| en_US: How to deploy Text Embedding Inference | |||
| zh_Hans: 如何部署 Text Embedding Inference | |||
| url: | |||
| en_US: https://github.com/huggingface/text-embeddings-inference | |||
| supported_model_types: | |||
| - text-embedding | |||
| - rerank | |||
| configurate_methods: | |||
| - customizable-model | |||
| model_credential_schema: | |||
| model: | |||
| label: | |||
| en_US: Model Name | |||
| zh_Hans: 模型名称 | |||
| placeholder: | |||
| en_US: Enter your model name | |||
| zh_Hans: 输入模型名称 | |||
| credential_form_schemas: | |||
| - variable: server_url | |||
| label: | |||
| zh_Hans: 服务器URL | |||
| en_US: Server url | |||
| type: secret-input | |||
| required: true | |||
| placeholder: | |||
| zh_Hans: 在此输入Text Embedding Inference的服务器地址,如 http://192.168.1.100:8080 | |||
| en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080 | |||
| @@ -0,0 +1,137 @@ | |||
| from typing import Optional | |||
| import httpx | |||
| from core.model_runtime.entities.common_entities import I18nObject | |||
| from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType | |||
| from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult | |||
| from core.model_runtime.errors.invoke import ( | |||
| InvokeAuthorizationError, | |||
| InvokeBadRequestError, | |||
| InvokeConnectionError, | |||
| InvokeError, | |||
| InvokeRateLimitError, | |||
| InvokeServerUnavailableError, | |||
| ) | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.__base.rerank_model import RerankModel | |||
| from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiHelper | |||
| class HuggingfaceTeiRerankModel(RerankModel): | |||
| """ | |||
| Model class for Text Embedding Inference rerank model. | |||
| """ | |||
| def _invoke( | |||
| self, | |||
| model: str, | |||
| credentials: dict, | |||
| query: str, | |||
| docs: list[str], | |||
| score_threshold: Optional[float] = None, | |||
| top_n: Optional[int] = None, | |||
| user: Optional[str] = None, | |||
| ) -> RerankResult: | |||
| """ | |||
| Invoke rerank model | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :param query: search query | |||
| :param docs: docs for reranking | |||
| :param score_threshold: score threshold | |||
| :param top_n: top n | |||
| :param user: unique user id | |||
| :return: rerank result | |||
| """ | |||
| if len(docs) == 0: | |||
| return RerankResult(model=model, docs=[]) | |||
| server_url = credentials['server_url'] | |||
| if server_url.endswith('/'): | |||
| server_url = server_url[:-1] | |||
| try: | |||
| results = TeiHelper.invoke_rerank(server_url, query, docs) | |||
| rerank_documents = [] | |||
| for result in results: | |||
| rerank_document = RerankDocument( | |||
| index=result['index'], | |||
| text=result['text'], | |||
| score=result['score'], | |||
| ) | |||
| if score_threshold is None or result['score'] >= score_threshold: | |||
| rerank_documents.append(rerank_document) | |||
| if top_n is not None and len(rerank_documents) >= top_n: | |||
| break | |||
| return RerankResult(model=model, docs=rerank_documents) | |||
| except httpx.HTTPStatusError as e: | |||
| raise InvokeServerUnavailableError(str(e)) | |||
| def validate_credentials(self, model: str, credentials: dict) -> None: | |||
| """ | |||
| Validate model credentials | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :return: | |||
| """ | |||
| try: | |||
| server_url = credentials['server_url'] | |||
| extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) | |||
| if extra_args.model_type != 'reranker': | |||
| raise CredentialsValidateFailedError('Current model is not a rerank model') | |||
| credentials['context_size'] = extra_args.max_input_length | |||
| self.invoke( | |||
| model=model, | |||
| credentials=credentials, | |||
| query='Whose kasumi', | |||
| docs=[ | |||
| 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', | |||
| 'Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ', | |||
| 'and she leads a team named PopiParty.', | |||
| ], | |||
| score_threshold=0.8, | |||
| ) | |||
| except Exception as ex: | |||
| raise CredentialsValidateFailedError(str(ex)) | |||
| @property | |||
| def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: | |||
| """ | |||
| Map model invoke error to unified error | |||
| The key is the error type thrown to the caller | |||
| The value is the error type thrown by the model, | |||
| which needs to be converted into a unified error type for the caller. | |||
| :return: Invoke error mapping | |||
| """ | |||
| return { | |||
| InvokeConnectionError: [InvokeConnectionError], | |||
| InvokeServerUnavailableError: [InvokeServerUnavailableError], | |||
| InvokeRateLimitError: [InvokeRateLimitError], | |||
| InvokeAuthorizationError: [InvokeAuthorizationError], | |||
| InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], | |||
| } | |||
| def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: | |||
| """ | |||
| used to define customizable model schema | |||
| """ | |||
| entity = AIModelEntity( | |||
| model=model, | |||
| label=I18nObject(en_US=model), | |||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | |||
| model_type=ModelType.RERANK, | |||
| model_properties={ | |||
| ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)), | |||
| }, | |||
| parameter_rules=[], | |||
| ) | |||
| return entity | |||
| @@ -0,0 +1,183 @@ | |||
| from threading import Lock | |||
| from time import time | |||
| from typing import Optional | |||
| import httpx | |||
| from requests.adapters import HTTPAdapter | |||
| from requests.exceptions import ConnectionError, MissingSchema, Timeout | |||
| from requests.sessions import Session | |||
| from yarl import URL | |||
| class TeiModelExtraParameter: | |||
| model_type: str | |||
| max_input_length: int | |||
| max_client_batch_size: int | |||
| def __init__(self, model_type: str, max_input_length: int, max_client_batch_size: Optional[int] = None) -> None: | |||
| self.model_type = model_type | |||
| self.max_input_length = max_input_length | |||
| self.max_client_batch_size = max_client_batch_size | |||
| cache = {} | |||
| cache_lock = Lock() | |||
| class TeiHelper: | |||
| @staticmethod | |||
| def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter: | |||
| TeiHelper._clean_cache() | |||
| with cache_lock: | |||
| if model_name not in cache: | |||
| cache[model_name] = { | |||
| 'expires': time() + 300, | |||
| 'value': TeiHelper._get_tei_extra_parameter(server_url), | |||
| } | |||
| return cache[model_name]['value'] | |||
| @staticmethod | |||
| def _clean_cache() -> None: | |||
| try: | |||
| with cache_lock: | |||
| expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()] | |||
| for model_uid in expired_keys: | |||
| del cache[model_uid] | |||
| except RuntimeError as e: | |||
| pass | |||
| @staticmethod | |||
| def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter: | |||
| """ | |||
| get tei model extra parameter like model_type, max_input_length, max_batch_requests | |||
| """ | |||
| url = str(URL(server_url) / 'info') | |||
| # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 | |||
| session = Session() | |||
| session.mount('http://', HTTPAdapter(max_retries=3)) | |||
| session.mount('https://', HTTPAdapter(max_retries=3)) | |||
| try: | |||
| response = session.get(url, timeout=10) | |||
| except (MissingSchema, ConnectionError, Timeout) as e: | |||
| raise RuntimeError(f'get tei model extra parameter failed, url: {url}, error: {e}') | |||
| if response.status_code != 200: | |||
| raise RuntimeError( | |||
| f'get tei model extra parameter failed, status code: {response.status_code}, response: {response.text}' | |||
| ) | |||
| response_json = response.json() | |||
| model_type = response_json.get('model_type', {}) | |||
| if len(model_type.keys()) < 1: | |||
| raise RuntimeError('model_type is empty') | |||
| model_type = list(model_type.keys())[0] | |||
| if model_type not in ['embedding', 'reranker']: | |||
| raise RuntimeError(f'invalid model_type: {model_type}') | |||
| max_input_length = response_json.get('max_input_length', 512) | |||
| max_client_batch_size = response_json.get('max_client_batch_size', 1) | |||
| return TeiModelExtraParameter( | |||
| model_type=model_type, | |||
| max_input_length=max_input_length, | |||
| max_client_batch_size=max_client_batch_size | |||
| ) | |||
| @staticmethod | |||
| def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: | |||
| """ | |||
| Invoke tokenize endpoint | |||
| Example response: | |||
| [ | |||
| [ | |||
| { | |||
| "id": 0, | |||
| "text": "<s>", | |||
| "special": true, | |||
| "start": null, | |||
| "stop": null | |||
| }, | |||
| { | |||
| "id": 7704, | |||
| "text": "str", | |||
| "special": false, | |||
| "start": 0, | |||
| "stop": 3 | |||
| }, | |||
| < MORE TOKENS > | |||
| ] | |||
| ] | |||
| :param server_url: server url | |||
| :param texts: texts to tokenize | |||
| """ | |||
| resp = httpx.post( | |||
| f'{server_url}/tokenize', | |||
| json={'inputs': texts}, | |||
| ) | |||
| resp.raise_for_status() | |||
| return resp.json() | |||
| @staticmethod | |||
| def invoke_embeddings(server_url: str, texts: list[str]) -> dict: | |||
| """ | |||
| Invoke embeddings endpoint | |||
| Example response: | |||
| { | |||
| "object": "list", | |||
| "data": [ | |||
| { | |||
| "object": "embedding", | |||
| "embedding": [...], | |||
| "index": 0 | |||
| } | |||
| ], | |||
| "model": "MODEL_NAME", | |||
| "usage": { | |||
| "prompt_tokens": 3, | |||
| "total_tokens": 3 | |||
| } | |||
| } | |||
| :param server_url: server url | |||
| :param texts: texts to embed | |||
| """ | |||
| # Use OpenAI compatible API here, which has usage tracking | |||
| resp = httpx.post( | |||
| f'{server_url}/v1/embeddings', | |||
| json={'input': texts}, | |||
| ) | |||
| resp.raise_for_status() | |||
| return resp.json() | |||
| @staticmethod | |||
| def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]: | |||
| """ | |||
| Invoke rerank endpoint | |||
| Example response: | |||
| [ | |||
| { | |||
| "index": 0, | |||
| "text": "Deep Learning is ...", | |||
| "score": 0.9950755 | |||
| } | |||
| ] | |||
| :param server_url: server url | |||
| :param texts: texts to rerank | |||
| :param candidates: candidates to rerank | |||
| """ | |||
| params = {'query': query, 'texts': docs, 'return_text': True} | |||
| response = httpx.post( | |||
| server_url + '/rerank', | |||
| json=params, | |||
| ) | |||
| response.raise_for_status() | |||
| return response.json() | |||
| @@ -0,0 +1,204 @@ | |||
| import time | |||
| from typing import Optional | |||
| from core.model_runtime.entities.common_entities import I18nObject | |||
| from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType | |||
| from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult | |||
| from core.model_runtime.errors.invoke import ( | |||
| InvokeAuthorizationError, | |||
| InvokeBadRequestError, | |||
| InvokeConnectionError, | |||
| InvokeError, | |||
| InvokeRateLimitError, | |||
| InvokeServerUnavailableError, | |||
| ) | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | |||
| from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiHelper | |||
| class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): | |||
| """ | |||
| Model class for Text Embedding Inference text embedding model. | |||
| """ | |||
| def _invoke( | |||
| self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None | |||
| ) -> TextEmbeddingResult: | |||
| """ | |||
| Invoke text embedding model | |||
| credentials should be like: | |||
| { | |||
| 'server_url': 'server url', | |||
| 'model_uid': 'model uid', | |||
| } | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :param texts: texts to embed | |||
| :param user: unique user id | |||
| :return: embeddings result | |||
| """ | |||
| server_url = credentials['server_url'] | |||
| if server_url.endswith('/'): | |||
| server_url = server_url[:-1] | |||
| # get model properties | |||
| context_size = self._get_context_size(model, credentials) | |||
| max_chunks = self._get_max_chunks(model, credentials) | |||
| inputs = [] | |||
| indices = [] | |||
| used_tokens = 0 | |||
| # get tokenized results from TEI | |||
| batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts) | |||
| for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)): | |||
| # Check if the number of tokens is larger than the context size | |||
| num_tokens = len(tokenize_result) | |||
| if num_tokens >= context_size: | |||
| # Find the best cutoff point | |||
| pre_special_token_count = 0 | |||
| for token in tokenize_result: | |||
| if token['special']: | |||
| pre_special_token_count += 1 | |||
| else: | |||
| break | |||
| rest_special_token_count = len([token for token in tokenize_result if token['special']]) - pre_special_token_count | |||
| # Calculate the cutoff point, leave 20 extra space to avoid exceeding the limit | |||
| token_cutoff = context_size - rest_special_token_count - 20 | |||
| # Find the cutoff index | |||
| cutpoint_token = tokenize_result[token_cutoff] | |||
| cutoff = cutpoint_token['start'] | |||
| inputs.append(text[0: cutoff]) | |||
| else: | |||
| inputs.append(text) | |||
| indices += [i] | |||
| batched_embeddings = [] | |||
| _iter = range(0, len(inputs), max_chunks) | |||
| try: | |||
| used_tokens = 0 | |||
| for i in _iter: | |||
| iter_texts = inputs[i : i + max_chunks] | |||
| results = TeiHelper.invoke_embeddings(server_url, iter_texts) | |||
| embeddings = results['data'] | |||
| embeddings = [embedding['embedding'] for embedding in embeddings] | |||
| batched_embeddings.extend(embeddings) | |||
| usage = results['usage'] | |||
| used_tokens += usage['total_tokens'] | |||
| except RuntimeError as e: | |||
| raise InvokeServerUnavailableError(str(e)) | |||
| usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) | |||
| result = TextEmbeddingResult(model=model, embeddings=batched_embeddings, usage=usage) | |||
| return result | |||
| def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: | |||
| """ | |||
| Get number of tokens for given prompt messages | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :param texts: texts to embed | |||
| :return: | |||
| """ | |||
| num_tokens = 0 | |||
| server_url = credentials['server_url'] | |||
| if server_url.endswith('/'): | |||
| server_url = server_url[:-1] | |||
| batch_tokens = TeiHelper.invoke_tokenize(server_url, texts) | |||
| num_tokens = sum(len(tokens) for tokens in batch_tokens) | |||
| return num_tokens | |||
| def validate_credentials(self, model: str, credentials: dict) -> None: | |||
| """ | |||
| Validate model credentials | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :return: | |||
| """ | |||
| try: | |||
| server_url = credentials['server_url'] | |||
| extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) | |||
| print(extra_args) | |||
| if extra_args.model_type != 'embedding': | |||
| raise CredentialsValidateFailedError('Current model is not a embedding model') | |||
| credentials['context_size'] = extra_args.max_input_length | |||
| credentials['max_chunks'] = extra_args.max_client_batch_size | |||
| self._invoke(model=model, credentials=credentials, texts=['ping']) | |||
| except Exception as ex: | |||
| raise CredentialsValidateFailedError(str(ex)) | |||
| @property | |||
| def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: | |||
| return { | |||
| InvokeConnectionError: [InvokeConnectionError], | |||
| InvokeServerUnavailableError: [InvokeServerUnavailableError], | |||
| InvokeRateLimitError: [InvokeRateLimitError], | |||
| InvokeAuthorizationError: [InvokeAuthorizationError], | |||
| InvokeBadRequestError: [KeyError], | |||
| } | |||
| def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: | |||
| """ | |||
| Calculate response usage | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :param tokens: input tokens | |||
| :return: usage | |||
| """ | |||
| # get input price info | |||
| input_price_info = self.get_price( | |||
| model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens | |||
| ) | |||
| # transform usage | |||
| usage = EmbeddingUsage( | |||
| tokens=tokens, | |||
| total_tokens=tokens, | |||
| unit_price=input_price_info.unit_price, | |||
| price_unit=input_price_info.unit, | |||
| total_price=input_price_info.total_amount, | |||
| currency=input_price_info.currency, | |||
| latency=time.perf_counter() - self.started_at, | |||
| ) | |||
| return usage | |||
| def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: | |||
| """ | |||
| used to define customizable model schema | |||
| """ | |||
| entity = AIModelEntity( | |||
| model=model, | |||
| label=I18nObject(en_US=model), | |||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model_properties={ | |||
| ModelPropertyKey.MAX_CHUNKS: int(credentials.get('max_chunks', 1)), | |||
| ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)), | |||
| }, | |||
| parameter_rules=[], | |||
| ) | |||
| return entity | |||
| @@ -93,6 +93,8 @@ CODE_MAX_STRING_LENGTH = "80000" | |||
| CODE_EXECUTION_ENDPOINT = "http://127.0.0.1:8194" | |||
| CODE_EXECUTION_API_KEY = "dify-sandbox" | |||
| FIRECRAWL_API_KEY = "fc-" | |||
| TEI_EMBEDDING_SERVER_URL = "http://a.abc.com:11451" | |||
| TEI_RERANK_SERVER_URL = "http://a.abc.com:11451" | |||
| [tool.poetry] | |||
| name = "dify-api" | |||
| @@ -0,0 +1,94 @@ | |||
| from api.core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiModelExtraParameter | |||
| class MockTEIClass: | |||
| @staticmethod | |||
| def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter: | |||
| # During mock, we don't have a real server to query, so we just return a dummy value | |||
| if 'rerank' in model_name: | |||
| model_type = 'reranker' | |||
| else: | |||
| model_type = 'embedding' | |||
| return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1) | |||
| @staticmethod | |||
| def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: | |||
| # Use space as token separator, and split the text into tokens | |||
| tokenized_texts = [] | |||
| for text in texts: | |||
| tokens = text.split(' ') | |||
| current_index = 0 | |||
| tokenized_text = [] | |||
| for idx, token in enumerate(tokens): | |||
| s_token = { | |||
| 'id': idx, | |||
| 'text': token, | |||
| 'special': False, | |||
| 'start': current_index, | |||
| 'stop': current_index + len(token), | |||
| } | |||
| current_index += len(token) + 1 | |||
| tokenized_text.append(s_token) | |||
| tokenized_texts.append(tokenized_text) | |||
| return tokenized_texts | |||
| @staticmethod | |||
| def invoke_embeddings(server_url: str, texts: list[str]) -> dict: | |||
| # { | |||
| # "object": "list", | |||
| # "data": [ | |||
| # { | |||
| # "object": "embedding", | |||
| # "embedding": [...], | |||
| # "index": 0 | |||
| # } | |||
| # ], | |||
| # "model": "MODEL_NAME", | |||
| # "usage": { | |||
| # "prompt_tokens": 3, | |||
| # "total_tokens": 3 | |||
| # } | |||
| # } | |||
| embeddings = [] | |||
| for idx, text in enumerate(texts): | |||
| embedding = [0.1] * 768 | |||
| embeddings.append( | |||
| { | |||
| 'object': 'embedding', | |||
| 'embedding': embedding, | |||
| 'index': idx, | |||
| } | |||
| ) | |||
| return { | |||
| 'object': 'list', | |||
| 'data': embeddings, | |||
| 'model': 'MODEL_NAME', | |||
| 'usage': { | |||
| 'prompt_tokens': sum(len(text.split(' ')) for text in texts), | |||
| 'total_tokens': sum(len(text.split(' ')) for text in texts), | |||
| }, | |||
| } | |||
| def invoke_rerank(server_url: str, query: str, texts: list[str]) -> list[dict]: | |||
| # Example response: | |||
| # [ | |||
| # { | |||
| # "index": 0, | |||
| # "text": "Deep Learning is ...", | |||
| # "score": 0.9950755 | |||
| # } | |||
| # ] | |||
| reranked_docs = [] | |||
| for idx, text in enumerate(texts): | |||
| reranked_docs.append( | |||
| { | |||
| 'index': idx, | |||
| 'text': text, | |||
| 'score': 0.9, | |||
| } | |||
| ) | |||
| # For mock, only return the first document | |||
| break | |||
| return reranked_docs | |||
| @@ -0,0 +1,72 @@ | |||
| import os | |||
| import pytest | |||
| from api.core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper | |||
| from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import ( | |||
| HuggingfaceTeiTextEmbeddingModel, | |||
| ) | |||
| from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass | |||
| MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' | |||
| @pytest.fixture | |||
| def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch): | |||
| if MOCK: | |||
| monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter) | |||
| monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize) | |||
| monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings) | |||
| monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank) | |||
| yield | |||
| if MOCK: | |||
| monkeypatch.undo() | |||
| @pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True) | |||
| def test_validate_credentials(setup_tei_mock): | |||
| model = HuggingfaceTeiTextEmbeddingModel() | |||
| # model name is only used in mock | |||
| model_name = 'embedding' | |||
| if MOCK: | |||
| # TEI Provider will check model type by API endpoint, at real server, the model type is correct. | |||
| # So we dont need to check model type here. Only check in mock | |||
| with pytest.raises(CredentialsValidateFailedError): | |||
| model.validate_credentials( | |||
| model='reranker', | |||
| credentials={ | |||
| 'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""), | |||
| } | |||
| ) | |||
| model.validate_credentials( | |||
| model=model_name, | |||
| credentials={ | |||
| 'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""), | |||
| } | |||
| ) | |||
| @pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True) | |||
| def test_invoke_model(setup_tei_mock): | |||
| model = HuggingfaceTeiTextEmbeddingModel() | |||
| model_name = 'embedding' | |||
| result = model.invoke( | |||
| model=model_name, | |||
| credentials={ | |||
| 'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""), | |||
| }, | |||
| texts=[ | |||
| "hello", | |||
| "world" | |||
| ], | |||
| user="abc-123" | |||
| ) | |||
| assert isinstance(result, TextEmbeddingResult) | |||
| assert len(result.embeddings) == 2 | |||
| assert result.usage.total_tokens > 0 | |||
| @@ -0,0 +1,76 @@ | |||
| import os | |||
| import pytest | |||
| from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult | |||
| from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.huggingface_tei.rerank.rerank import ( | |||
| HuggingfaceTeiRerankModel, | |||
| ) | |||
| from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper | |||
| from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass | |||
| MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' | |||
| @pytest.fixture | |||
| def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch): | |||
| if MOCK: | |||
| monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter) | |||
| monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize) | |||
| monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings) | |||
| monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank) | |||
| yield | |||
| if MOCK: | |||
| monkeypatch.undo() | |||
| @pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True) | |||
| def test_validate_credentials(setup_tei_mock): | |||
| model = HuggingfaceTeiRerankModel() | |||
| # model name is only used in mock | |||
| model_name = 'reranker' | |||
| if MOCK: | |||
| # TEI Provider will check model type by API endpoint, at real server, the model type is correct. | |||
| # So we dont need to check model type here. Only check in mock | |||
| with pytest.raises(CredentialsValidateFailedError): | |||
| model.validate_credentials( | |||
| model='embedding', | |||
| credentials={ | |||
| 'server_url': os.environ.get('TEI_RERANK_SERVER_URL'), | |||
| } | |||
| ) | |||
| model.validate_credentials( | |||
| model=model_name, | |||
| credentials={ | |||
| 'server_url': os.environ.get('TEI_RERANK_SERVER_URL'), | |||
| } | |||
| ) | |||
| @pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True) | |||
| def test_invoke_model(setup_tei_mock): | |||
| model = HuggingfaceTeiRerankModel() | |||
| # model name is only used in mock | |||
| model_name = 'reranker' | |||
| result = model.invoke( | |||
| model=model_name, | |||
| credentials={ | |||
| 'server_url': os.environ.get('TEI_RERANK_SERVER_URL'), | |||
| }, | |||
| query="Who is Kasumi?", | |||
| docs=[ | |||
| "Kasumi is a girl's name of Japanese origin meaning \"mist\".", | |||
| "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", | |||
| "and she leads a team named PopiParty." | |||
| ], | |||
| score_threshold=0.8 | |||
| ) | |||
| assert isinstance(result, RerankResult) | |||
| assert len(result.docs) == 1 | |||
| assert result.docs[0].index == 0 | |||
| assert result.docs[0].score >= 0.8 | |||