| if credentials['server_url'].endswith('/'): | if credentials['server_url'].endswith('/'): | ||||
| credentials['server_url'] = credentials['server_url'][:-1] | credentials['server_url'] = credentials['server_url'][:-1] | ||||
| # initialize client | |||||
| client = Client( | |||||
| base_url=credentials['server_url'] | |||||
| ) | |||||
| xinference_client = client.get_model(model_uid=credentials['model_uid']) | |||||
| if not isinstance(xinference_client, RESTfulRerankModelHandle): | |||||
| raise InvokeBadRequestError('please check model type, the model you want to invoke is not a rerank model') | |||||
| response = xinference_client.rerank( | |||||
| handle = RESTfulRerankModelHandle(credentials['model_uid'], credentials['server_url'],auth_headers={}) | |||||
| response = handle.rerank( | |||||
| documents=docs, | documents=docs, | ||||
| query=query, | query=query, | ||||
| top_n=top_n, | top_n=top_n, | ||||
| try: | try: | ||||
| if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: | if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: | ||||
| raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") | raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") | ||||
| if credentials['server_url'].endswith('/'): | |||||
| credentials['server_url'] = credentials['server_url'][:-1] | |||||
| # initialize client | |||||
| client = Client( | |||||
| base_url=credentials['server_url'] | |||||
| ) | |||||
| xinference_client = client.get_model(model_uid=credentials['model_uid']) | |||||
| if not isinstance(xinference_client, RESTfulRerankModelHandle): | |||||
| raise InvokeBadRequestError( | |||||
| 'please check model type, the model you want to invoke is not a rerank model') | |||||
| self.invoke( | self.invoke( | ||||
| model=model, | model=model, | ||||
| parameter_rules=[] | parameter_rules=[] | ||||
| ) | ) | ||||
| return entity | |||||
| return entity | 
| if server_url.endswith('/'): | if server_url.endswith('/'): | ||||
| server_url = server_url[:-1] | server_url = server_url[:-1] | ||||
| client = Client(base_url=server_url) | |||||
| try: | |||||
| handle = client.get_model(model_uid=model_uid) | |||||
| except RuntimeError as e: | |||||
| raise InvokeAuthorizationError(e) | |||||
| if not isinstance(handle, RESTfulEmbeddingModelHandle): | |||||
| raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model') | |||||
| try: | try: | ||||
| handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers={}) | |||||
| embeddings = handle.create_embedding(input=texts) | embeddings = handle.create_embedding(input=texts) | ||||
| except RuntimeError as e: | except RuntimeError as e: | ||||
| raise InvokeServerUnavailableError(e) | raise InvokeServerUnavailableError(e) | ||||
| if extra_args.max_tokens: | if extra_args.max_tokens: | ||||
| credentials['max_tokens'] = extra_args.max_tokens | credentials['max_tokens'] = extra_args.max_tokens | ||||
| if server_url.endswith('/'): | |||||
| server_url = server_url[:-1] | |||||
| client = Client(base_url=server_url) | |||||
| try: | |||||
| handle = client.get_model(model_uid=model_uid) | |||||
| except RuntimeError as e: | |||||
| raise InvokeAuthorizationError(e) | |||||
| if not isinstance(handle, RESTfulEmbeddingModelHandle): | |||||
| raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model') | |||||
| self._invoke(model=model, credentials=credentials, texts=['ping']) | self._invoke(model=model, credentials=credentials, texts=['ping']) | ||||
| except InvokeAuthorizationError as e: | except InvokeAuthorizationError as e: | ||||
| parameter_rules=[] | parameter_rules=[] | ||||
| ) | ) | ||||
| return entity | |||||
| return entity |