浏览代码

Refa: Remove useless conver and fix a bug for DefaultRerank (#8887)

### What problem does this PR solve?

1. bug when re-try, we need to reset i.
2. remove useless convert

### Type of change

- [x] Refactoring
tags/v0.20.0
Stephen Hu 3 个月前
父节点
当前提交
38b34116dd
没有帐户链接到提交者的电子邮件
共有 1 个文件被更改,包括 5 次插入3 次删除
  1. 5
    3
      rag/llm/rerank_model.py

+ 5
- 3
rag/llm/rerank_model.py 查看文件

old_dynamic_batch_size = self._dynamic_batch_size old_dynamic_batch_size = self._dynamic_batch_size
if max_batch_size is not None: if max_batch_size is not None:
self._dynamic_batch_size = max_batch_size self._dynamic_batch_size = max_batch_size
res = []
res = np.array([], dtype=float)
i = 0 i = 0
while i < len(pairs): while i < len(pairs):
cur_i = i
current_batch = self._dynamic_batch_size current_batch = self._dynamic_batch_size
max_retries = 5 max_retries = 5
retry_count = 0 retry_count = 0
try: try:
# call subclass implemented batch processing calculation # call subclass implemented batch processing calculation
batch_scores = self._compute_batch_scores(pairs[i : i + current_batch]) batch_scores = self._compute_batch_scores(pairs[i : i + current_batch])
res.extend(batch_scores)
res = np.append(res, batch_scores)
i += current_batch i += current_batch
self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8) self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8)
break break
if "CUDA out of memory" in str(e) and current_batch > self._min_batch_size: if "CUDA out of memory" in str(e) and current_batch > self._min_batch_size:
current_batch = max(current_batch // 2, self._min_batch_size) current_batch = max(current_batch // 2, self._min_batch_size)
self.torch_empty_cache() self.torch_empty_cache()
i = cur_i # reset i to the start of the current batch
retry_count += 1 retry_count += 1
else: else:
raise raise
scores = self._model.compute_score(batch_pairs) scores = self._model.compute_score(batch_pairs)
else: else:
scores = self._model.compute_score(batch_pairs, max_length=max_length) scores = self._model.compute_score(batch_pairs, max_length=max_length)
scores = sigmoid(np.array(scores)).tolist()
scores = sigmoid(np.array(scores))
if not isinstance(scores, Iterable): if not isinstance(scores, Iterable):
scores = [scores] scores = [scores]
return scores return scores

正在加载...
取消
保存