You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

description_summary.py 5.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. # Copyright (c) 2024 Microsoft Corporation.
  2. # Licensed under the MIT License
  3. """
  4. Reference:
  5. - [graphrag](https://github.com/microsoft/graphrag)
  6. """
  7. import argparse
  8. import html
  9. import json
  10. import logging
  11. import numbers
  12. import re
  13. import traceback
  14. from collections.abc import Callable
  15. from dataclasses import dataclass
  16. from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
  17. from rag.llm.chat_model import Base as CompletionLLM
  18. import networkx as nx
  19. from rag.utils import num_tokens_from_string
  20. SUMMARIZE_PROMPT = """
  21. You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
  22. Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
  23. Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
  24. If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
  25. Make sure it is written in third person, and include the entity names so we the have full context.
  26. #######
  27. -Data-
  28. Entities: {entity_name}
  29. Description List: {description_list}
  30. #######
  31. Output:
  32. """
  33. # Max token size for input prompts
  34. DEFAULT_MAX_INPUT_TOKENS = 4_000
  35. # Max token count for LLM answers
  36. DEFAULT_MAX_SUMMARY_LENGTH = 128
  37. @dataclass
  38. class SummarizationResult:
  39. """Unipartite graph extraction result class definition."""
  40. items: str | tuple[str, str]
  41. description: str
  42. class SummarizeExtractor:
  43. """Unipartite graph extractor class definition."""
  44. _llm: CompletionLLM
  45. _entity_name_key: str
  46. _input_descriptions_key: str
  47. _summarization_prompt: str
  48. _on_error: ErrorHandlerFn
  49. _max_summary_length: int
  50. _max_input_tokens: int
  51. def __init__(
  52. self,
  53. llm_invoker: CompletionLLM,
  54. entity_name_key: str | None = None,
  55. input_descriptions_key: str | None = None,
  56. summarization_prompt: str | None = None,
  57. on_error: ErrorHandlerFn | None = None,
  58. max_summary_length: int | None = None,
  59. max_input_tokens: int | None = None,
  60. ):
  61. """Init method definition."""
  62. # TODO: streamline construction
  63. self._llm = llm_invoker
  64. self._entity_name_key = entity_name_key or "entity_name"
  65. self._input_descriptions_key = input_descriptions_key or "description_list"
  66. self._summarization_prompt = summarization_prompt or SUMMARIZE_PROMPT
  67. self._on_error = on_error or (lambda _e, _s, _d: None)
  68. self._max_summary_length = max_summary_length or DEFAULT_MAX_SUMMARY_LENGTH
  69. self._max_input_tokens = max_input_tokens or DEFAULT_MAX_INPUT_TOKENS
  70. def __call__(
  71. self,
  72. items: str | tuple[str, str],
  73. descriptions: list[str],
  74. ) -> SummarizationResult:
  75. """Call method definition."""
  76. result = ""
  77. if len(descriptions) == 0:
  78. result = ""
  79. if len(descriptions) == 1:
  80. result = descriptions[0]
  81. else:
  82. result = self._summarize_descriptions(items, descriptions)
  83. return SummarizationResult(
  84. items=items,
  85. description=result or "",
  86. )
  87. def _summarize_descriptions(
  88. self, items: str | tuple[str, str], descriptions: list[str]
  89. ) -> str:
  90. """Summarize descriptions into a single description."""
  91. sorted_items = sorted(items) if isinstance(items, list) else items
  92. # Safety check, should always be a list
  93. if not isinstance(descriptions, list):
  94. descriptions = [descriptions]
  95. # Iterate over descriptions, adding all until the max input tokens is reached
  96. usable_tokens = self._max_input_tokens - num_tokens_from_string(
  97. self._summarization_prompt
  98. )
  99. descriptions_collected = []
  100. result = ""
  101. for i, description in enumerate(descriptions):
  102. usable_tokens -= num_tokens_from_string(description)
  103. descriptions_collected.append(description)
  104. # If buffer is full, or all descriptions have been added, summarize
  105. if (usable_tokens < 0 and len(descriptions_collected) > 1) or (
  106. i == len(descriptions) - 1
  107. ):
  108. # Calculate result (final or partial)
  109. result = await self._summarize_descriptions_with_llm(
  110. sorted_items, descriptions_collected
  111. )
  112. # If we go for another loop, reset values to new
  113. if i != len(descriptions) - 1:
  114. descriptions_collected = [result]
  115. usable_tokens = (
  116. self._max_input_tokens
  117. - num_tokens_from_string(self._summarization_prompt)
  118. - num_tokens_from_string(result)
  119. )
  120. return result
  121. def _summarize_descriptions_with_llm(
  122. self, items: str | tuple[str, str] | list[str], descriptions: list[str]
  123. ):
  124. """Summarize descriptions using the LLM."""
  125. variables = {
  126. self._entity_name_key: json.dumps(items),
  127. self._input_descriptions_key: json.dumps(sorted(descriptions)),
  128. }
  129. text = perform_variable_replacements(self._summarization_prompt, variables=variables)
  130. return self._llm.chat("", [{"role": "user", "content": text}])