瀏覽代碼

make sure the models will not be load twice (#422)

### What problem does this PR solve?

#381 
### Type of change

- [x] Refactoring
tags/v0.3.0
KevinHuSh 1 年之前
父節點
當前提交
453c29170f
沒有連結到貢獻者的電子郵件帳戶。

+ 4
- 5
api/apps/api_app.py 查看文件

@@ -105,8 +105,8 @@ def stats():
res = {
"pv": [(o["dt"], o["pv"]) for o in objs],
"uv": [(o["dt"], o["uv"]) for o in objs],
"speed": [(o["dt"], o["tokens"]/o["duration"]) for o in objs],
"tokens": [(o["dt"], o["tokens"]/1000.) for o in objs],
"speed": [(o["dt"], float(o["tokens"])/float(o["duration"])) for o in objs],
"tokens": [(o["dt"], float(o["tokens"])/1000.) for o in objs],
"round": [(o["dt"], o["round"]) for o in objs],
"thumb_up": [(o["dt"], o["thumb_up"]) for o in objs]
}
@@ -115,8 +115,7 @@ def stats():
return server_error_response(e)
@manager.route('/new_conversation', methods=['POST'])
@validate_request("user_id")
@manager.route('/new_conversation', methods=['GET'])
def set_conversation():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
@@ -131,7 +130,7 @@ def set_conversation():
conv = {
"id": get_uuid(),
"dialog_id": dia.id,
"user_id": req["user_id"],
"user_id": request.args.get("user_id", ""),
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
}
API4ConversationService.save(**conv)

+ 1
- 1
api/db/db_models.py 查看文件

@@ -629,7 +629,7 @@ class Document(DataBaseModel):
max_length=128,
null=False,
default="local",
help_text="where dose this document from")
help_text="where dose this document come from")
type = CharField(max_length=32, null=False, help_text="file extension")
created_by = CharField(
max_length=32,

+ 3
- 1
deepdoc/parser/pdf_parser.py 查看文件

@@ -43,7 +43,9 @@ class HuParser:
model_dir, "updown_concat_xgb.model"))
except Exception as e:
model_dir = snapshot_download(
repo_id="InfiniFlow/text_concat_xgb_v1.0")
repo_id="InfiniFlow/text_concat_xgb_v1.0",
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
local_dir_use_symlinks=False)
self.updown_cnt_mdl.load_model(os.path.join(
model_dir, "updown_concat_xgb.model"))


+ 3
- 1
deepdoc/vision/layout_recognizer.py 查看文件

@@ -43,7 +43,9 @@ class LayoutRecognizer(Recognizer):
"rag/res/deepdoc")
super().__init__(self.labels, domain, model_dir)
except Exception as e:
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
local_dir_use_symlinks=False)
super().__init__(self.labels, domain, model_dir)
self.garbage_layouts = ["footer", "header", "reference"]

+ 3
- 1
deepdoc/vision/ocr.py 查看文件

@@ -486,7 +486,9 @@ class OCR(object):
self.text_detector = TextDetector(model_dir)
self.text_recognizer = TextRecognizer(model_dir)
except Exception as e:
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
local_dir_use_symlinks=False)
self.text_detector = TextDetector(model_dir)
self.text_recognizer = TextRecognizer(model_dir)


+ 3
- 1
deepdoc/vision/recognizer.py 查看文件

@@ -41,7 +41,9 @@ class Recognizer(object):
"rag/res/deepdoc")
model_file_path = os.path.join(model_dir, task_name + ".onnx")
if not os.path.exists(model_file_path):
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
local_dir_use_symlinks=False)
model_file_path = os.path.join(model_dir, task_name + ".onnx")
else:
model_file_path = os.path.join(model_dir, task_name + ".onnx")

+ 3
- 1
deepdoc/vision/table_structure_recognizer.py 查看文件

@@ -39,7 +39,9 @@ class TableStructureRecognizer(Recognizer):
get_project_base_directory(),
"rag/res/deepdoc"))
except Exception as e:
super().__init__(self.labels, "tsr", snapshot_download(repo_id="InfiniFlow/deepdoc"))
super().__init__(self.labels, "tsr", snapshot_download(repo_id="InfiniFlow/deepdoc",
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
local_dir_use_symlinks=False))
def __call__(self, images, thr=0.2):
tbls = super().__call__(images, thr)

+ 6
- 1
rag/llm/embedding_model.py 查看文件

@@ -14,6 +14,8 @@
# limitations under the License.
#
from typing import Optional

from huggingface_hub import snapshot_download
from zhipuai import ZhipuAI
import os
from abc import ABC
@@ -35,7 +37,10 @@ try:
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available())
except Exception as e:
flag_model = FlagModel("BAAI/bge-large-zh-v1.5",
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
local_dir=os.path.join(get_project_base_directory(), "rag/res/bge-large-zh-v1.5"),
local_dir_use_symlinks=False)
flag_model = FlagModel(model_dir,
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available())


Loading…
取消
儲存