|  |  | @@ -47,6 +47,7 @@ class Base(ABC): | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | class DefaultEmbedding(Base): | 
		
	
		
			
			|  |  |  | _model = None | 
		
	
		
			
			|  |  |  | _model_name = "" | 
		
	
		
			
			|  |  |  | _model_lock = threading.Lock() | 
		
	
		
			
			|  |  |  | def __init__(self, key, model_name, **kwargs): | 
		
	
		
			
			|  |  |  | """ | 
		
	
	
		
			
			|  |  | @@ -69,6 +70,7 @@ class DefaultEmbedding(Base): | 
		
	
		
			
			|  |  |  | DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), | 
		
	
		
			
			|  |  |  | query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", | 
		
	
		
			
			|  |  |  | use_fp16=torch.cuda.is_available()) | 
		
	
		
			
			|  |  |  | DefaultEmbedding._model_name = model_name | 
		
	
		
			
			|  |  |  | except Exception: | 
		
	
		
			
			|  |  |  | model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5", | 
		
	
		
			
			|  |  |  | local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), | 
		
	
	
		
			
			|  |  | @@ -77,6 +79,7 @@ class DefaultEmbedding(Base): | 
		
	
		
			
			|  |  |  | query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", | 
		
	
		
			
			|  |  |  | use_fp16=torch.cuda.is_available()) | 
		
	
		
			
			|  |  |  | self._model = DefaultEmbedding._model | 
		
	
		
			
			|  |  |  | self._model_name = DefaultEmbedding._model_name | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | def encode(self, texts: list): | 
		
	
		
			
			|  |  |  | batch_size = 16 | 
		
	
	
		
			
			|  |  | @@ -250,6 +253,8 @@ class OllamaEmbed(Base): | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | class FastEmbed(Base): | 
		
	
		
			
			|  |  |  | _model = None | 
		
	
		
			
			|  |  |  | _model_name = "" | 
		
	
		
			
			|  |  |  | _model_lock = threading.Lock() | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | def __init__( | 
		
	
		
			
			|  |  |  | self, | 
		
	
	
		
			
			|  |  | @@ -260,8 +265,20 @@ class FastEmbed(Base): | 
		
	
		
			
			|  |  |  | **kwargs, | 
		
	
		
			
			|  |  |  | ): | 
		
	
		
			
			|  |  |  | if not settings.LIGHTEN and not FastEmbed._model: | 
		
	
		
			
			|  |  |  | from fastembed import TextEmbedding | 
		
	
		
			
			|  |  |  | self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) | 
		
	
		
			
			|  |  |  | with FastEmbed._model_lock: | 
		
	
		
			
			|  |  |  | from fastembed import TextEmbedding | 
		
	
		
			
			|  |  |  | if not FastEmbed._model or model_name != FastEmbed._model_name: | 
		
	
		
			
			|  |  |  | try: | 
		
	
		
			
			|  |  |  | FastEmbed._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) | 
		
	
		
			
			|  |  |  | FastEmbed._model_name = model_name | 
		
	
		
			
			|  |  |  | except Exception: | 
		
	
		
			
			|  |  |  | cache_dir = snapshot_download(repo_id="BAAI/bge-small-en-v1.5", | 
		
	
		
			
			|  |  |  | local_dir=os.path.join(get_home_cache_dir(), | 
		
	
		
			
			|  |  |  | re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), | 
		
	
		
			
			|  |  |  | local_dir_use_symlinks=False) | 
		
	
		
			
			|  |  |  | FastEmbed._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) | 
		
	
		
			
			|  |  |  | self._model = FastEmbed._model | 
		
	
		
			
			|  |  |  | self._model_name = model_name | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | def encode(self, texts: list): | 
		
	
		
			
			|  |  |  | # Using the internal tokenizer to encode the texts and get the total |