|
|
|
@@ -19,6 +19,7 @@ from collections import defaultdict |
|
|
|
from copy import deepcopy |
|
|
|
import json_repair |
|
|
|
import pandas as pd |
|
|
|
import trio |
|
|
|
|
|
|
|
from api.utils import get_uuid |
|
|
|
from graphrag.query_analyze_prompt import PROMPTS |
|
|
|
@@ -41,7 +42,7 @@ class KGSearch(Dealer): |
|
|
|
return response |
|
|
|
|
|
|
|
def query_rewrite(self, llm, question, idxnms, kb_ids): |
|
|
|
ty2ents = get_entity_type2sampels(idxnms, kb_ids) |
|
|
|
ty2ents = trio.run(lambda: get_entity_type2sampels(idxnms, kb_ids)) |
|
|
|
hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question, |
|
|
|
TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2)) |
|
|
|
result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {"temperature": .5}) |