소스 검색

support gpt-4o (#773)

### What problem does this PR solve?
#771 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
tags/v0.6.0
KevinHuSh 1 년 전
부모
커밋
aa1c915d6e
No account linked to committer's email address
6개의 변경된 파일20개의 추가작업 그리고 7개의 파일을 삭제
  1. 1
    1
      api/apps/llm_app.py
  2. 6
    0
      api/db/init_data.py
  3. 1
    1
      api/db/services/llm_service.py
  4. 6
    0
      api/settings.py
  5. 2
    1
      rag/llm/__init__.py
  6. 4
    4
      rag/svr/task_executor.py

+ 1
- 1
api/apps/llm_app.py 파일 보기

res = {} res = {}
for m in llms: for m in llms:
if model_type and m["model_type"] != model_type:
if model_type and m["model_type"].find(model_type)<0:
continue continue
if m["fid"] not in res: if m["fid"] not in res:
res[m["fid"]] = [] res[m["fid"]] = []

+ 6
- 0
api/db/init_data.py 파일 보기

llm_infos = [ llm_infos = [
# ---------------------- OpenAI ------------------------ # ---------------------- OpenAI ------------------------
{ {
"fid": factory_infos[0]["name"],
"llm_name": "gpt-4o",
"tags": "LLM,CHAT,128K",
"max_tokens": 128000,
"model_type": LLMType.CHAT.value + "," + LLMType.IMAGE2TEXT.value
}, {
"fid": factory_infos[0]["name"], "fid": factory_infos[0]["name"],
"llm_name": "gpt-3.5-turbo", "llm_name": "gpt-3.5-turbo",
"tags": "LLM,CHAT,4K", "tags": "LLM,CHAT,4K",

+ 1
- 1
api/db/services/llm_service.py 파일 보기

if not model_config: if not model_config:
if llm_type == LLMType.EMBEDDING.value: if llm_type == LLMType.EMBEDDING.value:
llm = LLMService.query(llm_name=llm_name) llm = LLMService.query(llm_name=llm_name)
if llm and llm[0].fid in ["Youdao", "FastEmbed"]:
if llm and llm[0].fid in ["Youdao", "FastEmbed", "DeepSeek"]:
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""} model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""}
if not model_config: if not model_config:
if llm_name == "flag-embedding": if llm_name == "flag-embedding":

+ 6
- 0
api/settings.py 파일 보기

"embedding_model": "", "embedding_model": "",
"image2text_model": "", "image2text_model": "",
"asr_model": "", "asr_model": "",
},
"DeepSeek": {
"chat_model": "deepseek-chat",
"embedding_model": "BAAI/bge-large-zh-v1.5",
"image2text_model": "",
"asr_model": "",
} }
} }
LLM = get_base_config("user_default_llm", {}) LLM = get_base_config("user_default_llm", {})

+ 2
- 1
rag/llm/__init__.py 파일 보기

"Tongyi-Qianwen": DefaultEmbedding, #QWenEmbed, "Tongyi-Qianwen": DefaultEmbedding, #QWenEmbed,
"ZHIPU-AI": ZhipuEmbed, "ZHIPU-AI": ZhipuEmbed,
"FastEmbed": FastEmbed, "FastEmbed": FastEmbed,
"Youdao": YoudaoEmbed
"Youdao": YoudaoEmbed,
"DeepSeek": DefaultEmbedding
} }





+ 4
- 4
rag/svr/task_executor.py 파일 보기



st = timer() st = timer()
cks = build(r) cks = build(r)
cron_logger.info("Build chunks({}): {}".format(r["name"], timer()-st))
cron_logger.info("Build chunks({}): {:.2f}".format(r["name"], timer()-st))
if cks is None: if cks is None:
continue continue
if not cks: if not cks:
callback(-1, "Embedding error:{}".format(str(e))) callback(-1, "Embedding error:{}".format(str(e)))
cron_logger.error(str(e)) cron_logger.error(str(e))
tk_count = 0 tk_count = 0
cron_logger.info("Embedding elapsed({}): {}".format(r["name"], timer()-st))
cron_logger.info("Embedding elapsed({:.2f}): {}".format(r["name"], timer()-st))


callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer()-st)) callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer()-st))
init_kb(r) init_kb(r)
if b % 128 == 0: if b % 128 == 0:
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="") callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")


cron_logger.info("Indexing elapsed({}): {}".format(r["name"], timer()-st))
cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer()-st))
if es_r: if es_r:
callback(-1, "Index failure!") callback(-1, "Index failure!")
ELASTICSEARCH.deleteByQuery( ELASTICSEARCH.deleteByQuery(
DocumentService.increment_chunk_num( DocumentService.increment_chunk_num(
r["doc_id"], r["kb_id"], tk_count, chunk_count, 0) r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
cron_logger.info( cron_logger.info(
"Chunk doc({}), token({}), chunks({}), elapsed:{}".format(
"Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format(
r["id"], tk_count, len(cks), timer()-st)) r["id"], tk_count, len(cks), timer()-st))





Loading…
취소
저장