|
|
|
@@ -47,6 +47,7 @@ class Base(ABC): |
|
|
|
|
|
|
|
class DefaultEmbedding(Base): |
|
|
|
_model = None |
|
|
|
_model_name = "" |
|
|
|
_model_lock = threading.Lock() |
|
|
|
def __init__(self, key, model_name, **kwargs): |
|
|
|
""" |
|
|
|
@@ -69,6 +70,7 @@ class DefaultEmbedding(Base): |
|
|
|
DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), |
|
|
|
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", |
|
|
|
use_fp16=torch.cuda.is_available()) |
|
|
|
DefaultEmbedding._model_name = model_name |
|
|
|
except Exception: |
|
|
|
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5", |
|
|
|
local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), |
|
|
|
@@ -77,6 +79,7 @@ class DefaultEmbedding(Base): |
|
|
|
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", |
|
|
|
use_fp16=torch.cuda.is_available()) |
|
|
|
self._model = DefaultEmbedding._model |
|
|
|
self._model_name = DefaultEmbedding._model_name |
|
|
|
|
|
|
|
def encode(self, texts: list): |
|
|
|
batch_size = 16 |
|
|
|
@@ -250,6 +253,8 @@ class OllamaEmbed(Base): |
|
|
|
|
|
|
|
class FastEmbed(Base): |
|
|
|
_model = None |
|
|
|
_model_name = "" |
|
|
|
_model_lock = threading.Lock() |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
@@ -260,8 +265,20 @@ class FastEmbed(Base): |
|
|
|
**kwargs, |
|
|
|
): |
|
|
|
if not settings.LIGHTEN and not FastEmbed._model: |
|
|
|
from fastembed import TextEmbedding |
|
|
|
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) |
|
|
|
with FastEmbed._model_lock: |
|
|
|
from fastembed import TextEmbedding |
|
|
|
if not FastEmbed._model or model_name != FastEmbed._model_name: |
|
|
|
try: |
|
|
|
FastEmbed._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) |
|
|
|
FastEmbed._model_name = model_name |
|
|
|
except Exception: |
|
|
|
cache_dir = snapshot_download(repo_id="BAAI/bge-small-en-v1.5", |
|
|
|
local_dir=os.path.join(get_home_cache_dir(), |
|
|
|
re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), |
|
|
|
local_dir_use_symlinks=False) |
|
|
|
FastEmbed._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) |
|
|
|
self._model = FastEmbed._model |
|
|
|
self._model_name = model_name |
|
|
|
|
|
|
|
def encode(self, texts: list): |
|
|
|
# Using the internal tokenizer to encode the texts and get the total |