You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

description_summary.py 5.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. """
  17. Reference:
  18. - [graphrag](https://github.com/microsoft/graphrag)
  19. """
  20. import argparse
  21. import html
  22. import json
  23. import logging
  24. import numbers
  25. import re
  26. import traceback
  27. from collections.abc import Callable
  28. from dataclasses import dataclass
  29. from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
  30. from rag.llm.chat_model import Base as CompletionLLM
  31. import networkx as nx
  32. from rag.utils import num_tokens_from_string
  33. SUMMARIZE_PROMPT = """
  34. You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
  35. Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
  36. Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
  37. If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
  38. Make sure it is written in third person, and include the entity names so we the have full context.
  39. #######
  40. -Data-
  41. Entities: {entity_name}
  42. Description List: {description_list}
  43. #######
  44. Output:
  45. """
  46. # Max token size for input prompts
  47. DEFAULT_MAX_INPUT_TOKENS = 4_000
  48. # Max token count for LLM answers
  49. DEFAULT_MAX_SUMMARY_LENGTH = 128
  50. @dataclass
  51. class SummarizationResult:
  52. """Unipartite graph extraction result class definition."""
  53. items: str | tuple[str, str]
  54. description: str
  55. class SummarizeExtractor:
  56. """Unipartite graph extractor class definition."""
  57. _llm: CompletionLLM
  58. _entity_name_key: str
  59. _input_descriptions_key: str
  60. _summarization_prompt: str
  61. _on_error: ErrorHandlerFn
  62. _max_summary_length: int
  63. _max_input_tokens: int
  64. def __init__(
  65. self,
  66. llm_invoker: CompletionLLM,
  67. entity_name_key: str | None = None,
  68. input_descriptions_key: str | None = None,
  69. summarization_prompt: str | None = None,
  70. on_error: ErrorHandlerFn | None = None,
  71. max_summary_length: int | None = None,
  72. max_input_tokens: int | None = None,
  73. ):
  74. """Init method definition."""
  75. # TODO: streamline construction
  76. self._llm = llm_invoker
  77. self._entity_name_key = entity_name_key or "entity_name"
  78. self._input_descriptions_key = input_descriptions_key or "description_list"
  79. self._summarization_prompt = summarization_prompt or SUMMARIZE_PROMPT
  80. self._on_error = on_error or (lambda _e, _s, _d: None)
  81. self._max_summary_length = max_summary_length or DEFAULT_MAX_SUMMARY_LENGTH
  82. self._max_input_tokens = max_input_tokens or DEFAULT_MAX_INPUT_TOKENS
  83. def __call__(
  84. self,
  85. items: str | tuple[str, str],
  86. descriptions: list[str],
  87. ) -> SummarizationResult:
  88. """Call method definition."""
  89. result = ""
  90. if len(descriptions) == 0:
  91. result = ""
  92. if len(descriptions) == 1:
  93. result = descriptions[0]
  94. else:
  95. result = self._summarize_descriptions(items, descriptions)
  96. return SummarizationResult(
  97. items=items,
  98. description=result or "",
  99. )
  100. def _summarize_descriptions(
  101. self, items: str | tuple[str, str], descriptions: list[str]
  102. ) -> str:
  103. """Summarize descriptions into a single description."""
  104. sorted_items = sorted(items) if isinstance(items, list) else items
  105. # Safety check, should always be a list
  106. if not isinstance(descriptions, list):
  107. descriptions = [descriptions]
  108. # Iterate over descriptions, adding all until the max input tokens is reached
  109. usable_tokens = self._max_input_tokens - num_tokens_from_string(
  110. self._summarization_prompt
  111. )
  112. descriptions_collected = []
  113. result = ""
  114. for i, description in enumerate(descriptions):
  115. usable_tokens -= num_tokens_from_string(description)
  116. descriptions_collected.append(description)
  117. # If buffer is full, or all descriptions have been added, summarize
  118. if (usable_tokens < 0 and len(descriptions_collected) > 1) or (
  119. i == len(descriptions) - 1
  120. ):
  121. # Calculate result (final or partial)
  122. result = await self._summarize_descriptions_with_llm(
  123. sorted_items, descriptions_collected
  124. )
  125. # If we go for another loop, reset values to new
  126. if i != len(descriptions) - 1:
  127. descriptions_collected = [result]
  128. usable_tokens = (
  129. self._max_input_tokens
  130. - num_tokens_from_string(self._summarization_prompt)
  131. - num_tokens_from_string(result)
  132. )
  133. return result
  134. def _summarize_descriptions_with_llm(
  135. self, items: str | tuple[str, str] | list[str], descriptions: list[str]
  136. ):
  137. """Summarize descriptions using the LLM."""
  138. variables = {
  139. self._entity_name_key: json.dumps(items),
  140. self._input_descriptions_key: json.dumps(sorted(descriptions)),
  141. }
  142. text = perform_variable_replacements(self._summarization_prompt, variables=variables)
  143. return self._llm.chat("", [{"role": "user", "content": text}])