|
|
|
@@ -60,7 +60,6 @@ class Base(ABC): |
|
|
|
|
|
|
|
class DefaultEmbedding(Base): |
|
|
|
_FACTORY_NAME = "BAAI" |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
|
|
|
_model = None |
|
|
|
_model_name = "" |
|
|
|
_model_lock = threading.Lock() |
|
|
|
@@ -78,9 +77,13 @@ class DefaultEmbedding(Base): |
|
|
|
|
|
|
|
""" |
|
|
|
if not settings.LIGHTEN: |
|
|
|
input_cuda_visible_devices = None |
|
|
|
with DefaultEmbedding._model_lock: |
|
|
|
import torch |
|
|
|
from FlagEmbedding import FlagModel |
|
|
|
if "CUDA_VISIBLE_DEVICES" in os.environ: |
|
|
|
input_cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"] |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # handle some issues with multiple GPUs when initializing the model |
|
|
|
|
|
|
|
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name: |
|
|
|
try: |
|
|
|
@@ -95,6 +98,10 @@ class DefaultEmbedding(Base): |
|
|
|
repo_id="BAAI/bge-large-zh-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False |
|
|
|
) |
|
|
|
DefaultEmbedding._model = FlagModel(model_dir, query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=torch.cuda.is_available()) |
|
|
|
finally: |
|
|
|
if input_cuda_visible_devices: |
|
|
|
# restore CUDA_VISIBLE_DEVICES |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = input_cuda_visible_devices |
|
|
|
self._model = DefaultEmbedding._model |
|
|
|
self._model_name = DefaultEmbedding._model_name |
|
|
|
|