|
|
|
@@ -31,6 +31,7 @@ from rag.utils import num_tokens_from_string, truncate |
|
|
|
import json |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sigmoid(x): |
|
|
|
return 1 / (1 + np.exp(-x)) |
|
|
|
|
|
|
|
@@ -86,6 +87,57 @@ class DefaultRerank(Base): |
|
|
|
local_dir_use_symlinks=False) |
|
|
|
DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available()) |
|
|
|
self._model = DefaultRerank._model |
|
|
|
self._dynamic_batch_size = 8 |
|
|
|
self._min_batch_size = 1 |
|
|
|
|
|
|
|
|
|
|
|
def torch_empty_cache(self): |
|
|
|
try: |
|
|
|
import torch |
|
|
|
torch.cuda.empty_cache() |
|
|
|
except Exception as e: |
|
|
|
print(f"Error emptying cache: {e}") |
|
|
|
|
|
|
|
def _process_batch(self, pairs, max_batch_size=None): |
|
|
|
"""template method for subclass call""" |
|
|
|
old_dynamic_batch_size = self._dynamic_batch_size |
|
|
|
if max_batch_size is not None: |
|
|
|
self._dynamic_batch_size = max_batch_size |
|
|
|
res = [] |
|
|
|
i = 0 |
|
|
|
while i < len(pairs): |
|
|
|
current_batch = self._dynamic_batch_size |
|
|
|
max_retries = 5 |
|
|
|
retry_count = 0 |
|
|
|
while retry_count < max_retries: |
|
|
|
try: |
|
|
|
# call subclass implemented batch processing calculation |
|
|
|
batch_scores = self._compute_batch_scores(pairs[i:i+current_batch]) |
|
|
|
res.extend(batch_scores) |
|
|
|
i += current_batch |
|
|
|
self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8) |
|
|
|
break |
|
|
|
except RuntimeError as e: |
|
|
|
if "CUDA out of memory" in str(e) and current_batch > self._min_batch_size: |
|
|
|
current_batch = max(current_batch // 2, self._min_batch_size) |
|
|
|
self.torch_empty_cache() |
|
|
|
retry_count += 1 |
|
|
|
else: |
|
|
|
raise |
|
|
|
if retry_count >= max_retries: |
|
|
|
raise RuntimeError("max retry times, still cannot process batch, please check your GPU memory") |
|
|
|
self.torch_empty_cache() |
|
|
|
|
|
|
|
self._dynamic_batch_size = old_dynamic_batch_size |
|
|
|
return np.array(res) |
|
|
|
|
|
|
|
|
|
|
|
def _compute_batch_scores(self, batch_pairs, max_length=None): |
|
|
|
if max_length is None: |
|
|
|
max_length = self._model.max_length |
|
|
|
scores = self._model.compute_score(batch_pairs, max_length=max_length) |
|
|
|
scores = sigmoid(np.array(scores)).tolist() |
|
|
|
return scores |
|
|
|
|
|
|
|
def similarity(self, query: str, texts: list): |
|
|
|
pairs = [(query, truncate(t, 2048)) for t in texts] |
|
|
|
@@ -93,14 +145,7 @@ class DefaultRerank(Base): |
|
|
|
for _, t in pairs: |
|
|
|
token_count += num_tokens_from_string(t) |
|
|
|
batch_size = 4096 |
|
|
|
res = [] |
|
|
|
for i in range(0, len(pairs), batch_size): |
|
|
|
scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048) |
|
|
|
scores = sigmoid(np.array(scores)).tolist() |
|
|
|
if isinstance(scores, float): |
|
|
|
res.append(scores) |
|
|
|
else: |
|
|
|
res.extend(scores) |
|
|
|
res = self._process_batch(pairs, max_batch_size=batch_size) |
|
|
|
return np.array(res), token_count |
|
|
|
|
|
|
|
|
|
|
|
@@ -155,14 +200,7 @@ class YoudaoRerank(DefaultRerank): |
|
|
|
for _, t in pairs: |
|
|
|
token_count += num_tokens_from_string(t) |
|
|
|
batch_size = 8 |
|
|
|
res = [] |
|
|
|
for i in range(0, len(pairs), batch_size): |
|
|
|
scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length) |
|
|
|
scores = sigmoid(np.array(scores)).tolist() |
|
|
|
if isinstance(scores, float): |
|
|
|
res.append(scores) |
|
|
|
else: |
|
|
|
res.extend(scores) |
|
|
|
res = self._process_batch(pairs, max_batch_size=batch_size) |
|
|
|
return np.array(res), token_count |
|
|
|
|
|
|
|
|