Signed-off-by: kenwoodjw <blackxin55+@gmail.com> Co-authored-by: crazywoola <427733928@qq.com>tags/0.12.0
| @@ -34,3 +34,11 @@ model_credential_schema: | |||
| placeholder: | |||
| zh_Hans: 在此输入Text Embedding Inference的服务器地址,如 http://192.168.1.100:8080 | |||
| en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080 | |||
| - variable: api_key | |||
| label: | |||
| en_US: API Key | |||
| type: secret-input | |||
| required: false | |||
| placeholder: | |||
| zh_Hans: 在此输入您的 API Key | |||
| en_US: Enter your API Key | |||
| @@ -51,8 +51,13 @@ class HuggingfaceTeiRerankModel(RerankModel): | |||
| server_url = server_url.removesuffix("/") | |||
| headers = {"Content-Type": "application/json"} | |||
| api_key = credentials.get("api_key") | |||
| if api_key: | |||
| headers["Authorization"] = f"Bearer {api_key}" | |||
| try: | |||
| results = TeiHelper.invoke_rerank(server_url, query, docs) | |||
| results = TeiHelper.invoke_rerank(server_url, query, docs, headers) | |||
| rerank_documents = [] | |||
| for result in results: | |||
| @@ -80,7 +85,11 @@ class HuggingfaceTeiRerankModel(RerankModel): | |||
| """ | |||
| try: | |||
| server_url = credentials["server_url"] | |||
| extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) | |||
| headers = {"Content-Type": "application/json"} | |||
| api_key = credentials.get("api_key") | |||
| if api_key: | |||
| headers["Authorization"] = f"Bearer {api_key}" | |||
| extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers) | |||
| if extra_args.model_type != "reranker": | |||
| raise CredentialsValidateFailedError("Current model is not a rerank model") | |||
| @@ -26,13 +26,15 @@ cache_lock = Lock() | |||
| class TeiHelper: | |||
| @staticmethod | |||
| def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter: | |||
| def get_tei_extra_parameter( | |||
| server_url: str, model_name: str, headers: Optional[dict] = None | |||
| ) -> TeiModelExtraParameter: | |||
| TeiHelper._clean_cache() | |||
| with cache_lock: | |||
| if model_name not in cache: | |||
| cache[model_name] = { | |||
| "expires": time() + 300, | |||
| "value": TeiHelper._get_tei_extra_parameter(server_url), | |||
| "value": TeiHelper._get_tei_extra_parameter(server_url, headers), | |||
| } | |||
| return cache[model_name]["value"] | |||
| @@ -47,7 +49,7 @@ class TeiHelper: | |||
| pass | |||
| @staticmethod | |||
| def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter: | |||
| def _get_tei_extra_parameter(server_url: str, headers: Optional[dict] = None) -> TeiModelExtraParameter: | |||
| """ | |||
| get tei model extra parameter like model_type, max_input_length, max_batch_requests | |||
| """ | |||
| @@ -61,7 +63,7 @@ class TeiHelper: | |||
| session.mount("https://", HTTPAdapter(max_retries=3)) | |||
| try: | |||
| response = session.get(url, timeout=10) | |||
| response = session.get(url, headers=headers, timeout=10) | |||
| except (MissingSchema, ConnectionError, Timeout) as e: | |||
| raise RuntimeError(f"get tei model extra parameter failed, url: {url}, error: {e}") | |||
| if response.status_code != 200: | |||
| @@ -86,7 +88,7 @@ class TeiHelper: | |||
| ) | |||
| @staticmethod | |||
| def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: | |||
| def invoke_tokenize(server_url: str, texts: list[str], headers: Optional[dict] = None) -> list[list[dict]]: | |||
| """ | |||
| Invoke tokenize endpoint | |||
| @@ -114,15 +116,15 @@ class TeiHelper: | |||
| :param server_url: server url | |||
| :param texts: texts to tokenize | |||
| """ | |||
| resp = httpx.post( | |||
| f"{server_url}/tokenize", | |||
| json={"inputs": texts}, | |||
| ) | |||
| url = f"{server_url}/tokenize" | |||
| json_data = {"inputs": texts} | |||
| resp = httpx.post(url, json=json_data, headers=headers) | |||
| resp.raise_for_status() | |||
| return resp.json() | |||
| @staticmethod | |||
| def invoke_embeddings(server_url: str, texts: list[str]) -> dict: | |||
| def invoke_embeddings(server_url: str, texts: list[str], headers: Optional[dict] = None) -> dict: | |||
| """ | |||
| Invoke embeddings endpoint | |||
| @@ -147,15 +149,14 @@ class TeiHelper: | |||
| :param texts: texts to embed | |||
| """ | |||
| # Use OpenAI compatible API here, which has usage tracking | |||
| resp = httpx.post( | |||
| f"{server_url}/v1/embeddings", | |||
| json={"input": texts}, | |||
| ) | |||
| url = f"{server_url}/v1/embeddings" | |||
| json_data = {"input": texts} | |||
| resp = httpx.post(url, json=json_data, headers=headers) | |||
| resp.raise_for_status() | |||
| return resp.json() | |||
| @staticmethod | |||
| def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]: | |||
| def invoke_rerank(server_url: str, query: str, docs: list[str], headers: Optional[dict] = None) -> list[dict]: | |||
| """ | |||
| Invoke rerank endpoint | |||
| @@ -173,10 +174,7 @@ class TeiHelper: | |||
| :param candidates: candidates to rerank | |||
| """ | |||
| params = {"query": query, "texts": docs, "return_text": True} | |||
| response = httpx.post( | |||
| server_url + "/rerank", | |||
| json=params, | |||
| ) | |||
| url = f"{server_url}/rerank" | |||
| response = httpx.post(url, json=params, headers=headers) | |||
| response.raise_for_status() | |||
| return response.json() | |||
| @@ -51,6 +51,10 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): | |||
| server_url = server_url.removesuffix("/") | |||
| headers = {"Content-Type": "application/json"} | |||
| api_key = credentials["api_key"] | |||
| if api_key: | |||
| headers["Authorization"] = f"Bearer {api_key}" | |||
| # get model properties | |||
| context_size = self._get_context_size(model, credentials) | |||
| max_chunks = self._get_max_chunks(model, credentials) | |||
| @@ -60,7 +64,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): | |||
| used_tokens = 0 | |||
| # get tokenized results from TEI | |||
| batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts) | |||
| batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts, headers) | |||
| for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)): | |||
| # Check if the number of tokens is larger than the context size | |||
| @@ -97,7 +101,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): | |||
| used_tokens = 0 | |||
| for i in _iter: | |||
| iter_texts = inputs[i : i + max_chunks] | |||
| results = TeiHelper.invoke_embeddings(server_url, iter_texts) | |||
| results = TeiHelper.invoke_embeddings(server_url, iter_texts, headers) | |||
| embeddings = results["data"] | |||
| embeddings = [embedding["embedding"] for embedding in embeddings] | |||
| batched_embeddings.extend(embeddings) | |||
| @@ -127,7 +131,11 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): | |||
| server_url = server_url.removesuffix("/") | |||
| batch_tokens = TeiHelper.invoke_tokenize(server_url, texts) | |||
| headers = { | |||
| "Authorization": f"Bearer {credentials.get('api_key')}", | |||
| } | |||
| batch_tokens = TeiHelper.invoke_tokenize(server_url, texts, headers) | |||
| num_tokens = sum(len(tokens) for tokens in batch_tokens) | |||
| return num_tokens | |||
| @@ -141,7 +149,14 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): | |||
| """ | |||
| try: | |||
| server_url = credentials["server_url"] | |||
| extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) | |||
| headers = {"Content-Type": "application/json"} | |||
| api_key = credentials.get("api_key") | |||
| if api_key: | |||
| headers["Authorization"] = f"Bearer {api_key}" | |||
| extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers) | |||
| print(extra_args) | |||
| if extra_args.model_type != "embedding": | |||
| raise CredentialsValidateFailedError("Current model is not a embedding model") | |||
| @@ -20,6 +20,7 @@ env = | |||
| OPENAI_API_KEY = sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii | |||
| TEI_EMBEDDING_SERVER_URL = http://a.abc.com:11451 | |||
| TEI_RERANK_SERVER_URL = http://a.abc.com:11451 | |||
| TEI_API_KEY = ttttttttttttttt | |||
| UPSTAGE_API_KEY = up-aaaaaaaaaaaaaaaaaaaa | |||
| VOYAGE_API_KEY = va-aaaaaaaaaaaaaaaaaaaa | |||
| XINFERENCE_CHAT_MODEL_UID = chat | |||
| @@ -40,6 +40,7 @@ def test_validate_credentials(setup_tei_mock): | |||
| model="reranker", | |||
| credentials={ | |||
| "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), | |||
| "api_key": os.environ.get("TEI_API_KEY", ""), | |||
| }, | |||
| ) | |||
| @@ -47,6 +48,7 @@ def test_validate_credentials(setup_tei_mock): | |||
| model=model_name, | |||
| credentials={ | |||
| "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), | |||
| "api_key": os.environ.get("TEI_API_KEY", ""), | |||
| }, | |||
| ) | |||
| @@ -60,6 +62,7 @@ def test_invoke_model(setup_tei_mock): | |||
| model=model_name, | |||
| credentials={ | |||
| "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), | |||
| "api_key": os.environ.get("TEI_API_KEY", ""), | |||
| }, | |||
| texts=["hello", "world"], | |||
| user="abc-123", | |||
| @@ -40,6 +40,7 @@ def test_validate_credentials(setup_tei_mock): | |||
| model="embedding", | |||
| credentials={ | |||
| "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), | |||
| "api_key": os.environ.get("TEI_API_KEY", ""), | |||
| }, | |||
| ) | |||
| @@ -47,6 +48,7 @@ def test_validate_credentials(setup_tei_mock): | |||
| model=model_name, | |||
| credentials={ | |||
| "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), | |||
| "api_key": os.environ.get("TEI_API_KEY", ""), | |||
| }, | |||
| ) | |||
| @@ -61,6 +63,7 @@ def test_invoke_model(setup_tei_mock): | |||
| model=model_name, | |||
| credentials={ | |||
| "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), | |||
| "api_key": os.environ.get("TEI_API_KEY", ""), | |||
| }, | |||
| query="Who is Kasumi?", | |||
| docs=[ | |||