Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import logging
  18. from functools import partial
  19. import networkx as nx
  20. import trio
  21. from api import settings
  22. from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
  23. from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
  24. from graphrag.general.community_reports_extractor import CommunityReportsExtractor
  25. from graphrag.entity_resolution import EntityResolution
  26. from graphrag.general.extractor import Extractor
  27. from graphrag.utils import (
  28. graph_merge,
  29. set_entity,
  30. get_relation,
  31. set_relation,
  32. get_entity,
  33. get_graph,
  34. set_graph,
  35. chunk_id,
  36. update_nodes_pagerank_nhop_neighbour,
  37. does_graph_contains,
  38. get_graph_doc_ids,
  39. )
  40. from rag.nlp import rag_tokenizer, search
  41. from rag.utils.redis_conn import REDIS_CONN
  42. def graphrag_task_set(tenant_id, kb_id, doc_id) -> bool:
  43. key = f"graphrag:{tenant_id}:{kb_id}"
  44. ok = REDIS_CONN.set(key, doc_id, exp=3600 * 24)
  45. if not ok:
  46. raise Exception(f"Faild to set the {key} to {doc_id}")
  47. def graphrag_task_get(tenant_id, kb_id) -> str | None:
  48. key = f"graphrag:{tenant_id}:{kb_id}"
  49. doc_id = REDIS_CONN.get(key)
  50. return doc_id
  51. async def run_graphrag(
  52. row: dict,
  53. language,
  54. with_resolution: bool,
  55. with_community: bool,
  56. chat_model,
  57. embedding_model,
  58. callback,
  59. ):
  60. start = trio.current_time()
  61. tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
  62. chunks = []
  63. for d in settings.retrievaler.chunk_list(
  64. doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"]
  65. ):
  66. chunks.append(d["content_with_weight"])
  67. graph, doc_ids = await update_graph(
  68. LightKGExt
  69. if row["parser_config"]["graphrag"]["method"] != "general"
  70. else GeneralKGExt,
  71. tenant_id,
  72. kb_id,
  73. doc_id,
  74. chunks,
  75. language,
  76. row["parser_config"]["graphrag"]["entity_types"],
  77. chat_model,
  78. embedding_model,
  79. callback,
  80. )
  81. if not graph:
  82. return
  83. if with_resolution or with_community:
  84. graphrag_task_set(tenant_id, kb_id, doc_id)
  85. if with_resolution:
  86. await resolve_entities(
  87. graph,
  88. doc_ids,
  89. tenant_id,
  90. kb_id,
  91. doc_id,
  92. chat_model,
  93. embedding_model,
  94. callback,
  95. )
  96. if with_community:
  97. await extract_community(
  98. graph,
  99. doc_ids,
  100. tenant_id,
  101. kb_id,
  102. doc_id,
  103. chat_model,
  104. embedding_model,
  105. callback,
  106. )
  107. now = trio.current_time()
  108. callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.")
  109. return
  110. async def update_graph(
  111. extractor: Extractor,
  112. tenant_id: str,
  113. kb_id: str,
  114. doc_id: str,
  115. chunks: list[str],
  116. language,
  117. entity_types,
  118. llm_bdl,
  119. embed_bdl,
  120. callback,
  121. ):
  122. contains = await does_graph_contains(tenant_id, kb_id, doc_id)
  123. if contains:
  124. callback(msg=f"Graph already contains {doc_id}, cancel myself")
  125. return None, None
  126. start = trio.current_time()
  127. ext = extractor(
  128. llm_bdl,
  129. language=language,
  130. entity_types=entity_types,
  131. get_entity=partial(get_entity, tenant_id, kb_id),
  132. set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl),
  133. get_relation=partial(get_relation, tenant_id, kb_id),
  134. set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
  135. )
  136. ents, rels = await ext(doc_id, chunks, callback)
  137. subgraph = nx.Graph()
  138. for en in ents:
  139. subgraph.add_node(en["entity_name"], entity_type=en["entity_type"])
  140. for rel in rels:
  141. subgraph.add_edge(
  142. rel["src_id"],
  143. rel["tgt_id"],
  144. weight=rel["weight"],
  145. # description=rel["description"]
  146. )
  147. # TODO: infinity doesn't support array search
  148. chunk = {
  149. "content_with_weight": json.dumps(
  150. nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False, indent=2
  151. ),
  152. "knowledge_graph_kwd": "subgraph",
  153. "kb_id": kb_id,
  154. "source_id": [doc_id],
  155. "available_int": 0,
  156. "removed_kwd": "N",
  157. }
  158. cid = chunk_id(chunk)
  159. await trio.to_thread.run_sync(
  160. lambda: settings.docStoreConn.insert(
  161. [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id
  162. )
  163. )
  164. now = trio.current_time()
  165. callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
  166. start = now
  167. while True:
  168. new_graph = subgraph
  169. now_docids = set([doc_id])
  170. old_graph, old_doc_ids = await get_graph(tenant_id, kb_id)
  171. if old_graph is not None:
  172. logging.info("Merge with an exiting graph...................")
  173. new_graph = graph_merge(old_graph, subgraph)
  174. await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, new_graph, 2)
  175. if old_doc_ids:
  176. for old_doc_id in old_doc_ids:
  177. now_docids.add(old_doc_id)
  178. old_doc_ids2 = await get_graph_doc_ids(tenant_id, kb_id)
  179. delta_doc_ids = set(old_doc_ids2) - set(old_doc_ids)
  180. if delta_doc_ids:
  181. callback(
  182. msg="The global graph has changed during merging, try again"
  183. )
  184. await trio.sleep(1)
  185. continue
  186. break
  187. await set_graph(tenant_id, kb_id, new_graph, list(now_docids))
  188. now = trio.current_time()
  189. callback(
  190. msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds."
  191. )
  192. return new_graph, now_docids
  193. async def resolve_entities(
  194. graph,
  195. doc_ids,
  196. tenant_id: str,
  197. kb_id: str,
  198. doc_id: str,
  199. llm_bdl,
  200. embed_bdl,
  201. callback,
  202. ):
  203. working_doc_id = graphrag_task_get(tenant_id, kb_id)
  204. if doc_id != working_doc_id:
  205. callback(
  206. msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
  207. )
  208. return
  209. start = trio.current_time()
  210. er = EntityResolution(
  211. llm_bdl,
  212. get_entity=partial(get_entity, tenant_id, kb_id),
  213. set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl),
  214. get_relation=partial(get_relation, tenant_id, kb_id),
  215. set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
  216. )
  217. reso = await er(graph)
  218. graph = reso.graph
  219. callback(msg=f"Graph resolution removed {len(reso.removed_entities)} nodes.")
  220. await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, 2)
  221. callback(msg="Graph resolution updated pagerank.")
  222. working_doc_id = graphrag_task_get(tenant_id, kb_id)
  223. if doc_id != working_doc_id:
  224. callback(
  225. msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
  226. )
  227. return
  228. await set_graph(tenant_id, kb_id, graph, doc_ids)
  229. await trio.to_thread.run_sync(
  230. lambda: settings.docStoreConn.delete(
  231. {
  232. "knowledge_graph_kwd": "relation",
  233. "kb_id": kb_id,
  234. "from_entity_kwd": reso.removed_entities,
  235. },
  236. search.index_name(tenant_id),
  237. kb_id,
  238. )
  239. )
  240. await trio.to_thread.run_sync(
  241. lambda: settings.docStoreConn.delete(
  242. {
  243. "knowledge_graph_kwd": "relation",
  244. "kb_id": kb_id,
  245. "to_entity_kwd": reso.removed_entities,
  246. },
  247. search.index_name(tenant_id),
  248. kb_id,
  249. )
  250. )
  251. await trio.to_thread.run_sync(
  252. lambda: settings.docStoreConn.delete(
  253. {
  254. "knowledge_graph_kwd": "entity",
  255. "kb_id": kb_id,
  256. "entity_kwd": reso.removed_entities,
  257. },
  258. search.index_name(tenant_id),
  259. kb_id,
  260. )
  261. )
  262. now = trio.current_time()
  263. callback(msg=f"Graph resolution done in {now - start:.2f}s.")
  264. async def extract_community(
  265. graph,
  266. doc_ids,
  267. tenant_id: str,
  268. kb_id: str,
  269. doc_id: str,
  270. llm_bdl,
  271. embed_bdl,
  272. callback,
  273. ):
  274. working_doc_id = graphrag_task_get(tenant_id, kb_id)
  275. if doc_id != working_doc_id:
  276. callback(
  277. msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
  278. )
  279. return
  280. start = trio.current_time()
  281. ext = CommunityReportsExtractor(
  282. llm_bdl,
  283. get_entity=partial(get_entity, tenant_id, kb_id),
  284. set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl),
  285. get_relation=partial(get_relation, tenant_id, kb_id),
  286. set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
  287. )
  288. cr = await ext(graph, callback=callback)
  289. community_structure = cr.structured_output
  290. community_reports = cr.output
  291. working_doc_id = graphrag_task_get(tenant_id, kb_id)
  292. if doc_id != working_doc_id:
  293. callback(
  294. msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
  295. )
  296. return
  297. await set_graph(tenant_id, kb_id, graph, doc_ids)
  298. now = trio.current_time()
  299. callback(
  300. msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s."
  301. )
  302. start = now
  303. await trio.to_thread.run_sync(
  304. lambda: settings.docStoreConn.delete(
  305. {"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
  306. search.index_name(tenant_id),
  307. kb_id,
  308. )
  309. )
  310. for stru, rep in zip(community_structure, community_reports):
  311. obj = {
  312. "report": rep,
  313. "evidences": "\n".join([f["explanation"] for f in stru["findings"]]),
  314. }
  315. chunk = {
  316. "docnm_kwd": stru["title"],
  317. "title_tks": rag_tokenizer.tokenize(stru["title"]),
  318. "content_with_weight": json.dumps(obj, ensure_ascii=False),
  319. "content_ltks": rag_tokenizer.tokenize(
  320. obj["report"] + " " + obj["evidences"]
  321. ),
  322. "knowledge_graph_kwd": "community_report",
  323. "weight_flt": stru["weight"],
  324. "entities_kwd": stru["entities"],
  325. "important_kwd": stru["entities"],
  326. "kb_id": kb_id,
  327. "source_id": doc_ids,
  328. "available_int": 0,
  329. }
  330. chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(
  331. chunk["content_ltks"]
  332. )
  333. # try:
  334. # ebd, _ = embed_bdl.encode([", ".join(community["entities"])])
  335. # chunk["q_%d_vec" % len(ebd[0])] = ebd[0]
  336. # except Exception as e:
  337. # logging.exception(f"Fail to embed entity relation: {e}")
  338. await trio.to_thread.run_sync(
  339. lambda: settings.docStoreConn.insert(
  340. [{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id)
  341. )
  342. )
  343. now = trio.current_time()
  344. callback(
  345. msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s."
  346. )
  347. return community_structure, community_reports