Signed-off-by: kenwoodjw <blackxin55+@gmail.com> Co-authored-by: crazywoola <427733928@qq.com>tags/0.12.0
| placeholder: | placeholder: | ||||
| zh_Hans: 在此输入Text Embedding Inference的服务器地址,如 http://192.168.1.100:8080 | zh_Hans: 在此输入Text Embedding Inference的服务器地址,如 http://192.168.1.100:8080 | ||||
| en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080 | en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080 | ||||
| - variable: api_key | |||||
| label: | |||||
| en_US: API Key | |||||
| type: secret-input | |||||
| required: false | |||||
| placeholder: | |||||
| zh_Hans: 在此输入您的 API Key | |||||
| en_US: Enter your API Key |
| server_url = server_url.removesuffix("/") | server_url = server_url.removesuffix("/") | ||||
| headers = {"Content-Type": "application/json"} | |||||
| api_key = credentials.get("api_key") | |||||
| if api_key: | |||||
| headers["Authorization"] = f"Bearer {api_key}" | |||||
| try: | try: | ||||
| results = TeiHelper.invoke_rerank(server_url, query, docs) | |||||
| results = TeiHelper.invoke_rerank(server_url, query, docs, headers) | |||||
| rerank_documents = [] | rerank_documents = [] | ||||
| for result in results: | for result in results: | ||||
| """ | """ | ||||
| try: | try: | ||||
| server_url = credentials["server_url"] | server_url = credentials["server_url"] | ||||
| extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) | |||||
| headers = {"Content-Type": "application/json"} | |||||
| api_key = credentials.get("api_key") | |||||
| if api_key: | |||||
| headers["Authorization"] = f"Bearer {api_key}" | |||||
| extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers) | |||||
| if extra_args.model_type != "reranker": | if extra_args.model_type != "reranker": | ||||
| raise CredentialsValidateFailedError("Current model is not a rerank model") | raise CredentialsValidateFailedError("Current model is not a rerank model") | ||||
| class TeiHelper: | class TeiHelper: | ||||
| @staticmethod | @staticmethod | ||||
| def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter: | |||||
| def get_tei_extra_parameter( | |||||
| server_url: str, model_name: str, headers: Optional[dict] = None | |||||
| ) -> TeiModelExtraParameter: | |||||
| TeiHelper._clean_cache() | TeiHelper._clean_cache() | ||||
| with cache_lock: | with cache_lock: | ||||
| if model_name not in cache: | if model_name not in cache: | ||||
| cache[model_name] = { | cache[model_name] = { | ||||
| "expires": time() + 300, | "expires": time() + 300, | ||||
| "value": TeiHelper._get_tei_extra_parameter(server_url), | |||||
| "value": TeiHelper._get_tei_extra_parameter(server_url, headers), | |||||
| } | } | ||||
| return cache[model_name]["value"] | return cache[model_name]["value"] | ||||
| pass | pass | ||||
| @staticmethod | @staticmethod | ||||
| def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter: | |||||
| def _get_tei_extra_parameter(server_url: str, headers: Optional[dict] = None) -> TeiModelExtraParameter: | |||||
| """ | """ | ||||
| get tei model extra parameter like model_type, max_input_length, max_batch_requests | get tei model extra parameter like model_type, max_input_length, max_batch_requests | ||||
| """ | """ | ||||
| session.mount("https://", HTTPAdapter(max_retries=3)) | session.mount("https://", HTTPAdapter(max_retries=3)) | ||||
| try: | try: | ||||
| response = session.get(url, timeout=10) | |||||
| response = session.get(url, headers=headers, timeout=10) | |||||
| except (MissingSchema, ConnectionError, Timeout) as e: | except (MissingSchema, ConnectionError, Timeout) as e: | ||||
| raise RuntimeError(f"get tei model extra parameter failed, url: {url}, error: {e}") | raise RuntimeError(f"get tei model extra parameter failed, url: {url}, error: {e}") | ||||
| if response.status_code != 200: | if response.status_code != 200: | ||||
| ) | ) | ||||
| @staticmethod | @staticmethod | ||||
| def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: | |||||
| def invoke_tokenize(server_url: str, texts: list[str], headers: Optional[dict] = None) -> list[list[dict]]: | |||||
| """ | """ | ||||
| Invoke tokenize endpoint | Invoke tokenize endpoint | ||||
| :param server_url: server url | :param server_url: server url | ||||
| :param texts: texts to tokenize | :param texts: texts to tokenize | ||||
| """ | """ | ||||
| resp = httpx.post( | |||||
| f"{server_url}/tokenize", | |||||
| json={"inputs": texts}, | |||||
| ) | |||||
| url = f"{server_url}/tokenize" | |||||
| json_data = {"inputs": texts} | |||||
| resp = httpx.post(url, json=json_data, headers=headers) | |||||
| resp.raise_for_status() | resp.raise_for_status() | ||||
| return resp.json() | return resp.json() | ||||
| @staticmethod | @staticmethod | ||||
| def invoke_embeddings(server_url: str, texts: list[str]) -> dict: | |||||
| def invoke_embeddings(server_url: str, texts: list[str], headers: Optional[dict] = None) -> dict: | |||||
| """ | """ | ||||
| Invoke embeddings endpoint | Invoke embeddings endpoint | ||||
| :param texts: texts to embed | :param texts: texts to embed | ||||
| """ | """ | ||||
| # Use OpenAI compatible API here, which has usage tracking | # Use OpenAI compatible API here, which has usage tracking | ||||
| resp = httpx.post( | |||||
| f"{server_url}/v1/embeddings", | |||||
| json={"input": texts}, | |||||
| ) | |||||
| url = f"{server_url}/v1/embeddings" | |||||
| json_data = {"input": texts} | |||||
| resp = httpx.post(url, json=json_data, headers=headers) | |||||
| resp.raise_for_status() | resp.raise_for_status() | ||||
| return resp.json() | return resp.json() | ||||
| @staticmethod | @staticmethod | ||||
| def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]: | |||||
| def invoke_rerank(server_url: str, query: str, docs: list[str], headers: Optional[dict] = None) -> list[dict]: | |||||
| """ | """ | ||||
| Invoke rerank endpoint | Invoke rerank endpoint | ||||
| :param candidates: candidates to rerank | :param candidates: candidates to rerank | ||||
| """ | """ | ||||
| params = {"query": query, "texts": docs, "return_text": True} | params = {"query": query, "texts": docs, "return_text": True} | ||||
| response = httpx.post( | |||||
| server_url + "/rerank", | |||||
| json=params, | |||||
| ) | |||||
| url = f"{server_url}/rerank" | |||||
| response = httpx.post(url, json=params, headers=headers) | |||||
| response.raise_for_status() | response.raise_for_status() | ||||
| return response.json() | return response.json() |
| server_url = server_url.removesuffix("/") | server_url = server_url.removesuffix("/") | ||||
| headers = {"Content-Type": "application/json"} | |||||
| api_key = credentials["api_key"] | |||||
| if api_key: | |||||
| headers["Authorization"] = f"Bearer {api_key}" | |||||
| # get model properties | # get model properties | ||||
| context_size = self._get_context_size(model, credentials) | context_size = self._get_context_size(model, credentials) | ||||
| max_chunks = self._get_max_chunks(model, credentials) | max_chunks = self._get_max_chunks(model, credentials) | ||||
| used_tokens = 0 | used_tokens = 0 | ||||
| # get tokenized results from TEI | # get tokenized results from TEI | ||||
| batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts) | |||||
| batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts, headers) | |||||
| for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)): | for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)): | ||||
| # Check if the number of tokens is larger than the context size | # Check if the number of tokens is larger than the context size | ||||
| used_tokens = 0 | used_tokens = 0 | ||||
| for i in _iter: | for i in _iter: | ||||
| iter_texts = inputs[i : i + max_chunks] | iter_texts = inputs[i : i + max_chunks] | ||||
| results = TeiHelper.invoke_embeddings(server_url, iter_texts) | |||||
| results = TeiHelper.invoke_embeddings(server_url, iter_texts, headers) | |||||
| embeddings = results["data"] | embeddings = results["data"] | ||||
| embeddings = [embedding["embedding"] for embedding in embeddings] | embeddings = [embedding["embedding"] for embedding in embeddings] | ||||
| batched_embeddings.extend(embeddings) | batched_embeddings.extend(embeddings) | ||||
| server_url = server_url.removesuffix("/") | server_url = server_url.removesuffix("/") | ||||
| batch_tokens = TeiHelper.invoke_tokenize(server_url, texts) | |||||
| headers = { | |||||
| "Authorization": f"Bearer {credentials.get('api_key')}", | |||||
| } | |||||
| batch_tokens = TeiHelper.invoke_tokenize(server_url, texts, headers) | |||||
| num_tokens = sum(len(tokens) for tokens in batch_tokens) | num_tokens = sum(len(tokens) for tokens in batch_tokens) | ||||
| return num_tokens | return num_tokens | ||||
| """ | """ | ||||
| try: | try: | ||||
| server_url = credentials["server_url"] | server_url = credentials["server_url"] | ||||
| extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) | |||||
| headers = {"Content-Type": "application/json"} | |||||
| api_key = credentials.get("api_key") | |||||
| if api_key: | |||||
| headers["Authorization"] = f"Bearer {api_key}" | |||||
| extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers) | |||||
| print(extra_args) | print(extra_args) | ||||
| if extra_args.model_type != "embedding": | if extra_args.model_type != "embedding": | ||||
| raise CredentialsValidateFailedError("Current model is not a embedding model") | raise CredentialsValidateFailedError("Current model is not a embedding model") |
| OPENAI_API_KEY = sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii | OPENAI_API_KEY = sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii | ||||
| TEI_EMBEDDING_SERVER_URL = http://a.abc.com:11451 | TEI_EMBEDDING_SERVER_URL = http://a.abc.com:11451 | ||||
| TEI_RERANK_SERVER_URL = http://a.abc.com:11451 | TEI_RERANK_SERVER_URL = http://a.abc.com:11451 | ||||
| TEI_API_KEY = ttttttttttttttt | |||||
| UPSTAGE_API_KEY = up-aaaaaaaaaaaaaaaaaaaa | UPSTAGE_API_KEY = up-aaaaaaaaaaaaaaaaaaaa | ||||
| VOYAGE_API_KEY = va-aaaaaaaaaaaaaaaaaaaa | VOYAGE_API_KEY = va-aaaaaaaaaaaaaaaaaaaa | ||||
| XINFERENCE_CHAT_MODEL_UID = chat | XINFERENCE_CHAT_MODEL_UID = chat |
| model="reranker", | model="reranker", | ||||
| credentials={ | credentials={ | ||||
| "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), | "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), | ||||
| "api_key": os.environ.get("TEI_API_KEY", ""), | |||||
| }, | }, | ||||
| ) | ) | ||||
| model=model_name, | model=model_name, | ||||
| credentials={ | credentials={ | ||||
| "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), | "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), | ||||
| "api_key": os.environ.get("TEI_API_KEY", ""), | |||||
| }, | }, | ||||
| ) | ) | ||||
| model=model_name, | model=model_name, | ||||
| credentials={ | credentials={ | ||||
| "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), | "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), | ||||
| "api_key": os.environ.get("TEI_API_KEY", ""), | |||||
| }, | }, | ||||
| texts=["hello", "world"], | texts=["hello", "world"], | ||||
| user="abc-123", | user="abc-123", |
| model="embedding", | model="embedding", | ||||
| credentials={ | credentials={ | ||||
| "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), | "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), | ||||
| "api_key": os.environ.get("TEI_API_KEY", ""), | |||||
| }, | }, | ||||
| ) | ) | ||||
| model=model_name, | model=model_name, | ||||
| credentials={ | credentials={ | ||||
| "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), | "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), | ||||
| "api_key": os.environ.get("TEI_API_KEY", ""), | |||||
| }, | }, | ||||
| ) | ) | ||||
| model=model_name, | model=model_name, | ||||
| credentials={ | credentials={ | ||||
| "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), | "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), | ||||
| "api_key": os.environ.get("TEI_API_KEY", ""), | |||||
| }, | }, | ||||
| query="Who is Kasumi?", | query="Who is Kasumi?", | ||||
| docs=[ | docs=[ |