### What problem does this PR solve? Fix ragflow may encounter an OOM (Out Of Memory) when there are a lot of conversations. #1288 ### Type of change - [ ] Bug Fix (non-breaking change which fixes an issue) Co-authored-by: zhuhao <zhuhao@linklogis.com>tags/v0.8.0
| @@ -15,6 +15,7 @@ | |||
| # | |||
| import re | |||
| from typing import Optional | |||
| import threading | |||
| import requests | |||
| from huggingface_hub import snapshot_download | |||
| from zhipuai import ZhipuAI | |||
| @@ -44,7 +45,7 @@ class Base(ABC): | |||
| class DefaultEmbedding(Base): | |||
| _model = None | |||
| _model_lock = threading.Lock() | |||
| def __init__(self, key, model_name, **kwargs): | |||
| """ | |||
| If you have trouble downloading HuggingFace models, -_^ this might help!! | |||
| @@ -58,17 +59,20 @@ class DefaultEmbedding(Base): | |||
| """ | |||
| if not DefaultEmbedding._model: | |||
| try: | |||
| self._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), | |||
| query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", | |||
| use_fp16=torch.cuda.is_available()) | |||
| except Exception as e: | |||
| model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5", | |||
| local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), | |||
| local_dir_use_symlinks=False) | |||
| self._model = FlagModel(model_dir, | |||
| query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", | |||
| use_fp16=torch.cuda.is_available()) | |||
| with DefaultEmbedding._model_lock: | |||
| if not DefaultEmbedding._model: | |||
| try: | |||
| DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), | |||
| query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", | |||
| use_fp16=torch.cuda.is_available()) | |||
| except Exception as e: | |||
| model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5", | |||
| local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), | |||
| local_dir_use_symlinks=False) | |||
| DefaultEmbedding._model = FlagModel(model_dir, | |||
| query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", | |||
| use_fp16=torch.cuda.is_available()) | |||
| self._model = DefaultEmbedding._model | |||
| def encode(self, texts: list, batch_size=32): | |||
| texts = [truncate(t, 2048) for t in texts] | |||
| @@ -14,6 +14,7 @@ | |||
| # limitations under the License. | |||
| # | |||
| import re | |||
| import threading | |||
| import requests | |||
| import torch | |||
| from FlagEmbedding import FlagReranker | |||
| @@ -37,7 +38,7 @@ class Base(ABC): | |||
| class DefaultRerank(Base): | |||
| _model = None | |||
| _model_lock = threading.Lock() | |||
| def __init__(self, key, model_name, **kwargs): | |||
| """ | |||
| If you have trouble downloading HuggingFace models, -_^ this might help!! | |||
| @@ -51,16 +52,16 @@ class DefaultRerank(Base): | |||
| """ | |||
| if not DefaultRerank._model: | |||
| try: | |||
| self._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), | |||
| use_fp16=torch.cuda.is_available()) | |||
| except Exception as e: | |||
| self._model = snapshot_download(repo_id=model_name, | |||
| local_dir=os.path.join(get_home_cache_dir(), | |||
| re.sub(r"^[a-zA-Z]+/", "", model_name)), | |||
| local_dir_use_symlinks=False) | |||
| self._model = FlagReranker(os.path.join(get_home_cache_dir(), model_name), | |||
| use_fp16=torch.cuda.is_available()) | |||
| with DefaultRerank._model_lock: | |||
| if not DefaultRerank._model: | |||
| try: | |||
| DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), use_fp16=torch.cuda.is_available()) | |||
| except Exception as e: | |||
| model_dir = snapshot_download(repo_id= model_name, | |||
| local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), | |||
| local_dir_use_symlinks=False) | |||
| DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available()) | |||
| self._model = DefaultRerank._model | |||
| def similarity(self, query: str, texts: list): | |||
| pairs = [(query,truncate(t, 2048)) for t in texts] | |||