### What problem does this PR solve? Fix ragflow may encounter an OOM (Out Of Memory) when there are a lot of conversations. #1288 ### Type of change - [ ] Bug Fix (non-breaking change which fixes an issue) Co-authored-by: zhuhao <zhuhao@linklogis.com>tags/v0.8.0
| # | # | ||||
| import re | import re | ||||
| from typing import Optional | from typing import Optional | ||||
| import threading | |||||
| import requests | import requests | ||||
| from huggingface_hub import snapshot_download | from huggingface_hub import snapshot_download | ||||
| from zhipuai import ZhipuAI | from zhipuai import ZhipuAI | ||||
| class DefaultEmbedding(Base): | class DefaultEmbedding(Base): | ||||
| _model = None | _model = None | ||||
| _model_lock = threading.Lock() | |||||
| def __init__(self, key, model_name, **kwargs): | def __init__(self, key, model_name, **kwargs): | ||||
| """ | """ | ||||
| If you have trouble downloading HuggingFace models, -_^ this might help!! | If you have trouble downloading HuggingFace models, -_^ this might help!! | ||||
| """ | """ | ||||
| if not DefaultEmbedding._model: | if not DefaultEmbedding._model: | ||||
| try: | |||||
| self._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), | |||||
| query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", | |||||
| use_fp16=torch.cuda.is_available()) | |||||
| except Exception as e: | |||||
| model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5", | |||||
| local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), | |||||
| local_dir_use_symlinks=False) | |||||
| self._model = FlagModel(model_dir, | |||||
| query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", | |||||
| use_fp16=torch.cuda.is_available()) | |||||
| with DefaultEmbedding._model_lock: | |||||
| if not DefaultEmbedding._model: | |||||
| try: | |||||
| DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), | |||||
| query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", | |||||
| use_fp16=torch.cuda.is_available()) | |||||
| except Exception as e: | |||||
| model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5", | |||||
| local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), | |||||
| local_dir_use_symlinks=False) | |||||
| DefaultEmbedding._model = FlagModel(model_dir, | |||||
| query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", | |||||
| use_fp16=torch.cuda.is_available()) | |||||
| self._model = DefaultEmbedding._model | |||||
| def encode(self, texts: list, batch_size=32): | def encode(self, texts: list, batch_size=32): | ||||
| texts = [truncate(t, 2048) for t in texts] | texts = [truncate(t, 2048) for t in texts] |
| # limitations under the License. | # limitations under the License. | ||||
| # | # | ||||
| import re | import re | ||||
| import threading | |||||
| import requests | import requests | ||||
| import torch | import torch | ||||
| from FlagEmbedding import FlagReranker | from FlagEmbedding import FlagReranker | ||||
| class DefaultRerank(Base): | class DefaultRerank(Base): | ||||
| _model = None | _model = None | ||||
| _model_lock = threading.Lock() | |||||
| def __init__(self, key, model_name, **kwargs): | def __init__(self, key, model_name, **kwargs): | ||||
| """ | """ | ||||
| If you have trouble downloading HuggingFace models, -_^ this might help!! | If you have trouble downloading HuggingFace models, -_^ this might help!! | ||||
| """ | """ | ||||
| if not DefaultRerank._model: | if not DefaultRerank._model: | ||||
| try: | |||||
| self._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), | |||||
| use_fp16=torch.cuda.is_available()) | |||||
| except Exception as e: | |||||
| self._model = snapshot_download(repo_id=model_name, | |||||
| local_dir=os.path.join(get_home_cache_dir(), | |||||
| re.sub(r"^[a-zA-Z]+/", "", model_name)), | |||||
| local_dir_use_symlinks=False) | |||||
| self._model = FlagReranker(os.path.join(get_home_cache_dir(), model_name), | |||||
| use_fp16=torch.cuda.is_available()) | |||||
| with DefaultRerank._model_lock: | |||||
| if not DefaultRerank._model: | |||||
| try: | |||||
| DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), use_fp16=torch.cuda.is_available()) | |||||
| except Exception as e: | |||||
| model_dir = snapshot_download(repo_id= model_name, | |||||
| local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), | |||||
| local_dir_use_symlinks=False) | |||||
| DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available()) | |||||
| self._model = DefaultRerank._model | |||||
| def similarity(self, query: str, texts: list): | def similarity(self, query: str, texts: list): | ||||
| pairs = [(query,truncate(t, 2048)) for t in texts] | pairs = [(query,truncate(t, 2048)) for t in texts] |