Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

raptor.py 5.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import re
  18. from concurrent.futures import ThreadPoolExecutor, ALL_COMPLETED, wait
  19. from threading import Lock
  20. from typing import Tuple
  21. import umap
  22. import numpy as np
  23. from sklearn.mixture import GaussianMixture
  24. from rag.utils import truncate
  25. class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
  26. def __init__(self, max_cluster, llm_model, embd_model, prompt, max_token=256, threshold=0.1):
  27. self._max_cluster = max_cluster
  28. self._llm_model = llm_model
  29. self._embd_model = embd_model
  30. self._threshold = threshold
  31. self._prompt = prompt
  32. self._max_token = max_token
  33. def _get_optimal_clusters(self, embeddings: np.ndarray, random_state:int):
  34. max_clusters = min(self._max_cluster, len(embeddings))
  35. n_clusters = np.arange(1, max_clusters)
  36. bics = []
  37. for n in n_clusters:
  38. gm = GaussianMixture(n_components=n, random_state=random_state)
  39. gm.fit(embeddings)
  40. bics.append(gm.bic(embeddings))
  41. optimal_clusters = n_clusters[np.argmin(bics)]
  42. return optimal_clusters
  43. def __call__(self, chunks: Tuple[str, np.ndarray], random_state, callback=None):
  44. layers = [(0, len(chunks))]
  45. start, end = 0, len(chunks)
  46. if len(chunks) <= 1: return
  47. chunks = [(s, a) for s, a in chunks if len(a) > 0]
  48. def summarize(ck_idx, lock):
  49. nonlocal chunks
  50. try:
  51. texts = [chunks[i][0] for i in ck_idx]
  52. len_per_chunk = int((self._llm_model.max_length - self._max_token)/len(texts))
  53. cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
  54. cnt = self._llm_model.chat("You're a helpful assistant.",
  55. [{"role": "user", "content": self._prompt.format(cluster_content=cluster_content)}],
  56. {"temperature": 0.3, "max_tokens": self._max_token}
  57. )
  58. cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "", cnt)
  59. logging.debug(f"SUM: {cnt}")
  60. embds, _ = self._embd_model.encode([cnt])
  61. with lock:
  62. if not len(embds[0]): return
  63. chunks.append((cnt, embds[0]))
  64. except Exception as e:
  65. logging.exception("summarize got exception")
  66. return e
  67. labels = []
  68. while end - start > 1:
  69. embeddings = [embd for _, embd in chunks[start: end]]
  70. if len(embeddings) == 2:
  71. summarize([start, start+1], Lock())
  72. if callback:
  73. callback(msg="Cluster one layer: {} -> {}".format(end-start, len(chunks)-end))
  74. labels.extend([0,0])
  75. layers.append((end, len(chunks)))
  76. start = end
  77. end = len(chunks)
  78. continue
  79. n_neighbors = int((len(embeddings) - 1) ** 0.8)
  80. reduced_embeddings = umap.UMAP(
  81. n_neighbors=max(2, n_neighbors), n_components=min(12, len(embeddings)-2), metric="cosine"
  82. ).fit_transform(embeddings)
  83. n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state)
  84. if n_clusters == 1:
  85. lbls = [0 for _ in range(len(reduced_embeddings))]
  86. else:
  87. gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
  88. gm.fit(reduced_embeddings)
  89. probs = gm.predict_proba(reduced_embeddings)
  90. lbls = [np.where(prob > self._threshold)[0] for prob in probs]
  91. lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
  92. lock = Lock()
  93. with ThreadPoolExecutor(max_workers=12) as executor:
  94. threads = []
  95. for c in range(n_clusters):
  96. ck_idx = [i+start for i in range(len(lbls)) if lbls[i] == c]
  97. threads.append(executor.submit(summarize, ck_idx, lock))
  98. wait(threads, return_when=ALL_COMPLETED)
  99. logging.debug(str([t.result() for t in threads]))
  100. assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
  101. labels.extend(lbls)
  102. layers.append((end, len(chunks)))
  103. if callback:
  104. callback(msg="Cluster one layer: {} -> {}".format(end-start, len(chunks)-end))
  105. start = end
  106. end = len(chunks)