|
|
|
@@ -39,6 +39,7 @@ class Base(ABC): |
|
|
|
class DefaultRerank(Base): |
|
|
|
_model = None |
|
|
|
_model_lock = threading.Lock() |
|
|
|
|
|
|
|
def __init__(self, key, model_name, **kwargs): |
|
|
|
""" |
|
|
|
If you have trouble downloading HuggingFace models, -_^ this might help!! |
|
|
|
@@ -102,19 +103,24 @@ class JinaRerank(Base): |
|
|
|
|
|
|
|
class YoudaoRerank(DefaultRerank): |
|
|
|
_model = None |
|
|
|
_model_lock = threading.Lock() |
|
|
|
|
|
|
|
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs): |
|
|
|
from BCEmbedding import RerankerModel |
|
|
|
if not YoudaoRerank._model: |
|
|
|
try: |
|
|
|
print("LOADING BCE...") |
|
|
|
YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join( |
|
|
|
get_home_cache_dir(), |
|
|
|
re.sub(r"^[a-zA-Z]+/", "", model_name))) |
|
|
|
except Exception as e: |
|
|
|
YoudaoRerank._model = RerankerModel( |
|
|
|
model_name_or_path=model_name.replace( |
|
|
|
"maidalun1020", "InfiniFlow")) |
|
|
|
with YoudaoRerank._model_lock: |
|
|
|
if not YoudaoRerank._model: |
|
|
|
try: |
|
|
|
print("LOADING BCE...") |
|
|
|
YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join( |
|
|
|
get_home_cache_dir(), |
|
|
|
re.sub(r"^[a-zA-Z]+/", "", model_name))) |
|
|
|
except Exception as e: |
|
|
|
YoudaoRerank._model = RerankerModel( |
|
|
|
model_name_or_path=model_name.replace( |
|
|
|
"maidalun1020", "InfiniFlow")) |
|
|
|
|
|
|
|
self._model = YoudaoRerank._model |
|
|
|
|
|
|
|
def similarity(self, query: str, texts: list): |
|
|
|
pairs = [(query, truncate(t, self._model.max_length)) for t in texts] |