|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167 |
- #
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- """
- Reference:
- - [graphrag](https://github.com/microsoft/graphrag)
- """
-
- import argparse
- import html
- import json
- import logging
- import numbers
- import re
- import traceback
- from collections.abc import Callable
- from dataclasses import dataclass
-
- from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
- from rag.llm.chat_model import Base as CompletionLLM
- import networkx as nx
-
- from rag.utils import num_tokens_from_string
-
- SUMMARIZE_PROMPT = """
- You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
- Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
- Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
- If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
- Make sure it is written in third person, and include the entity names so we the have full context.
-
- #######
- -Data-
- Entities: {entity_name}
- Description List: {description_list}
- #######
- Output:
- """
-
- # Max token size for input prompts
- DEFAULT_MAX_INPUT_TOKENS = 4_000
- # Max token count for LLM answers
- DEFAULT_MAX_SUMMARY_LENGTH = 128
-
-
- @dataclass
- class SummarizationResult:
- """Unipartite graph extraction result class definition."""
-
- items: str | tuple[str, str]
- description: str
-
-
- class SummarizeExtractor:
- """Unipartite graph extractor class definition."""
-
- _llm: CompletionLLM
- _entity_name_key: str
- _input_descriptions_key: str
- _summarization_prompt: str
- _on_error: ErrorHandlerFn
- _max_summary_length: int
- _max_input_tokens: int
-
- def __init__(
- self,
- llm_invoker: CompletionLLM,
- entity_name_key: str | None = None,
- input_descriptions_key: str | None = None,
- summarization_prompt: str | None = None,
- on_error: ErrorHandlerFn | None = None,
- max_summary_length: int | None = None,
- max_input_tokens: int | None = None,
- ):
- """Init method definition."""
- # TODO: streamline construction
- self._llm = llm_invoker
- self._entity_name_key = entity_name_key or "entity_name"
- self._input_descriptions_key = input_descriptions_key or "description_list"
-
- self._summarization_prompt = summarization_prompt or SUMMARIZE_PROMPT
- self._on_error = on_error or (lambda _e, _s, _d: None)
- self._max_summary_length = max_summary_length or DEFAULT_MAX_SUMMARY_LENGTH
- self._max_input_tokens = max_input_tokens or DEFAULT_MAX_INPUT_TOKENS
-
- def __call__(
- self,
- items: str | tuple[str, str],
- descriptions: list[str],
- ) -> SummarizationResult:
- """Call method definition."""
- result = ""
- if len(descriptions) == 0:
- result = ""
- if len(descriptions) == 1:
- result = descriptions[0]
- else:
- result = self._summarize_descriptions(items, descriptions)
-
- return SummarizationResult(
- items=items,
- description=result or "",
- )
-
- def _summarize_descriptions(
- self, items: str | tuple[str, str], descriptions: list[str]
- ) -> str:
- """Summarize descriptions into a single description."""
- sorted_items = sorted(items) if isinstance(items, list) else items
-
- # Safety check, should always be a list
- if not isinstance(descriptions, list):
- descriptions = [descriptions]
-
- # Iterate over descriptions, adding all until the max input tokens is reached
- usable_tokens = self._max_input_tokens - num_tokens_from_string(
- self._summarization_prompt
- )
- descriptions_collected = []
- result = ""
-
- for i, description in enumerate(descriptions):
- usable_tokens -= num_tokens_from_string(description)
- descriptions_collected.append(description)
-
- # If buffer is full, or all descriptions have been added, summarize
- if (usable_tokens < 0 and len(descriptions_collected) > 1) or (
- i == len(descriptions) - 1
- ):
- # Calculate result (final or partial)
- result = await self._summarize_descriptions_with_llm(
- sorted_items, descriptions_collected
- )
-
- # If we go for another loop, reset values to new
- if i != len(descriptions) - 1:
- descriptions_collected = [result]
- usable_tokens = (
- self._max_input_tokens
- - num_tokens_from_string(self._summarization_prompt)
- - num_tokens_from_string(result)
- )
-
- return result
-
- def _summarize_descriptions_with_llm(
- self, items: str | tuple[str, str] | list[str], descriptions: list[str]
- ):
- """Summarize descriptions using the LLM."""
- variables = {
- self._entity_name_key: json.dumps(items),
- self._input_descriptions_key: json.dumps(sorted(descriptions)),
- }
- text = perform_variable_replacements(self._summarization_prompt, variables=variables)
- return self._llm.chat("", [{"role": "user", "content": text}])
|