Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # Copyright (c) 2024 Microsoft Corporation.
  2. # Licensed under the MIT License
  3. """
  4. Reference:
  5. - [graphrag](https://github.com/microsoft/graphrag)
  6. """
  7. import json
  8. from dataclasses import dataclass
  9. from graphrag.extractor import Extractor
  10. from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
  11. from rag.llm.chat_model import Base as CompletionLLM
  12. from rag.utils import num_tokens_from_string
  13. SUMMARIZE_PROMPT = """
  14. You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
  15. Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
  16. Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
  17. If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
  18. Make sure it is written in third person, and include the entity names so we the have full context.
  19. #######
  20. -Data-
  21. Entities: {entity_name}
  22. Description List: {description_list}
  23. #######
  24. Output:
  25. """
  26. # Max token size for input prompts
  27. DEFAULT_MAX_INPUT_TOKENS = 4_000
  28. # Max token count for LLM answers
  29. DEFAULT_MAX_SUMMARY_LENGTH = 128
  30. @dataclass
  31. class SummarizationResult:
  32. """Unipartite graph extraction result class definition."""
  33. items: str | tuple[str, str]
  34. description: str
  35. class SummarizeExtractor(Extractor):
  36. """Unipartite graph extractor class definition."""
  37. _entity_name_key: str
  38. _input_descriptions_key: str
  39. _summarization_prompt: str
  40. _on_error: ErrorHandlerFn
  41. _max_summary_length: int
  42. _max_input_tokens: int
  43. def __init__(
  44. self,
  45. llm_invoker: CompletionLLM,
  46. entity_name_key: str | None = None,
  47. input_descriptions_key: str | None = None,
  48. summarization_prompt: str | None = None,
  49. on_error: ErrorHandlerFn | None = None,
  50. max_summary_length: int | None = None,
  51. max_input_tokens: int | None = None,
  52. ):
  53. """Init method definition."""
  54. # TODO: streamline construction
  55. self._llm = llm_invoker
  56. self._entity_name_key = entity_name_key or "entity_name"
  57. self._input_descriptions_key = input_descriptions_key or "description_list"
  58. self._summarization_prompt = summarization_prompt or SUMMARIZE_PROMPT
  59. self._on_error = on_error or (lambda _e, _s, _d: None)
  60. self._max_summary_length = max_summary_length or DEFAULT_MAX_SUMMARY_LENGTH
  61. self._max_input_tokens = max_input_tokens or DEFAULT_MAX_INPUT_TOKENS
  62. def __call__(
  63. self,
  64. items: str | tuple[str, str],
  65. descriptions: list[str],
  66. ) -> SummarizationResult:
  67. """Call method definition."""
  68. result = ""
  69. if len(descriptions) == 0:
  70. result = ""
  71. if len(descriptions) == 1:
  72. result = descriptions[0]
  73. else:
  74. result = self._summarize_descriptions(items, descriptions)
  75. return SummarizationResult(
  76. items=items,
  77. description=result or "",
  78. )
  79. def _summarize_descriptions(
  80. self, items: str | tuple[str, str], descriptions: list[str]
  81. ) -> str:
  82. """Summarize descriptions into a single description."""
  83. sorted_items = sorted(items) if isinstance(items, list) else items
  84. # Safety check, should always be a list
  85. if not isinstance(descriptions, list):
  86. descriptions = [descriptions]
  87. # Iterate over descriptions, adding all until the max input tokens is reached
  88. usable_tokens = self._max_input_tokens - num_tokens_from_string(
  89. self._summarization_prompt
  90. )
  91. descriptions_collected = []
  92. result = ""
  93. for i, description in enumerate(descriptions):
  94. usable_tokens -= num_tokens_from_string(description)
  95. descriptions_collected.append(description)
  96. # If buffer is full, or all descriptions have been added, summarize
  97. if (usable_tokens < 0 and len(descriptions_collected) > 1) or (
  98. i == len(descriptions) - 1
  99. ):
  100. # Calculate result (final or partial)
  101. result = await self._summarize_descriptions_with_llm(
  102. sorted_items, descriptions_collected
  103. )
  104. # If we go for another loop, reset values to new
  105. if i != len(descriptions) - 1:
  106. descriptions_collected = [result]
  107. usable_tokens = (
  108. self._max_input_tokens
  109. - num_tokens_from_string(self._summarization_prompt)
  110. - num_tokens_from_string(result)
  111. )
  112. return result
  113. def _summarize_descriptions_with_llm(
  114. self, items: str | tuple[str, str] | list[str], descriptions: list[str]
  115. ):
  116. """Summarize descriptions using the LLM."""
  117. variables = {
  118. self._entity_name_key: json.dumps(items),
  119. self._input_descriptions_key: json.dumps(sorted(descriptions)),
  120. }
  121. text = perform_variable_replacements(self._summarization_prompt, variables=variables)
  122. return self._chat("", [{"role": "user", "content": text}])