|
|
|
@@ -32,11 +32,6 @@ from api.utils.file_utils import get_home_cache_dir |
|
|
|
from api.utils.log_utils import log_exception |
|
|
|
from rag.utils import num_tokens_from_string, truncate |
|
|
|
|
|
|
|
|
|
|
|
def sigmoid(x): |
|
|
|
return 1 / (1 + np.exp(-x)) |
|
|
|
|
|
|
|
|
|
|
|
class Base(ABC): |
|
|
|
def __init__(self, key, model_name): |
|
|
|
pass |
|
|
|
@@ -133,10 +128,9 @@ class DefaultRerank(Base): |
|
|
|
|
|
|
|
def _compute_batch_scores(self, batch_pairs, max_length=None): |
|
|
|
if max_length is None: |
|
|
|
scores = self._model.compute_score(batch_pairs) |
|
|
|
scores = self._model.compute_score(batch_pairs, normalize=True) |
|
|
|
else: |
|
|
|
scores = self._model.compute_score(batch_pairs, max_length=max_length) |
|
|
|
scores = sigmoid(np.array(scores)) |
|
|
|
scores = self._model.compute_score(batch_pairs, max_length=max_length, normalize=True) |
|
|
|
if not isinstance(scores, Iterable): |
|
|
|
scores = [scores] |
|
|
|
return scores |