Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

benchmark.py 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import os
  18. import sys
  19. import time
  20. import argparse
  21. from collections import defaultdict
  22. from api.db import LLMType
  23. from api.db.services.llm_service import LLMBundle
  24. from api.db.services.knowledgebase_service import KnowledgebaseService
  25. from api.settings import retrievaler, docStoreConn
  26. from api.utils import get_uuid
  27. from rag.nlp import tokenize, search
  28. from ranx import evaluate
  29. import pandas as pd
  30. from tqdm import tqdm
  31. global max_docs
  32. max_docs = sys.maxsize
  33. class Benchmark:
  34. def __init__(self, kb_id):
  35. self.kb_id = kb_id
  36. e, self.kb = KnowledgebaseService.get_by_id(kb_id)
  37. self.similarity_threshold = self.kb.similarity_threshold
  38. self.vector_similarity_weight = self.kb.vector_similarity_weight
  39. self.embd_mdl = LLMBundle(self.kb.tenant_id, LLMType.EMBEDDING, llm_name=self.kb.embd_id, lang=self.kb.language)
  40. self.tenant_id = ''
  41. self.index_name = ''
  42. self.initialized_index = False
  43. def _get_retrieval(self, qrels):
  44. # Need to wait for the ES and Infinity index to be ready
  45. time.sleep(20)
  46. run = defaultdict(dict)
  47. query_list = list(qrels.keys())
  48. for query in query_list:
  49. ranks = retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
  50. 0.0, self.vector_similarity_weight)
  51. if len(ranks["chunks"]) == 0:
  52. print(f"deleted query: {query}")
  53. del qrels[query]
  54. continue
  55. for c in ranks["chunks"]:
  56. if "vector" in c:
  57. del c["vector"]
  58. run[query][c["chunk_id"]] = c["similarity"]
  59. return run
  60. def embedding(self, docs, batch_size=16):
  61. vects = []
  62. cnts = [d["content_with_weight"] for d in docs]
  63. for i in range(0, len(cnts), batch_size):
  64. vts, c = self.embd_mdl.encode(cnts[i: i + batch_size])
  65. vects.extend(vts.tolist())
  66. assert len(docs) == len(vects)
  67. vector_size = 0
  68. for i, d in enumerate(docs):
  69. v = vects[i]
  70. vector_size = len(v)
  71. d["q_%d_vec" % len(v)] = v
  72. return docs, vector_size
  73. def init_index(self, vector_size: int):
  74. if self.initialized_index:
  75. return
  76. if docStoreConn.indexExist(self.index_name, self.kb_id):
  77. docStoreConn.deleteIdx(self.index_name, self.kb_id)
  78. docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
  79. self.initialized_index = True
  80. def ms_marco_index(self, file_path, index_name):
  81. qrels = defaultdict(dict)
  82. texts = defaultdict(dict)
  83. docs_count = 0
  84. docs = []
  85. filelist = sorted(os.listdir(file_path))
  86. for fn in filelist:
  87. if docs_count >= max_docs:
  88. break
  89. if not fn.endswith(".parquet"):
  90. continue
  91. data = pd.read_parquet(os.path.join(file_path, fn))
  92. for i in tqdm(range(len(data)), colour="green", desc="Tokenizing:" + fn):
  93. if docs_count >= max_docs:
  94. break
  95. query = data.iloc[i]['query']
  96. for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
  97. d = {
  98. "id": get_uuid(),
  99. "kb_id": self.kb.id,
  100. "docnm_kwd": "xxxxx",
  101. "doc_id": "ksksks"
  102. }
  103. tokenize(d, text, "english")
  104. docs.append(d)
  105. texts[d["id"]] = text
  106. qrels[query][d["id"]] = int(rel)
  107. if len(docs) >= 32:
  108. docs_count += len(docs)
  109. docs, vector_size = self.embedding(docs)
  110. self.init_index(vector_size)
  111. docStoreConn.insert(docs, self.index_name, self.kb_id)
  112. docs = []
  113. if docs:
  114. docs, vector_size = self.embedding(docs)
  115. self.init_index(vector_size)
  116. docStoreConn.insert(docs, self.index_name, self.kb_id)
  117. return qrels, texts
  118. def trivia_qa_index(self, file_path, index_name):
  119. qrels = defaultdict(dict)
  120. texts = defaultdict(dict)
  121. docs_count = 0
  122. docs = []
  123. filelist = sorted(os.listdir(file_path))
  124. for fn in filelist:
  125. if docs_count >= max_docs:
  126. break
  127. if not fn.endswith(".parquet"):
  128. continue
  129. data = pd.read_parquet(os.path.join(file_path, fn))
  130. for i in tqdm(range(len(data)), colour="green", desc="Indexing:" + fn):
  131. if docs_count >= max_docs:
  132. break
  133. query = data.iloc[i]['question']
  134. for rel, text in zip(data.iloc[i]["search_results"]['rank'],
  135. data.iloc[i]["search_results"]['search_context']):
  136. d = {
  137. "id": get_uuid(),
  138. "kb_id": self.kb.id,
  139. "docnm_kwd": "xxxxx",
  140. "doc_id": "ksksks"
  141. }
  142. tokenize(d, text, "english")
  143. docs.append(d)
  144. texts[d["id"]] = text
  145. qrels[query][d["id"]] = int(rel)
  146. if len(docs) >= 32:
  147. docs_count += len(docs)
  148. docs, vector_size = self.embedding(docs)
  149. self.init_index(vector_size)
  150. docStoreConn.insert(docs,self.index_name)
  151. docs = []
  152. docs, vector_size = self.embedding(docs)
  153. self.init_index(vector_size)
  154. docStoreConn.insert(docs, self.index_name)
  155. return qrels, texts
  156. def miracl_index(self, file_path, corpus_path, index_name):
  157. corpus_total = {}
  158. for corpus_file in os.listdir(corpus_path):
  159. tmp_data = pd.read_json(os.path.join(corpus_path, corpus_file), lines=True)
  160. for index, i in tmp_data.iterrows():
  161. corpus_total[i['docid']] = i['text']
  162. topics_total = {}
  163. for topics_file in os.listdir(os.path.join(file_path, 'topics')):
  164. if 'test' in topics_file:
  165. continue
  166. tmp_data = pd.read_csv(os.path.join(file_path, 'topics', topics_file), sep='\t', names=['qid', 'query'])
  167. for index, i in tmp_data.iterrows():
  168. topics_total[i['qid']] = i['query']
  169. qrels = defaultdict(dict)
  170. texts = defaultdict(dict)
  171. docs_count = 0
  172. docs = []
  173. for qrels_file in os.listdir(os.path.join(file_path, 'qrels')):
  174. if 'test' in qrels_file:
  175. continue
  176. if docs_count >= max_docs:
  177. break
  178. tmp_data = pd.read_csv(os.path.join(file_path, 'qrels', qrels_file), sep='\t',
  179. names=['qid', 'Q0', 'docid', 'relevance'])
  180. for i in tqdm(range(len(tmp_data)), colour="green", desc="Indexing:" + qrels_file):
  181. if docs_count >= max_docs:
  182. break
  183. query = topics_total[tmp_data.iloc[i]['qid']]
  184. text = corpus_total[tmp_data.iloc[i]['docid']]
  185. rel = tmp_data.iloc[i]['relevance']
  186. d = {
  187. "id": get_uuid(),
  188. "kb_id": self.kb.id,
  189. "docnm_kwd": "xxxxx",
  190. "doc_id": "ksksks"
  191. }
  192. tokenize(d, text, 'english')
  193. docs.append(d)
  194. texts[d["id"]] = text
  195. qrels[query][d["id"]] = int(rel)
  196. if len(docs) >= 32:
  197. docs_count += len(docs)
  198. docs, vector_size = self.embedding(docs)
  199. self.init_index(vector_size)
  200. docStoreConn.insert(docs, self.index_name)
  201. docs = []
  202. docs, vector_size = self.embedding(docs)
  203. self.init_index(vector_size)
  204. docStoreConn.insert(docs, self.index_name)
  205. return qrels, texts
  206. def save_results(self, qrels, run, texts, dataset, file_path):
  207. keep_result = []
  208. run_keys = list(run.keys())
  209. for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"):
  210. key = run_keys[run_i]
  211. keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key],
  212. 'ndcg@10': evaluate({key: qrels[key]}, {key: run[key]}, "ndcg@10")})
  213. keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10'])
  214. with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f:
  215. f.write('## Score For Every Query\n')
  216. for keep_result_i in keep_result:
  217. f.write('### query: ' + keep_result_i['query'] + ' ndcg@10:' + str(keep_result_i['ndcg@10']) + '\n')
  218. scores = [[i[0], i[1]] for i in keep_result_i['run'].items()]
  219. scores = sorted(scores, key=lambda kk: kk[1])
  220. for score in scores[:10]:
  221. f.write('- text: ' + str(texts[score[0]]) + '\t qrel: ' + str(score[1]) + '\n')
  222. json.dump(qrels, open(os.path.join(file_path, dataset + '.qrels.json'), "w+"), indent=2)
  223. json.dump(run, open(os.path.join(file_path, dataset + '.run.json'), "w+"), indent=2)
  224. print(os.path.join(file_path, dataset + '_result.md'), 'Saved!')
  225. def __call__(self, dataset, file_path, miracl_corpus=''):
  226. if dataset == "ms_marco_v1.1":
  227. self.tenant_id = "benchmark_ms_marco_v11"
  228. self.index_name = search.index_name(self.tenant_id)
  229. qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1")
  230. run = self._get_retrieval(qrels)
  231. print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
  232. self.save_results(qrels, run, texts, dataset, file_path)
  233. if dataset == "trivia_qa":
  234. self.tenant_id = "benchmark_trivia_qa"
  235. self.index_name = search.index_name(self.tenant_id)
  236. qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa")
  237. run = self._get_retrieval(qrels)
  238. print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
  239. self.save_results(qrels, run, texts, dataset, file_path)
  240. if dataset == "miracl":
  241. for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th',
  242. 'yo', 'zh']:
  243. if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang)):
  244. print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang) + ' not found!')
  245. continue
  246. if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels')):
  247. print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels') + 'not found!')
  248. continue
  249. if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics')):
  250. print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics') + 'not found!')
  251. continue
  252. if not os.path.isdir(os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang)):
  253. print('Directory: ' + os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang) + ' not found!')
  254. continue
  255. self.tenant_id = "benchmark_miracl_" + lang
  256. self.index_name = search.index_name(self.tenant_id)
  257. self.initialized_index = False
  258. qrels, texts = self.miracl_index(os.path.join(file_path, 'miracl-v1.0-' + lang),
  259. os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang),
  260. "benchmark_miracl_" + lang)
  261. run = self._get_retrieval(qrels)
  262. print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
  263. self.save_results(qrels, run, texts, dataset, file_path)
  264. if __name__ == '__main__':
  265. print('*****************RAGFlow Benchmark*****************')
  266. parser = argparse.ArgumentParser(usage="benchmark.py <max_docs> <kb_id> <dataset> <dataset_path> [<miracl_corpus_path>])", description='RAGFlow Benchmark')
  267. parser.add_argument('max_docs', metavar='max_docs', type=int, help='max docs to evaluate')
  268. parser.add_argument('kb_id', metavar='kb_id', help='knowledgebase id')
  269. parser.add_argument('dataset', metavar='dataset', help='dataset name, shall be one of ms_marco_v1.1(https://huggingface.co/datasets/microsoft/ms_marco), trivia_qa(https://huggingface.co/datasets/mandarjoshi/trivia_qa>), miracl(https://huggingface.co/datasets/miracl/miracl')
  270. parser.add_argument('dataset_path', metavar='dataset_path', help='dataset path')
  271. parser.add_argument('miracl_corpus_path', metavar='miracl_corpus_path', nargs='?', default="", help='miracl corpus path. Only needed when dataset is miracl')
  272. args = parser.parse_args()
  273. max_docs = args.max_docs
  274. kb_id = args.kb_id
  275. ex = Benchmark(kb_id)
  276. dataset = args.dataset
  277. dataset_path = args.dataset_path
  278. if dataset == "ms_marco_v1.1" or dataset == "trivia_qa":
  279. ex(dataset, dataset_path)
  280. elif dataset == "miracl":
  281. if len(args) < 5:
  282. print('Please input the correct parameters!')
  283. exit(1)
  284. miracl_corpus_path = args[4]
  285. ex(dataset, dataset_path, miracl_corpus=args.miracl_corpus_path)
  286. else:
  287. print("Dataset: ", dataset, "not supported!")