Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

benchmark.py 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import os
  18. import sys
  19. import time
  20. import argparse
  21. from collections import defaultdict
  22. from api.db import LLMType
  23. from api.db.services.llm_service import LLMBundle
  24. from api.db.services.knowledgebase_service import KnowledgebaseService
  25. from api import settings
  26. from api.utils import get_uuid
  27. from rag.nlp import tokenize, search
  28. from ranx import evaluate
  29. from ranx import Qrels, Run
  30. import pandas as pd
  31. from tqdm import tqdm
  32. global max_docs
  33. max_docs = sys.maxsize
  34. class Benchmark:
  35. def __init__(self, kb_id):
  36. self.kb_id = kb_id
  37. e, self.kb = KnowledgebaseService.get_by_id(kb_id)
  38. self.similarity_threshold = self.kb.similarity_threshold
  39. self.vector_similarity_weight = self.kb.vector_similarity_weight
  40. self.embd_mdl = LLMBundle(self.kb.tenant_id, LLMType.EMBEDDING, llm_name=self.kb.embd_id, lang=self.kb.language)
  41. self.tenant_id = ''
  42. self.index_name = ''
  43. self.initialized_index = False
  44. def _get_retrieval(self, qrels):
  45. # Need to wait for the ES and Infinity index to be ready
  46. time.sleep(20)
  47. run = defaultdict(dict)
  48. query_list = list(qrels.keys())
  49. for query in query_list:
  50. ranks = settings.retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
  51. 0.0, self.vector_similarity_weight)
  52. if len(ranks["chunks"]) == 0:
  53. print(f"deleted query: {query}")
  54. del qrels[query]
  55. continue
  56. for c in ranks["chunks"]:
  57. c.pop("vector", None)
  58. run[query][c["chunk_id"]] = c["similarity"]
  59. return run
  60. def embedding(self, docs):
  61. texts = [d["content_with_weight"] for d in docs]
  62. embeddings, _ = self.embd_mdl.encode(texts)
  63. assert len(docs) == len(embeddings)
  64. vector_size = 0
  65. for i, d in enumerate(docs):
  66. v = embeddings[i]
  67. vector_size = len(v)
  68. d["q_%d_vec" % len(v)] = v
  69. return docs, vector_size
  70. def init_index(self, vector_size: int):
  71. if self.initialized_index:
  72. return
  73. if settings.docStoreConn.indexExist(self.index_name, self.kb_id):
  74. settings.docStoreConn.deleteIdx(self.index_name, self.kb_id)
  75. settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
  76. self.initialized_index = True
  77. def ms_marco_index(self, file_path, index_name):
  78. qrels = defaultdict(dict)
  79. texts = defaultdict(dict)
  80. docs_count = 0
  81. docs = []
  82. filelist = sorted(os.listdir(file_path))
  83. for fn in filelist:
  84. if docs_count >= max_docs:
  85. break
  86. if not fn.endswith(".parquet"):
  87. continue
  88. data = pd.read_parquet(os.path.join(file_path, fn))
  89. for i in tqdm(range(len(data)), colour="green", desc="Tokenizing:" + fn):
  90. if docs_count >= max_docs:
  91. break
  92. query = data.iloc[i]['query']
  93. for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
  94. d = {
  95. "id": get_uuid(),
  96. "kb_id": self.kb.id,
  97. "docnm_kwd": "xxxxx",
  98. "doc_id": "ksksks"
  99. }
  100. tokenize(d, text, "english")
  101. docs.append(d)
  102. texts[d["id"]] = text
  103. qrels[query][d["id"]] = int(rel)
  104. if len(docs) >= 32:
  105. docs_count += len(docs)
  106. docs, vector_size = self.embedding(docs)
  107. self.init_index(vector_size)
  108. settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
  109. docs = []
  110. if docs:
  111. docs, vector_size = self.embedding(docs)
  112. self.init_index(vector_size)
  113. settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
  114. return qrels, texts
  115. def trivia_qa_index(self, file_path, index_name):
  116. qrels = defaultdict(dict)
  117. texts = defaultdict(dict)
  118. docs_count = 0
  119. docs = []
  120. filelist = sorted(os.listdir(file_path))
  121. for fn in filelist:
  122. if docs_count >= max_docs:
  123. break
  124. if not fn.endswith(".parquet"):
  125. continue
  126. data = pd.read_parquet(os.path.join(file_path, fn))
  127. for i in tqdm(range(len(data)), colour="green", desc="Indexing:" + fn):
  128. if docs_count >= max_docs:
  129. break
  130. query = data.iloc[i]['question']
  131. for rel, text in zip(data.iloc[i]["search_results"]['rank'],
  132. data.iloc[i]["search_results"]['search_context']):
  133. d = {
  134. "id": get_uuid(),
  135. "kb_id": self.kb.id,
  136. "docnm_kwd": "xxxxx",
  137. "doc_id": "ksksks"
  138. }
  139. tokenize(d, text, "english")
  140. docs.append(d)
  141. texts[d["id"]] = text
  142. qrels[query][d["id"]] = int(rel)
  143. if len(docs) >= 32:
  144. docs_count += len(docs)
  145. docs, vector_size = self.embedding(docs)
  146. self.init_index(vector_size)
  147. settings.docStoreConn.insert(docs,self.index_name)
  148. docs = []
  149. docs, vector_size = self.embedding(docs)
  150. self.init_index(vector_size)
  151. settings.docStoreConn.insert(docs, self.index_name)
  152. return qrels, texts
  153. def miracl_index(self, file_path, corpus_path, index_name):
  154. corpus_total = {}
  155. for corpus_file in os.listdir(corpus_path):
  156. tmp_data = pd.read_json(os.path.join(corpus_path, corpus_file), lines=True)
  157. for index, i in tmp_data.iterrows():
  158. corpus_total[i['docid']] = i['text']
  159. topics_total = {}
  160. for topics_file in os.listdir(os.path.join(file_path, 'topics')):
  161. if 'test' in topics_file:
  162. continue
  163. tmp_data = pd.read_csv(os.path.join(file_path, 'topics', topics_file), sep='\t', names=['qid', 'query'])
  164. for index, i in tmp_data.iterrows():
  165. topics_total[i['qid']] = i['query']
  166. qrels = defaultdict(dict)
  167. texts = defaultdict(dict)
  168. docs_count = 0
  169. docs = []
  170. for qrels_file in os.listdir(os.path.join(file_path, 'qrels')):
  171. if 'test' in qrels_file:
  172. continue
  173. if docs_count >= max_docs:
  174. break
  175. tmp_data = pd.read_csv(os.path.join(file_path, 'qrels', qrels_file), sep='\t',
  176. names=['qid', 'Q0', 'docid', 'relevance'])
  177. for i in tqdm(range(len(tmp_data)), colour="green", desc="Indexing:" + qrels_file):
  178. if docs_count >= max_docs:
  179. break
  180. query = topics_total[tmp_data.iloc[i]['qid']]
  181. text = corpus_total[tmp_data.iloc[i]['docid']]
  182. rel = tmp_data.iloc[i]['relevance']
  183. d = {
  184. "id": get_uuid(),
  185. "kb_id": self.kb.id,
  186. "docnm_kwd": "xxxxx",
  187. "doc_id": "ksksks"
  188. }
  189. tokenize(d, text, 'english')
  190. docs.append(d)
  191. texts[d["id"]] = text
  192. qrels[query][d["id"]] = int(rel)
  193. if len(docs) >= 32:
  194. docs_count += len(docs)
  195. docs, vector_size = self.embedding(docs)
  196. self.init_index(vector_size)
  197. settings.docStoreConn.insert(docs, self.index_name)
  198. docs = []
  199. docs, vector_size = self.embedding(docs)
  200. self.init_index(vector_size)
  201. settings.docStoreConn.insert(docs, self.index_name)
  202. return qrels, texts
  203. def save_results(self, qrels, run, texts, dataset, file_path):
  204. keep_result = []
  205. run_keys = list(run.keys())
  206. for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"):
  207. key = run_keys[run_i]
  208. keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key],
  209. 'ndcg@10': evaluate({key: qrels[key]}, {key: run[key]}, "ndcg@10")})
  210. keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10'])
  211. with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f:
  212. f.write('## Score For Every Query\n')
  213. for keep_result_i in keep_result:
  214. f.write('### query: ' + keep_result_i['query'] + ' ndcg@10:' + str(keep_result_i['ndcg@10']) + '\n')
  215. scores = [[i[0], i[1]] for i in keep_result_i['run'].items()]
  216. scores = sorted(scores, key=lambda kk: kk[1])
  217. for score in scores[:10]:
  218. f.write('- text: ' + str(texts[score[0]]) + '\t qrel: ' + str(score[1]) + '\n')
  219. json.dump(qrels, open(os.path.join(file_path, dataset + '.qrels.json'), "w+", encoding='utf-8'), indent=2)
  220. json.dump(run, open(os.path.join(file_path, dataset + '.run.json'), "w+", encoding='utf-8'), indent=2)
  221. print(os.path.join(file_path, dataset + '_result.md'), 'Saved!')
  222. def __call__(self, dataset, file_path, miracl_corpus=''):
  223. if dataset == "ms_marco_v1.1":
  224. self.tenant_id = "benchmark_ms_marco_v11"
  225. self.index_name = search.index_name(self.tenant_id)
  226. qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1")
  227. run = self._get_retrieval(qrels)
  228. print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr@10"]))
  229. self.save_results(qrels, run, texts, dataset, file_path)
  230. if dataset == "trivia_qa":
  231. self.tenant_id = "benchmark_trivia_qa"
  232. self.index_name = search.index_name(self.tenant_id)
  233. qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa")
  234. run = self._get_retrieval(qrels)
  235. print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr@10"]))
  236. self.save_results(qrels, run, texts, dataset, file_path)
  237. if dataset == "miracl":
  238. for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th',
  239. 'yo', 'zh']:
  240. if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang)):
  241. print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang) + ' not found!')
  242. continue
  243. if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels')):
  244. print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels') + 'not found!')
  245. continue
  246. if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics')):
  247. print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics') + 'not found!')
  248. continue
  249. if not os.path.isdir(os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang)):
  250. print('Directory: ' + os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang) + ' not found!')
  251. continue
  252. self.tenant_id = "benchmark_miracl_" + lang
  253. self.index_name = search.index_name(self.tenant_id)
  254. self.initialized_index = False
  255. qrels, texts = self.miracl_index(os.path.join(file_path, 'miracl-v1.0-' + lang),
  256. os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang),
  257. "benchmark_miracl_" + lang)
  258. run = self._get_retrieval(qrels)
  259. print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr@10"]))
  260. self.save_results(qrels, run, texts, dataset, file_path)
  261. if __name__ == '__main__':
  262. print('*****************RAGFlow Benchmark*****************')
  263. parser = argparse.ArgumentParser(usage="benchmark.py <max_docs> <kb_id> <dataset> <dataset_path> [<miracl_corpus_path>])", description='RAGFlow Benchmark')
  264. parser.add_argument('max_docs', metavar='max_docs', type=int, help='max docs to evaluate')
  265. parser.add_argument('kb_id', metavar='kb_id', help='knowledgebase id')
  266. parser.add_argument('dataset', metavar='dataset', help='dataset name, shall be one of ms_marco_v1.1(https://huggingface.co/datasets/microsoft/ms_marco), trivia_qa(https://huggingface.co/datasets/mandarjoshi/trivia_qa>), miracl(https://huggingface.co/datasets/miracl/miracl')
  267. parser.add_argument('dataset_path', metavar='dataset_path', help='dataset path')
  268. parser.add_argument('miracl_corpus_path', metavar='miracl_corpus_path', nargs='?', default="", help='miracl corpus path. Only needed when dataset is miracl')
  269. args = parser.parse_args()
  270. max_docs = args.max_docs
  271. kb_id = args.kb_id
  272. ex = Benchmark(kb_id)
  273. dataset = args.dataset
  274. dataset_path = args.dataset_path
  275. if dataset == "ms_marco_v1.1" or dataset == "trivia_qa":
  276. ex(dataset, dataset_path)
  277. elif dataset == "miracl":
  278. if len(args) < 5:
  279. print('Please input the correct parameters!')
  280. exit(1)
  281. miracl_corpus_path = args[4]
  282. ex(dataset, dataset_path, miracl_corpus=args.miracl_corpus_path)
  283. else:
  284. print("Dataset: ", dataset, "not supported!")