Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

raptor.py 6.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import re
  18. from threading import Lock
  19. import umap
  20. import numpy as np
  21. from sklearn.mixture import GaussianMixture
  22. import trio
  23. from graphrag.utils import get_llm_cache, get_embed_cache, set_embed_cache, set_llm_cache, chat_limiter
  24. from rag.utils import truncate
  25. class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
  26. def __init__(self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1):
  27. self._max_cluster = max_cluster
  28. self._llm_model = llm_model
  29. self._embd_model = embd_model
  30. self._threshold = threshold
  31. self._prompt = prompt
  32. self._max_token = max_token
  33. def _chat(self, system, history, gen_conf):
  34. response = get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)
  35. if response:
  36. return response
  37. response = self._llm_model.chat(system, history, gen_conf)
  38. response = re.sub(r"<think>.*</think>", "", response, flags=re.DOTALL)
  39. if response.find("**ERROR**") >= 0:
  40. raise Exception(response)
  41. set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
  42. return response
  43. def _embedding_encode(self, txt):
  44. response = get_embed_cache(self._embd_model.llm_name, txt)
  45. if response is not None:
  46. return response
  47. embds, _ = self._embd_model.encode([txt])
  48. if len(embds) < 1 or len(embds[0]) < 1:
  49. raise Exception("Embedding error: ")
  50. embds = embds[0]
  51. set_embed_cache(self._embd_model.llm_name, txt, embds)
  52. return embds
  53. def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int):
  54. max_clusters = min(self._max_cluster, len(embeddings))
  55. n_clusters = np.arange(1, max_clusters)
  56. bics = []
  57. for n in n_clusters:
  58. gm = GaussianMixture(n_components=n, random_state=random_state)
  59. gm.fit(embeddings)
  60. bics.append(gm.bic(embeddings))
  61. optimal_clusters = n_clusters[np.argmin(bics)]
  62. return optimal_clusters
  63. async def __call__(self, chunks, random_state, callback=None):
  64. layers = [(0, len(chunks))]
  65. start, end = 0, len(chunks)
  66. if len(chunks) <= 1:
  67. return []
  68. chunks = [(s, a) for s, a in chunks if s and len(a) > 0]
  69. async def summarize(ck_idx, lock):
  70. nonlocal chunks
  71. try:
  72. texts = [chunks[i][0] for i in ck_idx]
  73. len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
  74. cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
  75. async with chat_limiter:
  76. cnt = await trio.to_thread.run_sync(lambda: self._chat("You're a helpful assistant.",
  77. [{"role": "user",
  78. "content": self._prompt.format(cluster_content=cluster_content)}],
  79. {"temperature": 0.3, "max_tokens": self._max_token}
  80. ))
  81. cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "",
  82. cnt)
  83. logging.debug(f"SUM: {cnt}")
  84. embds, _ = self._embd_model.encode([cnt])
  85. with lock:
  86. chunks.append((cnt, self._embedding_encode(cnt)))
  87. except Exception as e:
  88. logging.exception("summarize got exception")
  89. return e
  90. labels = []
  91. lock = Lock()
  92. while end - start > 1:
  93. embeddings = [embd for _, embd in chunks[start: end]]
  94. if len(embeddings) == 2:
  95. await summarize([start, start + 1], lock)
  96. if callback:
  97. callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
  98. labels.extend([0, 0])
  99. layers.append((end, len(chunks)))
  100. start = end
  101. end = len(chunks)
  102. continue
  103. n_neighbors = int((len(embeddings) - 1) ** 0.8)
  104. reduced_embeddings = umap.UMAP(
  105. n_neighbors=max(2, n_neighbors), n_components=min(12, len(embeddings) - 2), metric="cosine"
  106. ).fit_transform(embeddings)
  107. n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state)
  108. if n_clusters == 1:
  109. lbls = [0 for _ in range(len(reduced_embeddings))]
  110. else:
  111. gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
  112. gm.fit(reduced_embeddings)
  113. probs = gm.predict_proba(reduced_embeddings)
  114. lbls = [np.where(prob > self._threshold)[0] for prob in probs]
  115. lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
  116. async with trio.open_nursery() as nursery:
  117. for c in range(n_clusters):
  118. ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
  119. if not ck_idx:
  120. continue
  121. async with chat_limiter:
  122. nursery.start_soon(lambda: summarize(ck_idx, lock))
  123. assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
  124. labels.extend(lbls)
  125. layers.append((end, len(chunks)))
  126. if callback:
  127. callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
  128. start = end
  129. end = len(chunks)
  130. return chunks