Procházet zdrojové kódy

fix: replace os.path.join with yarl (#2690)

tags/0.5.9
Yeuoly před 1 rokem
rodič
revize
95733796f0
Žádný účet není propojen s e-mailovou adresou tvůrce revize

+ 5
- 3
api/core/model_runtime/model_providers/xinference/xinference_helper.py Zobrazit soubor

from os import path
from threading import Lock from threading import Lock
from time import time from time import time


from requests.adapters import HTTPAdapter from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, MissingSchema, Timeout from requests.exceptions import ConnectionError, MissingSchema, Timeout
from requests.sessions import Session from requests.sessions import Session
from yarl import URL




class XinferenceModelExtraParameter: class XinferenceModelExtraParameter:
get xinference model extra parameter like model_format and model_handle_type get xinference model extra parameter like model_format and model_handle_type
""" """


url = path.join(server_url, 'v1/models', model_uid)
if not model_uid or not model_uid.strip() or not server_url or not server_url.strip():
raise RuntimeError('model_uid is empty')

url = str(URL(server_url) / 'v1' / 'models' / model_uid)


# this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
session = Session() session = Session()
response = session.get(url, timeout=10) response = session.get(url, timeout=10)
except (MissingSchema, ConnectionError, Timeout) as e: except (MissingSchema, ConnectionError, Timeout) as e:
raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}') raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}')

if response.status_code != 200: if response.status_code != 200:
raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}') raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}')

+ 2
- 1
api/requirements.txt Zobrazit soubor

gmpy2~=2.1.5 gmpy2~=2.1.5
numexpr~=2.9.0 numexpr~=2.9.0
duckduckgo-search==4.4.3 duckduckgo-search==4.4.3
arxiv==2.1.0
arxiv==2.1.0
yarl~=1.9.4

+ 41
- 39
api/tests/integration_tests/model_runtime/__mock/xinference.py Zobrazit soubor

response = Response() response = Response()
if 'v1/models/' in url: if 'v1/models/' in url:
# get model uid # get model uid
model_uid = url.split('/')[-1]
model_uid = url.split('/')[-1] or ''
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \ if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
model_uid not in ['generate', 'chat', 'embedding', 'rerank']: model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
response.status_code = 404 response.status_code = 404
response._content = b'{}'
return response return response


# check if url is valid # check if url is valid
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url): if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
response.status_code = 404 response.status_code = 404
response._content = b'{}'
return response return response
if model_uid in ['generate', 'chat']: if model_uid in ['generate', 'chat']:
response.status_code = 200 response.status_code = 200
response._content = b'''{ response._content = b'''{
"model_type": "LLM",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "chatglm3-6b",
"model_lang": [
"en"
],
"model_ability": [
"generate",
"chat"
],
"model_description": "latest chatglm3",
"model_format": "pytorch",
"model_size_in_billions": 7,
"quantization": "none",
"model_hub": "huggingface",
"revision": null,
"context_length": 2048,
"replica": 1
}'''
"model_type": "LLM",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "chatglm3-6b",
"model_lang": [
"en"
],
"model_ability": [
"generate",
"chat"
],
"model_description": "latest chatglm3",
"model_format": "pytorch",
"model_size_in_billions": 7,
"quantization": "none",
"model_hub": "huggingface",
"revision": null,
"context_length": 2048,
"replica": 1
}'''
return response return response
elif model_uid == 'embedding': elif model_uid == 'embedding':
response.status_code = 200 response.status_code = 200
response._content = b'''{ response._content = b'''{
"model_type": "embedding",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "bge",
"model_lang": [
"en"
],
"revision": null,
"max_tokens": 512
}'''
"model_type": "embedding",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "bge",
"model_lang": [
"en"
],
"revision": null,
"max_tokens": 512
}'''
return response return response
elif 'v1/cluster/auth' in url: elif 'v1/cluster/auth' in url:
response.status_code = 200 response.status_code = 200
response._content = b'''{ response._content = b'''{
"auth": true
}'''
"auth": true
}'''
return response return response
def _check_cluster_authenticated(self): def _check_cluster_authenticated(self):

Načítá se…
Zrušit
Uložit