- from __future__ import annotations
 - 
 - import copy
 - import logging
 - import re
 - from abc import ABC, abstractmethod
 - from collections.abc import Callable, Collection, Iterable, Sequence, Set
 - from dataclasses import dataclass
 - from typing import (
 -     Any,
 -     Literal,
 -     Optional,
 -     TypedDict,
 -     TypeVar,
 -     Union,
 - )
 - 
 - from core.rag.models.document import BaseDocumentTransformer, Document
 - 
 - logger = logging.getLogger(__name__)
 - 
 - TS = TypeVar("TS", bound="TextSplitter")
 - 
 - 
 - def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> list[str]:
 -     # Now that we have the separator, split the text
 -     if separator:
 -         if keep_separator:
 -             # The parentheses in the pattern keep the delimiters in the result.
 -             _splits = re.split(f"({re.escape(separator)})", text)
 -             splits = [_splits[i - 1] + _splits[i] for i in range(1, len(_splits), 2)]
 -             if len(_splits) % 2 != 0:
 -                 splits += _splits[-1:]
 -         else:
 -             splits = re.split(separator, text)
 -     else:
 -         splits = list(text)
 -     return [s for s in splits if (s not in {"", "\n"})]
 - 
 - 
 - class TextSplitter(BaseDocumentTransformer, ABC):
 -     """Interface for splitting text into chunks."""
 - 
 -     def __init__(
 -         self,
 -         chunk_size: int = 4000,
 -         chunk_overlap: int = 200,
 -         length_function: Callable[[list[str]], list[int]] = lambda x: [len(x) for x in x],
 -         keep_separator: bool = False,
 -         add_start_index: bool = False,
 -     ) -> None:
 -         """Create a new TextSplitter.
 - 
 -         Args:
 -             chunk_size: Maximum size of chunks to return
 -             chunk_overlap: Overlap in characters between chunks
 -             length_function: Function that measures the length of given chunks
 -             keep_separator: Whether to keep the separator in the chunks
 -             add_start_index: If `True`, includes chunk's start index in metadata
 -         """
 -         if chunk_overlap > chunk_size:
 -             raise ValueError(
 -                 f"Got a larger chunk overlap ({chunk_overlap}) than chunk size ({chunk_size}), should be smaller."
 -             )
 -         self._chunk_size = chunk_size
 -         self._chunk_overlap = chunk_overlap
 -         self._length_function = length_function
 -         self._keep_separator = keep_separator
 -         self._add_start_index = add_start_index
 - 
 -     @abstractmethod
 -     def split_text(self, text: str) -> list[str]:
 -         """Split text into multiple components."""
 - 
 -     def create_documents(self, texts: list[str], metadatas: Optional[list[dict]] = None) -> list[Document]:
 -         """Create documents from a list of texts."""
 -         _metadatas = metadatas or [{}] * len(texts)
 -         documents = []
 -         for i, text in enumerate(texts):
 -             index = -1
 -             for chunk in self.split_text(text):
 -                 metadata = copy.deepcopy(_metadatas[i])
 -                 if self._add_start_index:
 -                     index = text.find(chunk, index + 1)
 -                     metadata["start_index"] = index
 -                 new_doc = Document(page_content=chunk, metadata=metadata)
 -                 documents.append(new_doc)
 -         return documents
 - 
 -     def split_documents(self, documents: Iterable[Document]) -> list[Document]:
 -         """Split documents."""
 -         texts, metadatas = [], []
 -         for doc in documents:
 -             texts.append(doc.page_content)
 -             metadatas.append(doc.metadata or {})
 -         return self.create_documents(texts, metadatas=metadatas)
 - 
 -     def _join_docs(self, docs: list[str], separator: str) -> Optional[str]:
 -         text = separator.join(docs)
 -         text = text.strip()
 -         if text == "":
 -             return None
 -         else:
 -             return text
 - 
 -     def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int]) -> list[str]:
 -         # We now want to combine these smaller pieces into medium size
 -         # chunks to send to the LLM.
 -         separator_len = self._length_function([separator])[0]
 - 
 -         docs = []
 -         current_doc: list[str] = []
 -         total = 0
 -         index = 0
 -         for d in splits:
 -             _len = lengths[index]
 -             if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size:
 -                 if total > self._chunk_size:
 -                     logger.warning(
 -                         f"Created a chunk of size {total}, which is longer than the specified {self._chunk_size}"
 -                     )
 -                 if len(current_doc) > 0:
 -                     doc = self._join_docs(current_doc, separator)
 -                     if doc is not None:
 -                         docs.append(doc)
 -                     # Keep on popping if:
 -                     # - we have a larger chunk than in the chunk overlap
 -                     # - or if we still have any chunks and the length is long
 -                     while total > self._chunk_overlap or (
 -                         total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0
 -                     ):
 -                         total -= self._length_function([current_doc[0]])[0] + (
 -                             separator_len if len(current_doc) > 1 else 0
 -                         )
 -                         current_doc = current_doc[1:]
 -             current_doc.append(d)
 -             total += _len + (separator_len if len(current_doc) > 1 else 0)
 -             index += 1
 -         doc = self._join_docs(current_doc, separator)
 -         if doc is not None:
 -             docs.append(doc)
 -         return docs
 - 
 -     @classmethod
 -     def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:
 -         """Text splitter that uses HuggingFace tokenizer to count length."""
 -         try:
 -             from transformers import PreTrainedTokenizerBase  # type: ignore
 - 
 -             if not isinstance(tokenizer, PreTrainedTokenizerBase):
 -                 raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase")
 - 
 -             def _huggingface_tokenizer_length(text: str) -> int:
 -                 return len(tokenizer.encode(text))
 - 
 -         except ImportError:
 -             raise ValueError(
 -                 "Could not import transformers python package. Please install it with `pip install transformers`."
 -             )
 -         return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs)
 - 
 -     def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
 -         """Transform sequence of documents by splitting them."""
 -         return self.split_documents(list(documents))
 - 
 -     async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
 -         """Asynchronously transform a sequence of documents by splitting them."""
 -         raise NotImplementedError
 - 
 - 
 - class CharacterTextSplitter(TextSplitter):
 -     """Splitting text that looks at characters."""
 - 
 -     def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None:
 -         """Create a new TextSplitter."""
 -         super().__init__(**kwargs)
 -         self._separator = separator
 - 
 -     def split_text(self, text: str) -> list[str]:
 -         """Split incoming text and return chunks."""
 -         # First we naively split the large input into a bunch of smaller ones.
 -         splits = _split_text_with_regex(text, self._separator, self._keep_separator)
 -         _separator = "" if self._keep_separator else self._separator
 -         _good_splits_lengths = []  # cache the lengths of the splits
 -         if splits:
 -             _good_splits_lengths.extend(self._length_function(splits))
 -         return self._merge_splits(splits, _separator, _good_splits_lengths)
 - 
 - 
 - class LineType(TypedDict):
 -     """Line type as typed dict."""
 - 
 -     metadata: dict[str, str]
 -     content: str
 - 
 - 
 - class HeaderType(TypedDict):
 -     """Header type as typed dict."""
 - 
 -     level: int
 -     name: str
 -     data: str
 - 
 - 
 - class MarkdownHeaderTextSplitter:
 -     """Splitting markdown files based on specified headers."""
 - 
 -     def __init__(self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False):
 -         """Create a new MarkdownHeaderTextSplitter.
 - 
 -         Args:
 -             headers_to_split_on: Headers we want to track
 -             return_each_line: Return each line w/ associated headers
 -         """
 -         # Output line-by-line or aggregated into chunks w/ common headers
 -         self.return_each_line = return_each_line
 -         # Given the headers we want to split on,
 -         # (e.g., "#, ##, etc") order by length
 -         self.headers_to_split_on = sorted(headers_to_split_on, key=lambda split: len(split[0]), reverse=True)
 - 
 -     def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]:
 -         """Combine lines with common metadata into chunks
 -         Args:
 -             lines: Line of text / associated header metadata
 -         """
 -         aggregated_chunks: list[LineType] = []
 - 
 -         for line in lines:
 -             if aggregated_chunks and aggregated_chunks[-1]["metadata"] == line["metadata"]:
 -                 # If the last line in the aggregated list
 -                 # has the same metadata as the current line,
 -                 # append the current content to the last lines's content
 -                 aggregated_chunks[-1]["content"] += "  \n" + line["content"]
 -             else:
 -                 # Otherwise, append the current line to the aggregated list
 -                 aggregated_chunks.append(line)
 - 
 -         return [Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in aggregated_chunks]
 - 
 -     def split_text(self, text: str) -> list[Document]:
 -         """Split markdown file
 -         Args:
 -             text: Markdown file"""
 - 
 -         # Split the input text by newline character ("\n").
 -         lines = text.split("\n")
 -         # Final output
 -         lines_with_metadata: list[LineType] = []
 -         # Content and metadata of the chunk currently being processed
 -         current_content: list[str] = []
 -         current_metadata: dict[str, str] = {}
 -         # Keep track of the nested header structure
 -         # header_stack: List[Dict[str, Union[int, str]]] = []
 -         header_stack: list[HeaderType] = []
 -         initial_metadata: dict[str, str] = {}
 - 
 -         for line in lines:
 -             stripped_line = line.strip()
 -             # Check each line against each of the header types (e.g., #, ##)
 -             for sep, name in self.headers_to_split_on:
 -                 # Check if line starts with a header that we intend to split on
 -                 if stripped_line.startswith(sep) and (
 -                     # Header with no text OR header is followed by space
 -                     # Both are valid conditions that sep is being used a header
 -                     len(stripped_line) == len(sep) or stripped_line[len(sep)] == " "
 -                 ):
 -                     # Ensure we are tracking the header as metadata
 -                     if name is not None:
 -                         # Get the current header level
 -                         current_header_level = sep.count("#")
 - 
 -                         # Pop out headers of lower or same level from the stack
 -                         while header_stack and header_stack[-1]["level"] >= current_header_level:
 -                             # We have encountered a new header
 -                             # at the same or higher level
 -                             popped_header = header_stack.pop()
 -                             # Clear the metadata for the
 -                             # popped header in initial_metadata
 -                             if popped_header["name"] in initial_metadata:
 -                                 initial_metadata.pop(popped_header["name"])
 - 
 -                         # Push the current header to the stack
 -                         header: HeaderType = {
 -                             "level": current_header_level,
 -                             "name": name,
 -                             "data": stripped_line[len(sep) :].strip(),
 -                         }
 -                         header_stack.append(header)
 -                         # Update initial_metadata with the current header
 -                         initial_metadata[name] = header["data"]
 - 
 -                     # Add the previous line to the lines_with_metadata
 -                     # only if current_content is not empty
 -                     if current_content:
 -                         lines_with_metadata.append(
 -                             {
 -                                 "content": "\n".join(current_content),
 -                                 "metadata": current_metadata.copy(),
 -                             }
 -                         )
 -                         current_content.clear()
 - 
 -                     break
 -             else:
 -                 if stripped_line:
 -                     current_content.append(stripped_line)
 -                 elif current_content:
 -                     lines_with_metadata.append(
 -                         {
 -                             "content": "\n".join(current_content),
 -                             "metadata": current_metadata.copy(),
 -                         }
 -                     )
 -                     current_content.clear()
 - 
 -             current_metadata = initial_metadata.copy()
 - 
 -         if current_content:
 -             lines_with_metadata.append({"content": "\n".join(current_content), "metadata": current_metadata})
 - 
 -         # lines_with_metadata has each line with associated header metadata
 -         # aggregate these into chunks based on common metadata
 -         if not self.return_each_line:
 -             return self.aggregate_lines_to_chunks(lines_with_metadata)
 -         else:
 -             return [
 -                 Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in lines_with_metadata
 -             ]
 - 
 - 
 - # should be in newer Python versions (3.10+)
 - # @dataclass(frozen=True, kw_only=True, slots=True)
 - @dataclass(frozen=True)
 - class Tokenizer:
 -     chunk_overlap: int
 -     tokens_per_chunk: int
 -     decode: Callable[[list[int]], str]
 -     encode: Callable[[str], list[int]]
 - 
 - 
 - def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]:
 -     """Split incoming text and return chunks using tokenizer."""
 -     splits: list[str] = []
 -     input_ids = tokenizer.encode(text)
 -     start_idx = 0
 -     cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
 -     chunk_ids = input_ids[start_idx:cur_idx]
 -     while start_idx < len(input_ids):
 -         splits.append(tokenizer.decode(chunk_ids))
 -         start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
 -         cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
 -         chunk_ids = input_ids[start_idx:cur_idx]
 -     return splits
 - 
 - 
 - class TokenTextSplitter(TextSplitter):
 -     """Splitting text to tokens using model tokenizer."""
 - 
 -     def __init__(
 -         self,
 -         encoding_name: str = "gpt2",
 -         model_name: Optional[str] = None,
 -         allowed_special: Union[Literal["all"], Set[str]] = set(),
 -         disallowed_special: Union[Literal["all"], Collection[str]] = "all",
 -         **kwargs: Any,
 -     ) -> None:
 -         """Create a new TextSplitter."""
 -         super().__init__(**kwargs)
 -         try:
 -             import tiktoken
 -         except ImportError:
 -             raise ImportError(
 -                 "Could not import tiktoken python package. "
 -                 "This is needed in order to for TokenTextSplitter. "
 -                 "Please install it with `pip install tiktoken`."
 -             )
 - 
 -         if model_name is not None:
 -             enc = tiktoken.encoding_for_model(model_name)
 -         else:
 -             enc = tiktoken.get_encoding(encoding_name)
 -         self._tokenizer = enc
 -         self._allowed_special = allowed_special
 -         self._disallowed_special = disallowed_special
 - 
 -     def split_text(self, text: str) -> list[str]:
 -         def _encode(_text: str) -> list[int]:
 -             return self._tokenizer.encode(
 -                 _text,
 -                 allowed_special=self._allowed_special,
 -                 disallowed_special=self._disallowed_special,
 -             )
 - 
 -         tokenizer = Tokenizer(
 -             chunk_overlap=self._chunk_overlap,
 -             tokens_per_chunk=self._chunk_size,
 -             decode=self._tokenizer.decode,
 -             encode=_encode,
 -         )
 - 
 -         return split_text_on_tokens(text=text, tokenizer=tokenizer)
 - 
 - 
 - class RecursiveCharacterTextSplitter(TextSplitter):
 -     """Splitting text by recursively look at characters.
 - 
 -     Recursively tries to split by different characters to find one
 -     that works.
 -     """
 - 
 -     def __init__(
 -         self,
 -         separators: Optional[list[str]] = None,
 -         keep_separator: bool = True,
 -         **kwargs: Any,
 -     ) -> None:
 -         """Create a new TextSplitter."""
 -         super().__init__(keep_separator=keep_separator, **kwargs)
 -         self._separators = separators or ["\n\n", "\n", " ", ""]
 - 
 -     def _split_text(self, text: str, separators: list[str]) -> list[str]:
 -         final_chunks = []
 -         separator = separators[-1]
 -         new_separators = []
 - 
 -         for i, _s in enumerate(separators):
 -             if _s == "":
 -                 separator = _s
 -                 break
 -             if re.search(_s, text):
 -                 separator = _s
 -                 new_separators = separators[i + 1 :]
 -                 break
 - 
 -         splits = _split_text_with_regex(text, separator, self._keep_separator)
 -         _good_splits = []
 -         _good_splits_lengths = []  # cache the lengths of the splits
 -         _separator = "" if self._keep_separator else separator
 -         s_lens = self._length_function(splits)
 -         for s, s_len in zip(splits, s_lens):
 -             if s_len < self._chunk_size:
 -                 _good_splits.append(s)
 -                 _good_splits_lengths.append(s_len)
 -             else:
 -                 if _good_splits:
 -                     merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
 -                     final_chunks.extend(merged_text)
 -                     _good_splits = []
 -                     _good_splits_lengths = []
 -                 if not new_separators:
 -                     final_chunks.append(s)
 -                 else:
 -                     other_info = self._split_text(s, new_separators)
 -                     final_chunks.extend(other_info)
 - 
 -         if _good_splits:
 -             merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
 -             final_chunks.extend(merged_text)
 - 
 -         return final_chunks
 - 
 -     def split_text(self, text: str) -> list[str]:
 -         return self._split_text(text, self._separators)
 
 
  |