Sfoglia il codice sorgente

Update embedding_model.py (#9083)

### What problem does this PR solve?

Reduce the logic scope for DefaultEmbedding

### Type of change

- [x] Refactoring
tags/v0.20.0
Stephen Hu 3 mesi fa
parent
commit
ba563f8095
Nessun account collegato all'indirizzo email del committer
1 ha cambiato i file con 8 aggiunte e 1 eliminazioni
  1. 8
    1
      rag/llm/embedding_model.py

+ 8
- 1
rag/llm/embedding_model.py Vedi File

@@ -60,7 +60,6 @@ class Base(ABC):

class DefaultEmbedding(Base):
_FACTORY_NAME = "BAAI"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
_model = None
_model_name = ""
_model_lock = threading.Lock()
@@ -78,9 +77,13 @@ class DefaultEmbedding(Base):

"""
if not settings.LIGHTEN:
input_cuda_visible_devices = None
with DefaultEmbedding._model_lock:
import torch
from FlagEmbedding import FlagModel
if "CUDA_VISIBLE_DEVICES" in os.environ:
input_cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # handle some issues with multiple GPUs when initializing the model

if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
try:
@@ -95,6 +98,10 @@ class DefaultEmbedding(Base):
repo_id="BAAI/bge-large-zh-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
)
DefaultEmbedding._model = FlagModel(model_dir, query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=torch.cuda.is_available())
finally:
if input_cuda_visible_devices:
# restore CUDA_VISIBLE_DEVICES
os.environ["CUDA_VISIBLE_DEVICES"] = input_cuda_visible_devices
self._model = DefaultEmbedding._model
self._model_name = DefaultEmbedding._model_name


Loading…
Annulla
Salva