You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

benchmark.py 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import os
  18. import sys
  19. import time
  20. import argparse
  21. from collections import defaultdict
  22. from api.db import LLMType
  23. from api.db.services.llm_service import LLMBundle
  24. from api.db.services.knowledgebase_service import KnowledgebaseService
  25. from api import settings
  26. from api.utils import get_uuid
  27. from rag.nlp import tokenize, search
  28. from ranx import evaluate
  29. from ranx import Qrels, Run
  30. import pandas as pd
  31. from tqdm import tqdm
  32. global max_docs
  33. max_docs = sys.maxsize
  34. class Benchmark:
  35. def __init__(self, kb_id):
  36. self.kb_id = kb_id
  37. e, self.kb = KnowledgebaseService.get_by_id(kb_id)
  38. self.similarity_threshold = self.kb.similarity_threshold
  39. self.vector_similarity_weight = self.kb.vector_similarity_weight
  40. self.embd_mdl = LLMBundle(self.kb.tenant_id, LLMType.EMBEDDING, llm_name=self.kb.embd_id, lang=self.kb.language)
  41. self.tenant_id = ''
  42. self.index_name = ''
  43. self.initialized_index = False
  44. def _get_retrieval(self, qrels):
  45. # Need to wait for the ES and Infinity index to be ready
  46. time.sleep(20)
  47. run = defaultdict(dict)
  48. query_list = list(qrels.keys())
  49. for query in query_list:
  50. ranks = settings.retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
  51. 0.0, self.vector_similarity_weight)
  52. if len(ranks["chunks"]) == 0:
  53. print(f"deleted query: {query}")
  54. del qrels[query]
  55. continue
  56. for c in ranks["chunks"]:
  57. if "vector" in c:
  58. del c["vector"]
  59. run[query][c["chunk_id"]] = c["similarity"]
  60. return run
  61. def embedding(self, docs, batch_size=16):
  62. vects = []
  63. cnts = [d["content_with_weight"] for d in docs]
  64. for i in range(0, len(cnts), batch_size):
  65. vts, c = self.embd_mdl.encode(cnts[i: i + batch_size])
  66. vects.extend(vts.tolist())
  67. assert len(docs) == len(vects)
  68. vector_size = 0
  69. for i, d in enumerate(docs):
  70. v = vects[i]
  71. vector_size = len(v)
  72. d["q_%d_vec" % len(v)] = v
  73. return docs, vector_size
  74. def init_index(self, vector_size: int):
  75. if self.initialized_index:
  76. return
  77. if settings.docStoreConn.indexExist(self.index_name, self.kb_id):
  78. settings.docStoreConn.deleteIdx(self.index_name, self.kb_id)
  79. settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
  80. self.initialized_index = True
  81. def ms_marco_index(self, file_path, index_name):
  82. qrels = defaultdict(dict)
  83. texts = defaultdict(dict)
  84. docs_count = 0
  85. docs = []
  86. filelist = sorted(os.listdir(file_path))
  87. for fn in filelist:
  88. if docs_count >= max_docs:
  89. break
  90. if not fn.endswith(".parquet"):
  91. continue
  92. data = pd.read_parquet(os.path.join(file_path, fn))
  93. for i in tqdm(range(len(data)), colour="green", desc="Tokenizing:" + fn):
  94. if docs_count >= max_docs:
  95. break
  96. query = data.iloc[i]['query']
  97. for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
  98. d = {
  99. "id": get_uuid(),
  100. "kb_id": self.kb.id,
  101. "docnm_kwd": "xxxxx",
  102. "doc_id": "ksksks"
  103. }
  104. tokenize(d, text, "english")
  105. docs.append(d)
  106. texts[d["id"]] = text
  107. qrels[query][d["id"]] = int(rel)
  108. if len(docs) >= 32:
  109. docs_count += len(docs)
  110. docs, vector_size = self.embedding(docs)
  111. self.init_index(vector_size)
  112. settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
  113. docs = []
  114. if docs:
  115. docs, vector_size = self.embedding(docs)
  116. self.init_index(vector_size)
  117. settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
  118. return qrels, texts
  119. def trivia_qa_index(self, file_path, index_name):
  120. qrels = defaultdict(dict)
  121. texts = defaultdict(dict)
  122. docs_count = 0
  123. docs = []
  124. filelist = sorted(os.listdir(file_path))
  125. for fn in filelist:
  126. if docs_count >= max_docs:
  127. break
  128. if not fn.endswith(".parquet"):
  129. continue
  130. data = pd.read_parquet(os.path.join(file_path, fn))
  131. for i in tqdm(range(len(data)), colour="green", desc="Indexing:" + fn):
  132. if docs_count >= max_docs:
  133. break
  134. query = data.iloc[i]['question']
  135. for rel, text in zip(data.iloc[i]["search_results"]['rank'],
  136. data.iloc[i]["search_results"]['search_context']):
  137. d = {
  138. "id": get_uuid(),
  139. "kb_id": self.kb.id,
  140. "docnm_kwd": "xxxxx",
  141. "doc_id": "ksksks"
  142. }
  143. tokenize(d, text, "english")
  144. docs.append(d)
  145. texts[d["id"]] = text
  146. qrels[query][d["id"]] = int(rel)
  147. if len(docs) >= 32:
  148. docs_count += len(docs)
  149. docs, vector_size = self.embedding(docs)
  150. self.init_index(vector_size)
  151. settings.docStoreConn.insert(docs,self.index_name)
  152. docs = []
  153. docs, vector_size = self.embedding(docs)
  154. self.init_index(vector_size)
  155. settings.docStoreConn.insert(docs, self.index_name)
  156. return qrels, texts
  157. def miracl_index(self, file_path, corpus_path, index_name):
  158. corpus_total = {}
  159. for corpus_file in os.listdir(corpus_path):
  160. tmp_data = pd.read_json(os.path.join(corpus_path, corpus_file), lines=True)
  161. for index, i in tmp_data.iterrows():
  162. corpus_total[i['docid']] = i['text']
  163. topics_total = {}
  164. for topics_file in os.listdir(os.path.join(file_path, 'topics')):
  165. if 'test' in topics_file:
  166. continue
  167. tmp_data = pd.read_csv(os.path.join(file_path, 'topics', topics_file), sep='\t', names=['qid', 'query'])
  168. for index, i in tmp_data.iterrows():
  169. topics_total[i['qid']] = i['query']
  170. qrels = defaultdict(dict)
  171. texts = defaultdict(dict)
  172. docs_count = 0
  173. docs = []
  174. for qrels_file in os.listdir(os.path.join(file_path, 'qrels')):
  175. if 'test' in qrels_file:
  176. continue
  177. if docs_count >= max_docs:
  178. break
  179. tmp_data = pd.read_csv(os.path.join(file_path, 'qrels', qrels_file), sep='\t',
  180. names=['qid', 'Q0', 'docid', 'relevance'])
  181. for i in tqdm(range(len(tmp_data)), colour="green", desc="Indexing:" + qrels_file):
  182. if docs_count >= max_docs:
  183. break
  184. query = topics_total[tmp_data.iloc[i]['qid']]
  185. text = corpus_total[tmp_data.iloc[i]['docid']]
  186. rel = tmp_data.iloc[i]['relevance']
  187. d = {
  188. "id": get_uuid(),
  189. "kb_id": self.kb.id,
  190. "docnm_kwd": "xxxxx",
  191. "doc_id": "ksksks"
  192. }
  193. tokenize(d, text, 'english')
  194. docs.append(d)
  195. texts[d["id"]] = text
  196. qrels[query][d["id"]] = int(rel)
  197. if len(docs) >= 32:
  198. docs_count += len(docs)
  199. docs, vector_size = self.embedding(docs)
  200. self.init_index(vector_size)
  201. settings.docStoreConn.insert(docs, self.index_name)
  202. docs = []
  203. docs, vector_size = self.embedding(docs)
  204. self.init_index(vector_size)
  205. settings.docStoreConn.insert(docs, self.index_name)
  206. return qrels, texts
  207. def save_results(self, qrels, run, texts, dataset, file_path):
  208. keep_result = []
  209. run_keys = list(run.keys())
  210. for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"):
  211. key = run_keys[run_i]
  212. keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key],
  213. 'ndcg@10': evaluate({key: qrels[key]}, {key: run[key]}, "ndcg@10")})
  214. keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10'])
  215. with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f:
  216. f.write('## Score For Every Query\n')
  217. for keep_result_i in keep_result:
  218. f.write('### query: ' + keep_result_i['query'] + ' ndcg@10:' + str(keep_result_i['ndcg@10']) + '\n')
  219. scores = [[i[0], i[1]] for i in keep_result_i['run'].items()]
  220. scores = sorted(scores, key=lambda kk: kk[1])
  221. for score in scores[:10]:
  222. f.write('- text: ' + str(texts[score[0]]) + '\t qrel: ' + str(score[1]) + '\n')
  223. json.dump(qrels, open(os.path.join(file_path, dataset + '.qrels.json'), "w+"), indent=2)
  224. json.dump(run, open(os.path.join(file_path, dataset + '.run.json'), "w+"), indent=2)
  225. print(os.path.join(file_path, dataset + '_result.md'), 'Saved!')
  226. def __call__(self, dataset, file_path, miracl_corpus=''):
  227. if dataset == "ms_marco_v1.1":
  228. self.tenant_id = "benchmark_ms_marco_v11"
  229. self.index_name = search.index_name(self.tenant_id)
  230. qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1")
  231. run = self._get_retrieval(qrels)
  232. print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr@10"]))
  233. self.save_results(qrels, run, texts, dataset, file_path)
  234. if dataset == "trivia_qa":
  235. self.tenant_id = "benchmark_trivia_qa"
  236. self.index_name = search.index_name(self.tenant_id)
  237. qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa")
  238. run = self._get_retrieval(qrels)
  239. print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr@10"]))
  240. self.save_results(qrels, run, texts, dataset, file_path)
  241. if dataset == "miracl":
  242. for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th',
  243. 'yo', 'zh']:
  244. if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang)):
  245. print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang) + ' not found!')
  246. continue
  247. if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels')):
  248. print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels') + 'not found!')
  249. continue
  250. if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics')):
  251. print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics') + 'not found!')
  252. continue
  253. if not os.path.isdir(os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang)):
  254. print('Directory: ' + os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang) + ' not found!')
  255. continue
  256. self.tenant_id = "benchmark_miracl_" + lang
  257. self.index_name = search.index_name(self.tenant_id)
  258. self.initialized_index = False
  259. qrels, texts = self.miracl_index(os.path.join(file_path, 'miracl-v1.0-' + lang),
  260. os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang),
  261. "benchmark_miracl_" + lang)
  262. run = self._get_retrieval(qrels)
  263. print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr@10"]))
  264. self.save_results(qrels, run, texts, dataset, file_path)
  265. if __name__ == '__main__':
  266. print('*****************RAGFlow Benchmark*****************')
  267. parser = argparse.ArgumentParser(usage="benchmark.py <max_docs> <kb_id> <dataset> <dataset_path> [<miracl_corpus_path>])", description='RAGFlow Benchmark')
  268. parser.add_argument('max_docs', metavar='max_docs', type=int, help='max docs to evaluate')
  269. parser.add_argument('kb_id', metavar='kb_id', help='knowledgebase id')
  270. parser.add_argument('dataset', metavar='dataset', help='dataset name, shall be one of ms_marco_v1.1(https://huggingface.co/datasets/microsoft/ms_marco), trivia_qa(https://huggingface.co/datasets/mandarjoshi/trivia_qa>), miracl(https://huggingface.co/datasets/miracl/miracl')
  271. parser.add_argument('dataset_path', metavar='dataset_path', help='dataset path')
  272. parser.add_argument('miracl_corpus_path', metavar='miracl_corpus_path', nargs='?', default="", help='miracl corpus path. Only needed when dataset is miracl')
  273. args = parser.parse_args()
  274. max_docs = args.max_docs
  275. kb_id = args.kb_id
  276. ex = Benchmark(kb_id)
  277. dataset = args.dataset
  278. dataset_path = args.dataset_path
  279. if dataset == "ms_marco_v1.1" or dataset == "trivia_qa":
  280. ex(dataset, dataset_path)
  281. elif dataset == "miracl":
  282. if len(args) < 5:
  283. print('Please input the correct parameters!')
  284. exit(1)
  285. miracl_corpus_path = args[4]
  286. ex(dataset, dataset_path, miracl_corpus=args.miracl_corpus_path)
  287. else:
  288. print("Dataset: ", dataset, "not supported!")