|
|
|
@@ -26,19 +26,17 @@ from FlagEmbedding import FlagModel |
|
|
|
import torch |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from api.utils.file_utils import get_project_base_directory |
|
|
|
from api.utils.file_utils import get_project_base_directory, get_home_cache_dir |
|
|
|
from rag.utils import num_tokens_from_string |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
flag_model = FlagModel(os.path.join( |
|
|
|
get_project_base_directory(), |
|
|
|
"rag/res/bge-large-zh-v1.5"), |
|
|
|
flag_model = FlagModel(os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"), |
|
|
|
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", |
|
|
|
use_fp16=torch.cuda.is_available()) |
|
|
|
except Exception as e: |
|
|
|
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5", |
|
|
|
local_dir=os.path.join(get_project_base_directory(), "rag/res/bge-large-zh-v1.5"), |
|
|
|
local_dir=os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"), |
|
|
|
local_dir_use_symlinks=False) |
|
|
|
flag_model = FlagModel(model_dir, |
|
|
|
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", |