|
|
|
|
|
|
|
|
def set_api_key():
|
|
|
def set_api_key():
|
|
|
req = request.json
|
|
|
req = request.json
|
|
|
# test if api key works
|
|
|
# test if api key works
|
|
|
chat_passed = False
|
|
|
|
|
|
|
|
|
chat_passed, embd_passed, rerank_passed = False, False, False
|
|
|
factory = req["llm_factory"]
|
|
|
factory = req["llm_factory"]
|
|
|
msg = ""
|
|
|
msg = ""
|
|
|
for llm in LLMService.query(fid=factory):
|
|
|
for llm in LLMService.query(fid=factory):
|
|
|
if llm.model_type == LLMType.EMBEDDING.value:
|
|
|
|
|
|
|
|
|
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
|
|
|
mdl = EmbeddingModel[factory](
|
|
|
mdl = EmbeddingModel[factory](
|
|
|
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
|
|
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
|
|
try:
|
|
|
try:
|
|
|
arr, tc = mdl.encode(["Test if the api key is available"])
|
|
|
arr, tc = mdl.encode(["Test if the api key is available"])
|
|
|
if len(arr[0]) == 0 or tc == 0:
|
|
|
if len(arr[0]) == 0 or tc == 0:
|
|
|
raise Exception("Fail")
|
|
|
raise Exception("Fail")
|
|
|
|
|
|
embd_passed = True
|
|
|
except Exception as e:
|
|
|
except Exception as e:
|
|
|
msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." + str(e)
|
|
|
msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." + str(e)
|
|
|
elif not chat_passed and llm.model_type == LLMType.CHAT.value:
|
|
|
elif not chat_passed and llm.model_type == LLMType.CHAT.value:
|
|
|
|
|
|
|
|
|
"temperature": 0.9})
|
|
|
"temperature": 0.9})
|
|
|
if not tc:
|
|
|
if not tc:
|
|
|
raise Exception(m)
|
|
|
raise Exception(m)
|
|
|
chat_passed = True
|
|
|
|
|
|
except Exception as e:
|
|
|
except Exception as e:
|
|
|
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
|
|
|
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
|
|
|
e)
|
|
|
e)
|
|
|
elif llm.model_type == LLMType.RERANK:
|
|
|
|
|
|
|
|
|
chat_passed = True
|
|
|
|
|
|
elif not rerank_passed and llm.model_type == LLMType.RERANK:
|
|
|
mdl = RerankModel[factory](
|
|
|
mdl = RerankModel[factory](
|
|
|
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
|
|
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
|
|
try:
|
|
|
try:
|
|
|
m, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
|
|
|
|
|
|
if len(arr[0]) == 0 or tc == 0:
|
|
|
|
|
|
|
|
|
arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
|
|
|
|
|
|
if len(arr) == 0 or tc == 0:
|
|
|
raise Exception("Fail")
|
|
|
raise Exception("Fail")
|
|
|
except Exception as e:
|
|
|
except Exception as e:
|
|
|
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
|
|
|
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
|
|
|
e)
|
|
|
e)
|
|
|
|
|
|
rerank_passed = True
|
|
|
|
|
|
|
|
|
if msg:
|
|
|
if msg:
|
|
|
return get_data_error_result(retmsg=msg)
|
|
|
return get_data_error_result(retmsg=msg)
|