Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import logging
  18. import networkx as nx
  19. import trio
  20. from api import settings
  21. from api.utils import get_uuid
  22. from api.utils.api_utils import timeout
  23. from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
  24. from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
  25. from graphrag.general.community_reports_extractor import CommunityReportsExtractor
  26. from graphrag.entity_resolution import EntityResolution
  27. from graphrag.general.extractor import Extractor
  28. from graphrag.utils import (
  29. graph_merge,
  30. get_graph,
  31. set_graph,
  32. chunk_id,
  33. does_graph_contains,
  34. tidy_graph,
  35. GraphChange,
  36. )
  37. from rag.nlp import rag_tokenizer, search
  38. from rag.utils.redis_conn import RedisDistributedLock
  39. @timeout(30, 2)
  40. async def _is_strong_enough(chat_model, embedding_model):
  41. _ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"]))
  42. res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role":"user", "content": "Are you strong enough!?"}], {}))
  43. if res.find("**ERROR**") >= 0:
  44. raise Exception(res)
  45. async def run_graphrag(
  46. row: dict,
  47. language,
  48. with_resolution: bool,
  49. with_community: bool,
  50. chat_model,
  51. embedding_model,
  52. callback,
  53. ):
  54. # Pressure test for GraphRAG task
  55. async with trio.open_nursery() as nursery:
  56. for _ in range(12):
  57. nursery.start_soon(_is_strong_enough, chat_model, embedding_model)
  58. start = trio.current_time()
  59. tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
  60. chunks = []
  61. for d in settings.retrievaler.chunk_list(
  62. doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"]
  63. ):
  64. chunks.append(d["content_with_weight"])
  65. subgraph = await generate_subgraph(
  66. LightKGExt
  67. if "method" not in row["kb_parser_config"].get("graphrag", {}) or row["kb_parser_config"]["graphrag"]["method"] != "general"
  68. else GeneralKGExt,
  69. tenant_id,
  70. kb_id,
  71. doc_id,
  72. chunks,
  73. language,
  74. row["kb_parser_config"]["graphrag"].get("entity_types", []),
  75. chat_model,
  76. embedding_model,
  77. callback,
  78. )
  79. if not subgraph:
  80. return
  81. graphrag_task_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value=doc_id, timeout=1200)
  82. await graphrag_task_lock.spin_acquire()
  83. callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
  84. try:
  85. subgraph_nodes = set(subgraph.nodes())
  86. new_graph = await merge_subgraph(
  87. tenant_id,
  88. kb_id,
  89. doc_id,
  90. subgraph,
  91. embedding_model,
  92. callback,
  93. )
  94. assert new_graph is not None
  95. if not with_resolution and not with_community:
  96. return
  97. if with_resolution:
  98. await graphrag_task_lock.spin_acquire()
  99. callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
  100. await resolve_entities(
  101. new_graph,
  102. subgraph_nodes,
  103. tenant_id,
  104. kb_id,
  105. doc_id,
  106. chat_model,
  107. embedding_model,
  108. callback,
  109. )
  110. if with_community:
  111. await graphrag_task_lock.spin_acquire()
  112. callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
  113. await extract_community(
  114. new_graph,
  115. tenant_id,
  116. kb_id,
  117. doc_id,
  118. chat_model,
  119. embedding_model,
  120. callback,
  121. )
  122. finally:
  123. graphrag_task_lock.release()
  124. now = trio.current_time()
  125. callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.")
  126. return
  127. @timeout(60*60, 1)
  128. async def generate_subgraph(
  129. extractor: Extractor,
  130. tenant_id: str,
  131. kb_id: str,
  132. doc_id: str,
  133. chunks: list[str],
  134. language,
  135. entity_types,
  136. llm_bdl,
  137. embed_bdl,
  138. callback,
  139. ):
  140. contains = await does_graph_contains(tenant_id, kb_id, doc_id)
  141. if contains:
  142. callback(msg=f"Graph already contains {doc_id}")
  143. return None
  144. start = trio.current_time()
  145. ext = extractor(
  146. llm_bdl,
  147. language=language,
  148. entity_types=entity_types,
  149. )
  150. ents, rels = await ext(doc_id, chunks, callback)
  151. subgraph = nx.Graph()
  152. for ent in ents:
  153. assert "description" in ent, f"entity {ent} does not have description"
  154. ent["source_id"] = [doc_id]
  155. subgraph.add_node(ent["entity_name"], **ent)
  156. ignored_rels = 0
  157. for rel in rels:
  158. assert "description" in rel, f"relation {rel} does not have description"
  159. if not subgraph.has_node(rel["src_id"]) or not subgraph.has_node(rel["tgt_id"]):
  160. ignored_rels += 1
  161. continue
  162. rel["source_id"] = [doc_id]
  163. subgraph.add_edge(
  164. rel["src_id"],
  165. rel["tgt_id"],
  166. **rel,
  167. )
  168. if ignored_rels:
  169. callback(msg=f"ignored {ignored_rels} relations due to missing entities.")
  170. tidy_graph(subgraph, callback, check_attribute=False)
  171. subgraph.graph["source_id"] = [doc_id]
  172. chunk = {
  173. "content_with_weight": json.dumps(
  174. nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False
  175. ),
  176. "knowledge_graph_kwd": "subgraph",
  177. "kb_id": kb_id,
  178. "source_id": [doc_id],
  179. "available_int": 0,
  180. "removed_kwd": "N",
  181. }
  182. cid = chunk_id(chunk)
  183. await trio.to_thread.run_sync(
  184. lambda: settings.docStoreConn.delete(
  185. {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id
  186. )
  187. )
  188. await trio.to_thread.run_sync(
  189. lambda: settings.docStoreConn.insert(
  190. [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id
  191. )
  192. )
  193. now = trio.current_time()
  194. callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
  195. return subgraph
  196. @timeout(60*3)
  197. async def merge_subgraph(
  198. tenant_id: str,
  199. kb_id: str,
  200. doc_id: str,
  201. subgraph: nx.Graph,
  202. embedding_model,
  203. callback,
  204. ):
  205. start = trio.current_time()
  206. change = GraphChange()
  207. old_graph = await get_graph(tenant_id, kb_id, subgraph.graph["source_id"])
  208. if old_graph is not None:
  209. logging.info("Merge with an exiting graph...................")
  210. tidy_graph(old_graph, callback)
  211. new_graph = graph_merge(old_graph, subgraph, change)
  212. else:
  213. new_graph = subgraph
  214. change.added_updated_nodes = set(new_graph.nodes())
  215. change.added_updated_edges = set(new_graph.edges())
  216. pr = nx.pagerank(new_graph)
  217. for node_name, pagerank in pr.items():
  218. new_graph.nodes[node_name]["pagerank"] = pagerank
  219. await set_graph(tenant_id, kb_id, embedding_model, new_graph, change, callback)
  220. now = trio.current_time()
  221. callback(
  222. msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds."
  223. )
  224. return new_graph
  225. @timeout(60*30, 1)
  226. async def resolve_entities(
  227. graph,
  228. subgraph_nodes: set[str],
  229. tenant_id: str,
  230. kb_id: str,
  231. doc_id: str,
  232. llm_bdl,
  233. embed_bdl,
  234. callback,
  235. ):
  236. start = trio.current_time()
  237. er = EntityResolution(
  238. llm_bdl,
  239. )
  240. reso = await er(graph, subgraph_nodes, callback=callback)
  241. graph = reso.graph
  242. change = reso.change
  243. callback(msg=f"Graph resolution removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges.")
  244. callback(msg="Graph resolution updated pagerank.")
  245. await set_graph(tenant_id, kb_id, embed_bdl, graph, change, callback)
  246. now = trio.current_time()
  247. callback(msg=f"Graph resolution done in {now - start:.2f}s.")
  248. @timeout(60*30, 1)
  249. async def extract_community(
  250. graph,
  251. tenant_id: str,
  252. kb_id: str,
  253. doc_id: str,
  254. llm_bdl,
  255. embed_bdl,
  256. callback,
  257. ):
  258. start = trio.current_time()
  259. ext = CommunityReportsExtractor(
  260. llm_bdl,
  261. )
  262. cr = await ext(graph, callback=callback)
  263. community_structure = cr.structured_output
  264. community_reports = cr.output
  265. doc_ids = graph.graph["source_id"]
  266. now = trio.current_time()
  267. callback(
  268. msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s."
  269. )
  270. start = now
  271. chunks = []
  272. for stru, rep in zip(community_structure, community_reports):
  273. obj = {
  274. "report": rep,
  275. "evidences": "\n".join([f.get("explanation", "") for f in stru["findings"]]),
  276. }
  277. chunk = {
  278. "id": get_uuid(),
  279. "docnm_kwd": stru["title"],
  280. "title_tks": rag_tokenizer.tokenize(stru["title"]),
  281. "content_with_weight": json.dumps(obj, ensure_ascii=False),
  282. "content_ltks": rag_tokenizer.tokenize(
  283. obj["report"] + " " + obj["evidences"]
  284. ),
  285. "knowledge_graph_kwd": "community_report",
  286. "weight_flt": stru["weight"],
  287. "entities_kwd": stru["entities"],
  288. "important_kwd": stru["entities"],
  289. "kb_id": kb_id,
  290. "source_id": list(doc_ids),
  291. "available_int": 0,
  292. }
  293. chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(
  294. chunk["content_ltks"]
  295. )
  296. chunks.append(chunk)
  297. await trio.to_thread.run_sync(
  298. lambda: settings.docStoreConn.delete(
  299. {"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
  300. search.index_name(tenant_id),
  301. kb_id,
  302. )
  303. )
  304. es_bulk_size = 4
  305. for b in range(0, len(chunks), es_bulk_size):
  306. doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(tenant_id), kb_id))
  307. if doc_store_result:
  308. error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
  309. raise Exception(error_message)
  310. now = trio.current_time()
  311. callback(
  312. msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s."
  313. )
  314. return community_structure, community_reports