浏览代码

enhance:speedup xinference embedding & rerank (#3587)

tags/0.6.4
呆萌闷油瓶 1年前
父节点
当前提交
4365843c20
没有帐户链接到提交者的电子邮件

+ 17
- 12
api/core/model_runtime/model_providers/xinference/rerank/rerank.py 查看文件

if credentials['server_url'].endswith('/'): if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1] credentials['server_url'] = credentials['server_url'][:-1]


# initialize client
client = Client(
base_url=credentials['server_url']
)

xinference_client = client.get_model(model_uid=credentials['model_uid'])

if not isinstance(xinference_client, RESTfulRerankModelHandle):
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a rerank model')

response = xinference_client.rerank(
handle = RESTfulRerankModelHandle(credentials['model_uid'], credentials['server_url'],auth_headers={})
response = handle.rerank(
documents=docs, documents=docs,
query=query, query=query,
top_n=top_n, top_n=top_n,
try: try:
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")

if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]

# initialize client
client = Client(
base_url=credentials['server_url']
)

xinference_client = client.get_model(model_uid=credentials['model_uid'])

if not isinstance(xinference_client, RESTfulRerankModelHandle):
raise InvokeBadRequestError(
'please check model type, the model you want to invoke is not a rerank model')
self.invoke( self.invoke(
model=model, model=model,
parameter_rules=[] parameter_rules=[]
) )


return entity
return entity

+ 14
- 11
api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py 查看文件

if server_url.endswith('/'): if server_url.endswith('/'):
server_url = server_url[:-1] server_url = server_url[:-1]


client = Client(base_url=server_url)
try:
handle = client.get_model(model_uid=model_uid)
except RuntimeError as e:
raise InvokeAuthorizationError(e)

if not isinstance(handle, RESTfulEmbeddingModelHandle):
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model')

try: try:
handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers={})
embeddings = handle.create_embedding(input=texts) embeddings = handle.create_embedding(input=texts)
except RuntimeError as e: except RuntimeError as e:
raise InvokeServerUnavailableError(e) raise InvokeServerUnavailableError(e)


if extra_args.max_tokens: if extra_args.max_tokens:
credentials['max_tokens'] = extra_args.max_tokens credentials['max_tokens'] = extra_args.max_tokens
if server_url.endswith('/'):
server_url = server_url[:-1]

client = Client(base_url=server_url)
try:
handle = client.get_model(model_uid=model_uid)
except RuntimeError as e:
raise InvokeAuthorizationError(e)

if not isinstance(handle, RESTfulEmbeddingModelHandle):
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model')


self._invoke(model=model, credentials=credentials, texts=['ping']) self._invoke(model=model, credentials=credentials, texts=['ping'])
except InvokeAuthorizationError as e: except InvokeAuthorizationError as e:
parameter_rules=[] parameter_rules=[]
) )


return entity
return entity

正在加载...
取消
保存