Quellcode durchsuchen

add Moonshot, debug my_llm (#126)

tags/v0.1.0
KevinHuSh vor 1 Jahr
Ursprung
Commit
2447f95629
Es ist kein Account mit der E-Mail-Adresse des Committers verbunden

+ 1
- 1
api/apps/conversation_app.py Datei anzeigen

@@ -309,13 +309,13 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
# compose markdown table
clmns = "|"+"|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|原文|" if docid_idx and docid_idx else "|")
line = "|"+"|".join(["------" for _ in range(len(clmn_idx))]) + ("|------|" if docid_idx and docid_idx else "")
line = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}\|", "|", line)
rows = ["|"+"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
if not docid_idx or not docnm_idx:
chat_logger.warning("SQL missing field: " + sql)
return "\n".join([clmns, line, "\n".join(rows)]), []
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
docid_idx = list(docid_idx)[0]
docnm_idx = list(docnm_idx)[0]
return "\n".join([clmns, line, rows]), [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]]

+ 11
- 7
api/apps/llm_app.py Datei anzeigen

@@ -39,36 +39,40 @@ def factories():
def set_api_key():
req = request.json
# test if api key works
chat_passed = False
factory = req["llm_factory"]
msg = ""
for llm in LLMService.query(fid=req["llm_factory"]):
for llm in LLMService.query(fid=factory):
if llm.model_type == LLMType.EMBEDDING.value:
mdl = EmbeddingModel[req["llm_factory"]](
mdl = EmbeddingModel[factory](
req["api_key"], llm.llm_name)
try:
arr, tc = mdl.encode(["Test if the api key is available"])
if len(arr[0]) == 0 or tc ==0: raise Exception("Fail")
except Exception as e:
msg += f"\nFail to access embedding model({llm.llm_name}) using this api key."
elif llm.model_type == LLMType.CHAT.value:
mdl = ChatModel[req["llm_factory"]](
elif not chat_passed and llm.model_type == LLMType.CHAT.value:
mdl = ChatModel[factory](
req["api_key"], llm.llm_name)
try:
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
if not tc: raise Exception(m)
chat_passed = True
except Exception as e:
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(e)
if msg: return get_data_error_result(retmsg=msg)
llm = {
"tenant_id": current_user.id,
"llm_factory": req["llm_factory"],
"api_key": req["api_key"]
}
for n in ["model_type", "llm_name"]:
if n in req: llm[n] = req[n]
TenantLLMService.filter_update([TenantLLM.tenant_id==llm["tenant_id"], TenantLLM.llm_factory==llm["llm_factory"]], llm)
if not TenantLLMService.filter_update([TenantLLM.tenant_id==current_user.id, TenantLLM.llm_factory==factory], llm):
for llm in LLMService.query(fid=factory):
TenantLLMService.save(tenant_id=current_user.id, llm_factory=factory, llm_name=llm.llm_name, model_type=llm.model_type, api_key=req["api_key"])
return get_json_result(data=True)

+ 1
- 1
api/db/db_models.py Datei anzeigen

@@ -429,7 +429,7 @@ class LLMFactories(DataBaseModel):
class LLM(DataBaseModel):
# LLMs dictionary
llm_name = CharField(max_length=128, null=False, help_text="LLM name", index=True)
llm_name = CharField(max_length=128, null=False, help_text="LLM name", index=True, primary_key=True)
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
fid = CharField(max_length=128, null=False, help_text="LLM factory id")
max_tokens = IntegerField(default=0)

+ 47
- 38
api/db/init_data.py Datei anzeigen

@@ -73,41 +73,41 @@ def init_superuser():
print("\33[91m【ERROR】\33[0m:", " '{}' dosen't work!".format(tenant["embd_id"]))
factory_infos = [{
"name": "OpenAI",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},{
"name": "通义千问",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},{
"name": "智谱AI",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},
{
"name": "Local",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},{
"name": "Moonshot",
"logo": "",
"tags": "LLM,TEXT EMBEDDING",
"status": "1",
}
# {
# "name": "文心一言",
# "logo": "",
# "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
# "status": "1",
# },
]
def init_llm_factory():
factory_infos = [{
"name": "OpenAI",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},{
"name": "通义千问",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},{
"name": "智谱AI",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},
{
"name": "Local",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},{
"name": "Moonshot",
"logo": "",
"tags": "LLM,TEXT EMBEDDING",
"status": "1",
}
# {
# "name": "文心一言",
# "logo": "",
# "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
# "status": "1",
# },
]
llm_infos = [
# ---------------------- OpenAI ------------------------
{
@@ -260,21 +260,30 @@ def init_llm_factory():
},
]
for info in factory_infos:
LLMFactoriesService.save(**info)
try:
LLMFactoriesService.save(**info)
except Exception as e:
pass
for info in llm_infos:
LLMService.save(**info)
try:
LLMService.save(**info)
except Exception as e:
pass
def init_web_data():
start_time = time.time()
if not LLMService.get_all().count():init_llm_factory()
if LLMFactoriesService.get_all().count() != len(factory_infos):
init_llm_factory()
if not UserService.get_all().count():
init_superuser()
print("init web data success:{}".format(time.time() - start_time))
if __name__ == '__main__':
init_web_db()
init_web_data()
init_web_data()
add_tenant_llm()

+ 1
- 1
api/db/services/llm_service.py Datei anzeigen

@@ -53,7 +53,7 @@ class TenantLLMService(CommonService):
cls.model.used_tokens
]
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
cls.model.tenant_id == tenant_id).dicts()
cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
return list(objs)

+ 15
- 0
rag/llm/chat_model.py Datei anzeigen

@@ -54,6 +54,21 @@ class MoonshotChat(GptTurbo):
self.client = OpenAI(api_key=key, base_url="https://api.moonshot.cn/v1",)
self.model_name = model_name

def chat(self, system, history, gen_conf):
if system: history.insert(0, {"role": "system", "content": system})
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
**gen_conf)
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return ans, response.usage.completion_tokens
except openai.APIError as e:
return "**ERROR**: "+str(e), 0


from dashscope import Generation
class QWenChat(Base):

Laden…
Abbrechen
Speichern