Browse Source

Refactor rerank model with dynamic batch processing and memory manage… (#5273)

…ment

### What problem does this PR solve?
Issue:https://github.com/infiniflow/ragflow/issues/5262
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Co-authored-by: wenju.li <wenju.li@deepctr.cn>
tags/v0.17.0
liwenju0 8 months ago
parent
commit
569e40544d
No account linked to committer's email address
1 changed files with 54 additions and 16 deletions
  1. 54
    16
      rag/llm/rerank_model.py

+ 54
- 16
rag/llm/rerank_model.py View File

@@ -31,6 +31,7 @@ from rag.utils import num_tokens_from_string, truncate
import json



def sigmoid(x):
return 1 / (1 + np.exp(-x))

@@ -86,6 +87,57 @@ class DefaultRerank(Base):
local_dir_use_symlinks=False)
DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
self._model = DefaultRerank._model
self._dynamic_batch_size = 8
self._min_batch_size = 1

def torch_empty_cache(self):
try:
import torch
torch.cuda.empty_cache()
except Exception as e:
print(f"Error emptying cache: {e}")

def _process_batch(self, pairs, max_batch_size=None):
"""template method for subclass call"""
old_dynamic_batch_size = self._dynamic_batch_size
if max_batch_size is not None:
self._dynamic_batch_size = max_batch_size
res = []
i = 0
while i < len(pairs):
current_batch = self._dynamic_batch_size
max_retries = 5
retry_count = 0
while retry_count < max_retries:
try:
# call subclass implemented batch processing calculation
batch_scores = self._compute_batch_scores(pairs[i:i+current_batch])
res.extend(batch_scores)
i += current_batch
self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8)
break
except RuntimeError as e:
if "CUDA out of memory" in str(e) and current_batch > self._min_batch_size:
current_batch = max(current_batch // 2, self._min_batch_size)
self.torch_empty_cache()
retry_count += 1
else:
raise
if retry_count >= max_retries:
raise RuntimeError("max retry times, still cannot process batch, please check your GPU memory")
self.torch_empty_cache()

self._dynamic_batch_size = old_dynamic_batch_size
return np.array(res)

def _compute_batch_scores(self, batch_pairs, max_length=None):
if max_length is None:
max_length = self._model.max_length
scores = self._model.compute_score(batch_pairs, max_length=max_length)
scores = sigmoid(np.array(scores)).tolist()
return scores

def similarity(self, query: str, texts: list):
pairs = [(query, truncate(t, 2048)) for t in texts]
@@ -93,14 +145,7 @@ class DefaultRerank(Base):
for _, t in pairs:
token_count += num_tokens_from_string(t)
batch_size = 4096
res = []
for i in range(0, len(pairs), batch_size):
scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048)
scores = sigmoid(np.array(scores)).tolist()
if isinstance(scores, float):
res.append(scores)
else:
res.extend(scores)
res = self._process_batch(pairs, max_batch_size=batch_size)
return np.array(res), token_count


@@ -155,14 +200,7 @@ class YoudaoRerank(DefaultRerank):
for _, t in pairs:
token_count += num_tokens_from_string(t)
batch_size = 8
res = []
for i in range(0, len(pairs), batch_size):
scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length)
scores = sigmoid(np.array(scores)).tolist()
if isinstance(scores, float):
res.append(scores)
else:
res.extend(scores)
res = self._process_batch(pairs, max_batch_size=batch_size)
return np.array(res), token_count



Loading…
Cancel
Save