You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import argparse
  17. from collections import defaultdict
  18. from api.db import FileType, TaskStatus, ParserType, LLMType
  19. from api.db.services.llm_service import LLMBundle
  20. from api.db.services.knowledgebase_service import KnowledgebaseService
  21. from api.settings import retrievaler
  22. from api.utils import get_uuid
  23. from rag.nlp import tokenize, search
  24. from rag.utils.es_conn import ELASTICSEARCH
  25. from ranx import evaluate
  26. class benchmark_ndcg10:
  27. def __init__(self, kb_id):
  28. e, kb = KnowledgebaseService.get_by_id(kb_id)
  29. self.similarity_threshold = kb.similarity_threshold
  30. self.vector_similarity_weight = kb.vector_similarity_weight
  31. self.embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
  32. def _get_benchmarks(self, query, count=16):
  33. req = {"question": query, "size": count, "vector": True, "similarity": self.similarity_threshold}
  34. sres = retrievaler.search(req, search.index_name("benchmark"), self.embd_mdl)
  35. return sres
  36. def _get_retrieval(self, qrels):
  37. run = defaultdict(dict)
  38. query_list = list(qrels.keys())
  39. for query in query_list:
  40. sres = self._get_benchmarks(query)
  41. sim, _, _ = retrievaler.rerank(sres, query, 1 - self.vector_similarity_weight,
  42. self.vector_similarity_weight)
  43. for index, id in enumerate(sres.ids):
  44. run[query][id] = sim[index]
  45. return run
  46. def embedding(self, docs, batch_size=16):
  47. vects = []
  48. cnts = [d["content_with_weight"] for d in docs]
  49. for i in range(0, len(cnts), batch_size):
  50. vts, c = self.embd_mdl.encode(cnts[i: i + batch_size])
  51. vects.extend(vts.tolist())
  52. assert len(docs) == len(vects)
  53. for i, d in enumerate(docs):
  54. v = vects[i]
  55. d["q_%d_vec" % len(v)] = v
  56. return docs
  57. def __call__(self, file_path):
  58. qrels = defaultdict(dict)
  59. docs = []
  60. with open(file_path) as f:
  61. for line in f:
  62. query, text, rel = line.strip('\n').split()
  63. d = {
  64. "id": get_uuid()
  65. }
  66. tokenize(d, text)
  67. docs.append(d)
  68. if len(docs) >= 32:
  69. ELASTICSEARCH.bulk(docs, search.index_name("benchmark"))
  70. docs = []
  71. qrels[query][d["id"]] = float(rel)
  72. docs = self.embedding(docs)
  73. ELASTICSEARCH.bulk(docs, search.index_name("benchmark"))
  74. run = self._get_retrieval(qrels)
  75. return evaluate(qrels, run, "ndcg@10")
  76. if __name__ == '__main__':
  77. parser = argparse.ArgumentParser()
  78. parser.add_argument('-f', '--filepath', default='', help="file path", action='store', required=True)
  79. parser.add_argument('-k', '--kb_id', default='', help="kb_id", action='store', required=True)
  80. args = parser.parse_args()
  81. ex = benchmark_ndcg10(args.kb_id)
  82. print(ex(args.filepath))