Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

benchmark.py 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import os
  18. from collections import defaultdict
  19. from concurrent.futures import ThreadPoolExecutor
  20. from copy import deepcopy
  21. from api.db import LLMType
  22. from api.db.services.llm_service import LLMBundle
  23. from api.db.services.knowledgebase_service import KnowledgebaseService
  24. from api.settings import retrievaler
  25. from api.utils import get_uuid
  26. from api.utils.file_utils import get_project_base_directory
  27. from rag.nlp import tokenize, search
  28. from rag.utils.es_conn import ELASTICSEARCH
  29. from ranx import evaluate
  30. import pandas as pd
  31. from tqdm import tqdm
  32. from ranx import Qrels, Run
  33. class Benchmark:
  34. def __init__(self, kb_id):
  35. e, self.kb = KnowledgebaseService.get_by_id(kb_id)
  36. self.similarity_threshold = self.kb.similarity_threshold
  37. self.vector_similarity_weight = self.kb.vector_similarity_weight
  38. self.embd_mdl = LLMBundle(self.kb.tenant_id, LLMType.EMBEDDING, llm_name=self.kb.embd_id, lang=self.kb.language)
  39. def _get_benchmarks(self, query, dataset_idxnm, count=16):
  40. req = {"question": query, "size": count, "vector": True, "similarity": self.similarity_threshold}
  41. sres = retrievaler.search(req, search.index_name(dataset_idxnm), self.embd_mdl)
  42. return sres
  43. def _get_retrieval(self, qrels, dataset_idxnm):
  44. run = defaultdict(dict)
  45. query_list = list(qrels.keys())
  46. for query in query_list:
  47. ranks = retrievaler.retrieval(query, self.embd_mdl,
  48. dataset_idxnm, [self.kb.id], 1, 30,
  49. 0.0, self.vector_similarity_weight)
  50. for c in ranks["chunks"]:
  51. if "vector" in c:
  52. del c["vector"]
  53. run[query][c["chunk_id"]] = c["similarity"]
  54. return run
  55. def embedding(self, docs, batch_size=16):
  56. vects = []
  57. cnts = [d["content_with_weight"] for d in docs]
  58. for i in range(0, len(cnts), batch_size):
  59. vts, c = self.embd_mdl.encode(cnts[i: i + batch_size])
  60. vects.extend(vts.tolist())
  61. assert len(docs) == len(vects)
  62. for i, d in enumerate(docs):
  63. v = vects[i]
  64. d["q_%d_vec" % len(v)] = v
  65. return docs
  66. @staticmethod
  67. def init_kb(index_name):
  68. idxnm = search.index_name(index_name)
  69. if ELASTICSEARCH.indexExist(idxnm):
  70. ELASTICSEARCH.deleteIdx(search.index_name(index_name))
  71. return ELASTICSEARCH.createIdx(idxnm, json.load(
  72. open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
  73. def ms_marco_index(self, file_path, index_name):
  74. qrels = defaultdict(dict)
  75. texts = defaultdict(dict)
  76. docs = []
  77. filelist = os.listdir(file_path)
  78. self.init_kb(index_name)
  79. max_workers = int(os.environ.get('MAX_WORKERS', 3))
  80. exe = ThreadPoolExecutor(max_workers=max_workers)
  81. threads = []
  82. def slow_actions(es_docs, idx_nm):
  83. es_docs = self.embedding(es_docs)
  84. ELASTICSEARCH.bulk(es_docs, idx_nm)
  85. return True
  86. for dir in filelist:
  87. data = pd.read_parquet(os.path.join(file_path, dir))
  88. for i in tqdm(range(len(data)), colour="green", desc="Tokenizing:" + dir):
  89. query = data.iloc[i]['query']
  90. for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
  91. d = {
  92. "id": get_uuid(),
  93. "kb_id": self.kb.id,
  94. "docnm_kwd": "xxxxx",
  95. "doc_id": "ksksks"
  96. }
  97. tokenize(d, text, "english")
  98. docs.append(d)
  99. texts[d["id"]] = text
  100. qrels[query][d["id"]] = int(rel)
  101. if len(docs) >= 32:
  102. threads.append(
  103. exe.submit(slow_actions, deepcopy(docs), search.index_name(index_name)))
  104. docs = []
  105. threads.append(
  106. exe.submit(slow_actions, deepcopy(docs), search.index_name(index_name)))
  107. for i in tqdm(range(len(threads)), colour="red", desc="Indexing:" + dir):
  108. if not threads[i].result().output:
  109. print("Indexing error...")
  110. return qrels, texts
  111. def trivia_qa_index(self, file_path, index_name):
  112. qrels = defaultdict(dict)
  113. texts = defaultdict(dict)
  114. docs = []
  115. filelist = os.listdir(file_path)
  116. for dir in filelist:
  117. data = pd.read_parquet(os.path.join(file_path, dir))
  118. for i in tqdm(range(len(data)), colour="green", desc="Indexing:" + dir):
  119. query = data.iloc[i]['question']
  120. for rel, text in zip(data.iloc[i]["search_results"]['rank'],
  121. data.iloc[i]["search_results"]['search_context']):
  122. d = {
  123. "id": get_uuid(),
  124. "kb_id": self.kb.id,
  125. "docnm_kwd": "xxxxx",
  126. "doc_id": "ksksks"
  127. }
  128. tokenize(d, text, "english")
  129. docs.append(d)
  130. texts[d["id"]] = text
  131. qrels[query][d["id"]] = int(rel)
  132. if len(docs) >= 32:
  133. docs = self.embedding(docs)
  134. ELASTICSEARCH.bulk(docs, search.index_name(index_name))
  135. docs = []
  136. docs = self.embedding(docs)
  137. ELASTICSEARCH.bulk(docs, search.index_name(index_name))
  138. return qrels, texts
  139. def miracl_index(self, file_path, corpus_path, index_name):
  140. corpus_total = {}
  141. for corpus_file in os.listdir(corpus_path):
  142. tmp_data = pd.read_json(os.path.join(corpus_path, corpus_file), lines=True)
  143. for index, i in tmp_data.iterrows():
  144. corpus_total[i['docid']] = i['text']
  145. topics_total = {}
  146. for topics_file in os.listdir(os.path.join(file_path, 'topics')):
  147. if 'test' in topics_file:
  148. continue
  149. tmp_data = pd.read_csv(os.path.join(file_path, 'topics', topics_file), sep='\t', names=['qid', 'query'])
  150. for index, i in tmp_data.iterrows():
  151. topics_total[i['qid']] = i['query']
  152. qrels = defaultdict(dict)
  153. texts = defaultdict(dict)
  154. docs = []
  155. for qrels_file in os.listdir(os.path.join(file_path, 'qrels')):
  156. if 'test' in qrels_file:
  157. continue
  158. tmp_data = pd.read_csv(os.path.join(file_path, 'qrels', qrels_file), sep='\t',
  159. names=['qid', 'Q0', 'docid', 'relevance'])
  160. for i in tqdm(range(len(tmp_data)), colour="green", desc="Indexing:" + qrels_file):
  161. query = topics_total[tmp_data.iloc[i]['qid']]
  162. text = corpus_total[tmp_data.iloc[i]['docid']]
  163. rel = tmp_data.iloc[i]['relevance']
  164. d = {
  165. "id": get_uuid(),
  166. "kb_id": self.kb.id,
  167. "docnm_kwd": "xxxxx",
  168. "doc_id": "ksksks"
  169. }
  170. tokenize(d, text, 'english')
  171. docs.append(d)
  172. texts[d["id"]] = text
  173. qrels[query][d["id"]] = int(rel)
  174. if len(docs) >= 32:
  175. docs = self.embedding(docs)
  176. ELASTICSEARCH.bulk(docs, search.index_name(index_name))
  177. docs = []
  178. docs = self.embedding(docs)
  179. ELASTICSEARCH.bulk(docs, search.index_name(index_name))
  180. return qrels, texts
  181. def save_results(self, qrels, run, texts, dataset, file_path):
  182. keep_result = []
  183. run_keys = list(run.keys())
  184. for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"):
  185. key = run_keys[run_i]
  186. keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key],
  187. 'ndcg@10': evaluate(Qrels({key: qrels[key]}), Run({key: run[key]}), "ndcg@10")})
  188. keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10'])
  189. with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f:
  190. f.write('## Score For Every Query\n')
  191. for keep_result_i in keep_result:
  192. f.write('### query: ' + keep_result_i['query'] + ' ndcg@10:' + str(keep_result_i['ndcg@10']) + '\n')
  193. scores = [[i[0], i[1]] for i in keep_result_i['run'].items()]
  194. scores = sorted(scores, key=lambda kk: kk[1])
  195. for score in scores[:10]:
  196. f.write('- text: ' + str(texts[score[0]]) + '\t qrel: ' + str(score[1]) + '\n')
  197. json.dump(qrels, open(os.path.join(file_path, dataset + '.qrels.json'), "w+"), indent=2)
  198. json.dump(run, open(os.path.join(file_path, dataset + '.run.json'), "w+"), indent=2)
  199. print(os.path.join(file_path, dataset + '_result.md'), 'Saved!')
  200. def __call__(self, dataset, file_path, miracl_corpus=''):
  201. if dataset == "ms_marco_v1.1":
  202. qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1")
  203. run = self._get_retrieval(qrels, "benchmark_ms_marco_v1.1")
  204. print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
  205. self.save_results(qrels, run, texts, dataset, file_path)
  206. if dataset == "trivia_qa":
  207. qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa")
  208. run = self._get_retrieval(qrels, "benchmark_trivia_qa")
  209. print(dataset, evaluate((qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
  210. self.save_results(qrels, run, texts, dataset, file_path)
  211. if dataset == "miracl":
  212. for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th',
  213. 'yo', 'zh']:
  214. if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang)):
  215. print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang) + ' not found!')
  216. continue
  217. if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels')):
  218. print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels') + 'not found!')
  219. continue
  220. if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics')):
  221. print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics') + 'not found!')
  222. continue
  223. if not os.path.isdir(os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang)):
  224. print('Directory: ' + os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang) + ' not found!')
  225. continue
  226. qrels, texts = self.miracl_index(os.path.join(file_path, 'miracl-v1.0-' + lang),
  227. os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang),
  228. "benchmark_miracl_" + lang)
  229. run = self._get_retrieval(qrels, "benchmark_miracl_" + lang)
  230. print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
  231. self.save_results(qrels, run, texts, dataset, file_path)
  232. if __name__ == '__main__':
  233. print('*****************RAGFlow Benchmark*****************')
  234. kb_id = input('Please input kb_id:\n')
  235. ex = Benchmark(kb_id)
  236. dataset = input(
  237. 'RAGFlow Benchmark Support:\n\tms_marco_v1.1:<https://huggingface.co/datasets/microsoft/ms_marco>\n\ttrivia_qa:<https://huggingface.co/datasets/mandarjoshi/trivia_qa>\n\tmiracl:<https://huggingface.co/datasets/miracl/miracl>\nPlease input dataset choice:\n')
  238. if dataset in ['ms_marco_v1.1', 'trivia_qa']:
  239. if dataset == "ms_marco_v1.1":
  240. print("Notice: Please provide the ms_marco_v1.1 dataset only. ms_marco_v2.1 is not supported!")
  241. dataset_path = input('Please input ' + dataset + ' dataset path:\n')
  242. ex(dataset, dataset_path)
  243. elif dataset == 'miracl':
  244. dataset_path = input('Please input ' + dataset + ' dataset path:\n')
  245. corpus_path = input('Please input ' + dataset + '-corpus dataset path:\n')
  246. ex(dataset, dataset_path, miracl_corpus=corpus_path)
  247. else:
  248. print("Dataset: ", dataset, "not supported!")