Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

benchmark.py 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import os
  18. import sys
  19. import time
  20. import argparse
  21. from collections import defaultdict
  22. from api.db import LLMType
  23. from api.db.services.llm_service import LLMBundle
  24. from api.db.services.knowledgebase_service import KnowledgebaseService
  25. from api import settings
  26. from api.utils import get_uuid
  27. from rag.nlp import tokenize, search
  28. from ranx import evaluate
  29. from ranx import Qrels, Run
  30. import pandas as pd
  31. from tqdm import tqdm
  32. global max_docs
  33. max_docs = sys.maxsize
  34. class Benchmark:
  35. def __init__(self, kb_id):
  36. self.kb_id = kb_id
  37. e, self.kb = KnowledgebaseService.get_by_id(kb_id)
  38. self.similarity_threshold = self.kb.similarity_threshold
  39. self.vector_similarity_weight = self.kb.vector_similarity_weight
  40. self.embd_mdl = LLMBundle(self.kb.tenant_id, LLMType.EMBEDDING, llm_name=self.kb.embd_id, lang=self.kb.language)
  41. self.tenant_id = ''
  42. self.index_name = ''
  43. self.initialized_index = False
  44. def _get_retrieval(self, qrels):
  45. # Need to wait for the ES and Infinity index to be ready
  46. time.sleep(20)
  47. run = defaultdict(dict)
  48. query_list = list(qrels.keys())
  49. for query in query_list:
  50. ranks = settings.retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
  51. 0.0, self.vector_similarity_weight)
  52. if len(ranks["chunks"]) == 0:
  53. print(f"deleted query: {query}")
  54. del qrels[query]
  55. continue
  56. for c in ranks["chunks"]:
  57. c.pop("vector", None)
  58. run[query][c["chunk_id"]] = c["similarity"]
  59. return run
  60. def embedding(self, docs, batch_size=16):
  61. vects = []
  62. cnts = [d["content_with_weight"] for d in docs]
  63. for i in range(0, len(cnts), batch_size):
  64. vts, c = self.embd_mdl.encode(cnts[i: i + batch_size])
  65. vects.extend(vts.tolist())
  66. assert len(docs) == len(vects)
  67. vector_size = 0
  68. for i, d in enumerate(docs):
  69. v = vects[i]
  70. vector_size = len(v)
  71. d["q_%d_vec" % len(v)] = v
  72. return docs, vector_size
  73. def init_index(self, vector_size: int):
  74. if self.initialized_index:
  75. return
  76. if settings.docStoreConn.indexExist(self.index_name, self.kb_id):
  77. settings.docStoreConn.deleteIdx(self.index_name, self.kb_id)
  78. settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
  79. self.initialized_index = True
  80. def ms_marco_index(self, file_path, index_name):
  81. qrels = defaultdict(dict)
  82. texts = defaultdict(dict)
  83. docs_count = 0
  84. docs = []
  85. filelist = sorted(os.listdir(file_path))
  86. for fn in filelist:
  87. if docs_count >= max_docs:
  88. break
  89. if not fn.endswith(".parquet"):
  90. continue
  91. data = pd.read_parquet(os.path.join(file_path, fn))
  92. for i in tqdm(range(len(data)), colour="green", desc="Tokenizing:" + fn):
  93. if docs_count >= max_docs:
  94. break
  95. query = data.iloc[i]['query']
  96. for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
  97. d = {
  98. "id": get_uuid(),
  99. "kb_id": self.kb.id,
  100. "docnm_kwd": "xxxxx",
  101. "doc_id": "ksksks"
  102. }
  103. tokenize(d, text, "english")
  104. docs.append(d)
  105. texts[d["id"]] = text
  106. qrels[query][d["id"]] = int(rel)
  107. if len(docs) >= 32:
  108. docs_count += len(docs)
  109. docs, vector_size = self.embedding(docs)
  110. self.init_index(vector_size)
  111. settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
  112. docs = []
  113. if docs:
  114. docs, vector_size = self.embedding(docs)
  115. self.init_index(vector_size)
  116. settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
  117. return qrels, texts
  118. def trivia_qa_index(self, file_path, index_name):
  119. qrels = defaultdict(dict)
  120. texts = defaultdict(dict)
  121. docs_count = 0
  122. docs = []
  123. filelist = sorted(os.listdir(file_path))
  124. for fn in filelist:
  125. if docs_count >= max_docs:
  126. break
  127. if not fn.endswith(".parquet"):
  128. continue
  129. data = pd.read_parquet(os.path.join(file_path, fn))
  130. for i in tqdm(range(len(data)), colour="green", desc="Indexing:" + fn):
  131. if docs_count >= max_docs:
  132. break
  133. query = data.iloc[i]['question']
  134. for rel, text in zip(data.iloc[i]["search_results"]['rank'],
  135. data.iloc[i]["search_results"]['search_context']):
  136. d = {
  137. "id": get_uuid(),
  138. "kb_id": self.kb.id,
  139. "docnm_kwd": "xxxxx",
  140. "doc_id": "ksksks"
  141. }
  142. tokenize(d, text, "english")
  143. docs.append(d)
  144. texts[d["id"]] = text
  145. qrels[query][d["id"]] = int(rel)
  146. if len(docs) >= 32:
  147. docs_count += len(docs)
  148. docs, vector_size = self.embedding(docs)
  149. self.init_index(vector_size)
  150. settings.docStoreConn.insert(docs,self.index_name)
  151. docs = []
  152. docs, vector_size = self.embedding(docs)
  153. self.init_index(vector_size)
  154. settings.docStoreConn.insert(docs, self.index_name)
  155. return qrels, texts
  156. def miracl_index(self, file_path, corpus_path, index_name):
  157. corpus_total = {}
  158. for corpus_file in os.listdir(corpus_path):
  159. tmp_data = pd.read_json(os.path.join(corpus_path, corpus_file), lines=True)
  160. for index, i in tmp_data.iterrows():
  161. corpus_total[i['docid']] = i['text']
  162. topics_total = {}
  163. for topics_file in os.listdir(os.path.join(file_path, 'topics')):
  164. if 'test' in topics_file:
  165. continue
  166. tmp_data = pd.read_csv(os.path.join(file_path, 'topics', topics_file), sep='\t', names=['qid', 'query'])
  167. for index, i in tmp_data.iterrows():
  168. topics_total[i['qid']] = i['query']
  169. qrels = defaultdict(dict)
  170. texts = defaultdict(dict)
  171. docs_count = 0
  172. docs = []
  173. for qrels_file in os.listdir(os.path.join(file_path, 'qrels')):
  174. if 'test' in qrels_file:
  175. continue
  176. if docs_count >= max_docs:
  177. break
  178. tmp_data = pd.read_csv(os.path.join(file_path, 'qrels', qrels_file), sep='\t',
  179. names=['qid', 'Q0', 'docid', 'relevance'])
  180. for i in tqdm(range(len(tmp_data)), colour="green", desc="Indexing:" + qrels_file):
  181. if docs_count >= max_docs:
  182. break
  183. query = topics_total[tmp_data.iloc[i]['qid']]
  184. text = corpus_total[tmp_data.iloc[i]['docid']]
  185. rel = tmp_data.iloc[i]['relevance']
  186. d = {
  187. "id": get_uuid(),
  188. "kb_id": self.kb.id,
  189. "docnm_kwd": "xxxxx",
  190. "doc_id": "ksksks"
  191. }
  192. tokenize(d, text, 'english')
  193. docs.append(d)
  194. texts[d["id"]] = text
  195. qrels[query][d["id"]] = int(rel)
  196. if len(docs) >= 32:
  197. docs_count += len(docs)
  198. docs, vector_size = self.embedding(docs)
  199. self.init_index(vector_size)
  200. settings.docStoreConn.insert(docs, self.index_name)
  201. docs = []
  202. docs, vector_size = self.embedding(docs)
  203. self.init_index(vector_size)
  204. settings.docStoreConn.insert(docs, self.index_name)
  205. return qrels, texts
  206. def save_results(self, qrels, run, texts, dataset, file_path):
  207. keep_result = []
  208. run_keys = list(run.keys())
  209. for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"):
  210. key = run_keys[run_i]
  211. keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key],
  212. 'ndcg@10': evaluate({key: qrels[key]}, {key: run[key]}, "ndcg@10")})
  213. keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10'])
  214. with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f:
  215. f.write('## Score For Every Query\n')
  216. for keep_result_i in keep_result:
  217. f.write('### query: ' + keep_result_i['query'] + ' ndcg@10:' + str(keep_result_i['ndcg@10']) + '\n')
  218. scores = [[i[0], i[1]] for i in keep_result_i['run'].items()]
  219. scores = sorted(scores, key=lambda kk: kk[1])
  220. for score in scores[:10]:
  221. f.write('- text: ' + str(texts[score[0]]) + '\t qrel: ' + str(score[1]) + '\n')
  222. json.dump(qrels, open(os.path.join(file_path, dataset + '.qrels.json'), "w+", encoding='utf-8'), indent=2)
  223. json.dump(run, open(os.path.join(file_path, dataset + '.run.json'), "w+", encoding='utf-8'), indent=2)
  224. print(os.path.join(file_path, dataset + '_result.md'), 'Saved!')
  225. def __call__(self, dataset, file_path, miracl_corpus=''):
  226. if dataset == "ms_marco_v1.1":
  227. self.tenant_id = "benchmark_ms_marco_v11"
  228. self.index_name = search.index_name(self.tenant_id)
  229. qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1")
  230. run = self._get_retrieval(qrels)
  231. print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr@10"]))
  232. self.save_results(qrels, run, texts, dataset, file_path)
  233. if dataset == "trivia_qa":
  234. self.tenant_id = "benchmark_trivia_qa"
  235. self.index_name = search.index_name(self.tenant_id)
  236. qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa")
  237. run = self._get_retrieval(qrels)
  238. print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr@10"]))
  239. self.save_results(qrels, run, texts, dataset, file_path)
  240. if dataset == "miracl":
  241. for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th',
  242. 'yo', 'zh']:
  243. if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang)):
  244. print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang) + ' not found!')
  245. continue
  246. if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels')):
  247. print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels') + 'not found!')
  248. continue
  249. if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics')):
  250. print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics') + 'not found!')
  251. continue
  252. if not os.path.isdir(os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang)):
  253. print('Directory: ' + os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang) + ' not found!')
  254. continue
  255. self.tenant_id = "benchmark_miracl_" + lang
  256. self.index_name = search.index_name(self.tenant_id)
  257. self.initialized_index = False
  258. qrels, texts = self.miracl_index(os.path.join(file_path, 'miracl-v1.0-' + lang),
  259. os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang),
  260. "benchmark_miracl_" + lang)
  261. run = self._get_retrieval(qrels)
  262. print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr@10"]))
  263. self.save_results(qrels, run, texts, dataset, file_path)
  264. if __name__ == '__main__':
  265. print('*****************RAGFlow Benchmark*****************')
  266. parser = argparse.ArgumentParser(usage="benchmark.py <max_docs> <kb_id> <dataset> <dataset_path> [<miracl_corpus_path>])", description='RAGFlow Benchmark')
  267. parser.add_argument('max_docs', metavar='max_docs', type=int, help='max docs to evaluate')
  268. parser.add_argument('kb_id', metavar='kb_id', help='knowledgebase id')
  269. parser.add_argument('dataset', metavar='dataset', help='dataset name, shall be one of ms_marco_v1.1(https://huggingface.co/datasets/microsoft/ms_marco), trivia_qa(https://huggingface.co/datasets/mandarjoshi/trivia_qa>), miracl(https://huggingface.co/datasets/miracl/miracl')
  270. parser.add_argument('dataset_path', metavar='dataset_path', help='dataset path')
  271. parser.add_argument('miracl_corpus_path', metavar='miracl_corpus_path', nargs='?', default="", help='miracl corpus path. Only needed when dataset is miracl')
  272. args = parser.parse_args()
  273. max_docs = args.max_docs
  274. kb_id = args.kb_id
  275. ex = Benchmark(kb_id)
  276. dataset = args.dataset
  277. dataset_path = args.dataset_path
  278. if dataset == "ms_marco_v1.1" or dataset == "trivia_qa":
  279. ex(dataset, dataset_path)
  280. elif dataset == "miracl":
  281. if len(args) < 5:
  282. print('Please input the correct parameters!')
  283. exit(1)
  284. miracl_corpus_path = args[4]
  285. ex(dataset, dataset_path, miracl_corpus=args.miracl_corpus_path)
  286. else:
  287. print("Dataset: ", dataset, "not supported!")