| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394 |
- #
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
-
- import argparse
- from collections import defaultdict
- from api.db import FileType, TaskStatus, ParserType, LLMType
- from api.db.services.llm_service import LLMBundle
- from api.db.services.knowledgebase_service import KnowledgebaseService
- from api.settings import retrievaler
- from api.utils import get_uuid
- from rag.nlp import tokenize, search
- from rag.utils.es_conn import ELASTICSEARCH
- from ranx import evaluate
-
-
- class benchmark_ndcg10:
- def __init__(self, kb_id):
- e, kb = KnowledgebaseService.get_by_id(kb_id)
- self.similarity_threshold = kb.similarity_threshold
- self.vector_similarity_weight = kb.vector_similarity_weight
- self.embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
-
- def _get_benchmarks(self, query, count=16):
- req = {"question": query, "size": count, "vector": True, "similarity": self.similarity_threshold}
- sres = retrievaler.search(req, search.index_name("benchmark"), self.embd_mdl)
- return sres
-
- def _get_retrieval(self, qrels):
- run = defaultdict(dict)
- query_list = list(qrels.keys())
- for query in query_list:
- sres = self._get_benchmarks(query)
- sim, _, _ = retrievaler.rerank(sres, query, 1 - self.vector_similarity_weight,
- self.vector_similarity_weight)
- for index, id in enumerate(sres.ids):
- run[query][id] = sim[index]
- return run
-
- def embedding(self, docs, batch_size=16):
- vects = []
- cnts = [d["content_with_weight"] for d in docs]
- for i in range(0, len(cnts), batch_size):
- vts, c = self.embd_mdl.encode(cnts[i: i + batch_size])
- vects.extend(vts.tolist())
- assert len(docs) == len(vects)
- for i, d in enumerate(docs):
- v = vects[i]
- d["q_%d_vec" % len(v)] = v
- return docs
-
- def __call__(self, file_path):
- qrels = defaultdict(dict)
-
- docs = []
- with open(file_path) as f:
- for line in f:
- query, text, rel = line.strip('\n').split()
- d = {
- "id": get_uuid()
- }
- tokenize(d, text)
- docs.append(d)
- if len(docs) >= 32:
- ELASTICSEARCH.bulk(docs, search.index_name("benchmark"))
- docs = []
- qrels[query][d["id"]] = float(rel)
- docs = self.embedding(docs)
- ELASTICSEARCH.bulk(docs, search.index_name("benchmark"))
-
- run = self._get_retrieval(qrels)
- return evaluate(qrels, run, "ndcg@10")
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('-f', '--filepath', default='', help="file path", action='store', required=True)
- parser.add_argument('-k', '--kb_id', default='', help="kb_id", action='store', required=True)
- args = parser.parse_args()
-
- ex = benchmark_ndcg10(args.kb_id)
- print(ex(args.filepath))
|