浏览代码

Refactor improve codes for ranker (#8936)

### What problem does this PR solve?
Use the normalize method directly

### Type of change

- [x] Refactoring
tags/v0.20.0
Stephen Hu 3 个月前
父节点
当前提交
46caf6ae72
没有帐户链接到提交者的电子邮件
共有 1 个文件被更改,包括 2 次插入8 次删除
  1. 2
    8
      rag/llm/rerank_model.py

+ 2
- 8
rag/llm/rerank_model.py 查看文件

from api.utils.log_utils import log_exception from api.utils.log_utils import log_exception
from rag.utils import num_tokens_from_string, truncate from rag.utils import num_tokens_from_string, truncate



def sigmoid(x):
return 1 / (1 + np.exp(-x))


class Base(ABC): class Base(ABC):
def __init__(self, key, model_name): def __init__(self, key, model_name):
pass pass


def _compute_batch_scores(self, batch_pairs, max_length=None): def _compute_batch_scores(self, batch_pairs, max_length=None):
if max_length is None: if max_length is None:
scores = self._model.compute_score(batch_pairs)
scores = self._model.compute_score(batch_pairs, normalize=True)
else: else:
scores = self._model.compute_score(batch_pairs, max_length=max_length)
scores = sigmoid(np.array(scores))
scores = self._model.compute_score(batch_pairs, max_length=max_length, normalize=True)
if not isinstance(scores, Iterable): if not isinstance(scores, Iterable):
scores = [scores] scores = [scores]
return scores return scores

正在加载...
取消
保存