浏览代码

fix benchmark issue (#3324)

### What problem does this PR solve?



### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
tags/v0.14.0
Kevin Hu 11 个月前
父节点
当前提交
5e5a35191e
没有帐户链接到提交者的电子邮件
共有 1 个文件被更改,包括 18 次插入9 次删除
  1. 18
    9
      rag/benchmark.py

+ 18
- 9
rag/benchmark.py 查看文件

from ranx import evaluate from ranx import evaluate
import pandas as pd import pandas as pd
from tqdm import tqdm from tqdm import tqdm
from ranx import Qrels, Run
class Benchmark: class Benchmark:
query_list = list(qrels.keys()) query_list = list(qrels.keys())
for query in query_list: for query in query_list:
ranks = retrievaler.retrieval(query, self.embd_mdl, dataset_idxnm.replace("ragflow_", ""),
[self.kb.id], 0, 30,
ranks = retrievaler.retrieval(query, self.embd_mdl,
dataset_idxnm, [self.kb.id], 1, 30,
0.0, self.vector_similarity_weight) 0.0, self.vector_similarity_weight)
for c in ranks["chunks"]: for c in ranks["chunks"]:
if "vector" in c: if "vector" in c:
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']): for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
d = { d = {
"id": get_uuid(), "id": get_uuid(),
"kb_id": self.kb.id
"kb_id": self.kb.id,
"docnm_kwd": "xxxxx",
"doc_id": "ksksks"
} }
tokenize(d, text, "english") tokenize(d, text, "english")
docs.append(d) docs.append(d)
for rel, text in zip(data.iloc[i]["search_results"]['rank'], for rel, text in zip(data.iloc[i]["search_results"]['rank'],
data.iloc[i]["search_results"]['search_context']): data.iloc[i]["search_results"]['search_context']):
d = { d = {
"id": get_uuid()
"id": get_uuid(),
"kb_id": self.kb.id,
"docnm_kwd": "xxxxx",
"doc_id": "ksksks"
} }
tokenize(d, text, "english") tokenize(d, text, "english")
docs.append(d) docs.append(d)
text = corpus_total[tmp_data.iloc[i]['docid']] text = corpus_total[tmp_data.iloc[i]['docid']]
rel = tmp_data.iloc[i]['relevance'] rel = tmp_data.iloc[i]['relevance']
d = { d = {
"id": get_uuid()
"id": get_uuid(),
"kb_id": self.kb.id,
"docnm_kwd": "xxxxx",
"doc_id": "ksksks"
} }
tokenize(d, text, 'english') tokenize(d, text, 'english')
docs.append(d) docs.append(d)
for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"): for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"):
key = run_keys[run_i] key = run_keys[run_i]
keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key], keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key],
'ndcg@10': evaluate({key: qrels[key]}, {key: run[key]}, "ndcg@10")})
'ndcg@10': evaluate(Qrels({key: qrels[key]}), Run({key: run[key]}), "ndcg@10")})
keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10']) keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10'])
with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f: with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f:
f.write('## Score For Every Query\n') f.write('## Score For Every Query\n')
if dataset == "ms_marco_v1.1": if dataset == "ms_marco_v1.1":
qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1") qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1")
run = self._get_retrieval(qrels, "benchmark_ms_marco_v1.1") run = self._get_retrieval(qrels, "benchmark_ms_marco_v1.1")
print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
self.save_results(qrels, run, texts, dataset, file_path) self.save_results(qrels, run, texts, dataset, file_path)
if dataset == "trivia_qa": if dataset == "trivia_qa":
qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa") qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa")
run = self._get_retrieval(qrels, "benchmark_trivia_qa") run = self._get_retrieval(qrels, "benchmark_trivia_qa")
print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
print(dataset, evaluate((qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
self.save_results(qrels, run, texts, dataset, file_path) self.save_results(qrels, run, texts, dataset, file_path)
if dataset == "miracl": if dataset == "miracl":
for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th', for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th',
os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang), os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang),
"benchmark_miracl_" + lang) "benchmark_miracl_" + lang)
run = self._get_retrieval(qrels, "benchmark_miracl_" + lang) run = self._get_retrieval(qrels, "benchmark_miracl_" + lang)
print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
self.save_results(qrels, run, texts, dataset, file_path) self.save_results(qrels, run, texts, dataset, file_path)

正在加载...
取消
保存