|
|
|
@@ -0,0 +1,903 @@ |
|
|
|
from __future__ import annotations |
|
|
|
|
|
|
|
import copy |
|
|
|
import logging |
|
|
|
import re |
|
|
|
from abc import ABC, abstractmethod |
|
|
|
from collections.abc import Callable, Collection, Iterable, Sequence, Set |
|
|
|
from dataclasses import dataclass |
|
|
|
from enum import Enum |
|
|
|
from typing import ( |
|
|
|
Any, |
|
|
|
Literal, |
|
|
|
Optional, |
|
|
|
TypedDict, |
|
|
|
TypeVar, |
|
|
|
Union, |
|
|
|
) |
|
|
|
|
|
|
|
from core.rag.models.document import BaseDocumentTransformer, Document |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
TS = TypeVar("TS", bound="TextSplitter") |
|
|
|
|
|
|
|
|
|
|
|
def _split_text_with_regex( |
|
|
|
text: str, separator: str, keep_separator: bool |
|
|
|
) -> list[str]: |
|
|
|
# Now that we have the separator, split the text |
|
|
|
if separator: |
|
|
|
if keep_separator: |
|
|
|
# The parentheses in the pattern keep the delimiters in the result. |
|
|
|
_splits = re.split(f"({separator})", text) |
|
|
|
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)] |
|
|
|
if len(_splits) % 2 == 0: |
|
|
|
splits += _splits[-1:] |
|
|
|
splits = [_splits[0]] + splits |
|
|
|
else: |
|
|
|
splits = re.split(separator, text) |
|
|
|
else: |
|
|
|
splits = list(text) |
|
|
|
return [s for s in splits if s != ""] |
|
|
|
|
|
|
|
|
|
|
|
class TextSplitter(BaseDocumentTransformer, ABC): |
|
|
|
"""Interface for splitting text into chunks.""" |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
chunk_size: int = 4000, |
|
|
|
chunk_overlap: int = 200, |
|
|
|
length_function: Callable[[str], int] = len, |
|
|
|
keep_separator: bool = False, |
|
|
|
add_start_index: bool = False, |
|
|
|
) -> None: |
|
|
|
"""Create a new TextSplitter. |
|
|
|
|
|
|
|
Args: |
|
|
|
chunk_size: Maximum size of chunks to return |
|
|
|
chunk_overlap: Overlap in characters between chunks |
|
|
|
length_function: Function that measures the length of given chunks |
|
|
|
keep_separator: Whether to keep the separator in the chunks |
|
|
|
add_start_index: If `True`, includes chunk's start index in metadata |
|
|
|
""" |
|
|
|
if chunk_overlap > chunk_size: |
|
|
|
raise ValueError( |
|
|
|
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " |
|
|
|
f"({chunk_size}), should be smaller." |
|
|
|
) |
|
|
|
self._chunk_size = chunk_size |
|
|
|
self._chunk_overlap = chunk_overlap |
|
|
|
self._length_function = length_function |
|
|
|
self._keep_separator = keep_separator |
|
|
|
self._add_start_index = add_start_index |
|
|
|
|
|
|
|
@abstractmethod |
|
|
|
def split_text(self, text: str) -> list[str]: |
|
|
|
"""Split text into multiple components.""" |
|
|
|
|
|
|
|
def create_documents( |
|
|
|
self, texts: list[str], metadatas: Optional[list[dict]] = None |
|
|
|
) -> list[Document]: |
|
|
|
"""Create documents from a list of texts.""" |
|
|
|
_metadatas = metadatas or [{}] * len(texts) |
|
|
|
documents = [] |
|
|
|
for i, text in enumerate(texts): |
|
|
|
index = -1 |
|
|
|
for chunk in self.split_text(text): |
|
|
|
metadata = copy.deepcopy(_metadatas[i]) |
|
|
|
if self._add_start_index: |
|
|
|
index = text.find(chunk, index + 1) |
|
|
|
metadata["start_index"] = index |
|
|
|
new_doc = Document(page_content=chunk, metadata=metadata) |
|
|
|
documents.append(new_doc) |
|
|
|
return documents |
|
|
|
|
|
|
|
def split_documents(self, documents: Iterable[Document]) -> list[Document]: |
|
|
|
"""Split documents.""" |
|
|
|
texts, metadatas = [], [] |
|
|
|
for doc in documents: |
|
|
|
texts.append(doc.page_content) |
|
|
|
metadatas.append(doc.metadata) |
|
|
|
return self.create_documents(texts, metadatas=metadatas) |
|
|
|
|
|
|
|
def _join_docs(self, docs: list[str], separator: str) -> Optional[str]: |
|
|
|
text = separator.join(docs) |
|
|
|
text = text.strip() |
|
|
|
if text == "": |
|
|
|
return None |
|
|
|
else: |
|
|
|
return text |
|
|
|
|
|
|
|
def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]: |
|
|
|
# We now want to combine these smaller pieces into medium size |
|
|
|
# chunks to send to the LLM. |
|
|
|
separator_len = self._length_function(separator) |
|
|
|
|
|
|
|
docs = [] |
|
|
|
current_doc: list[str] = [] |
|
|
|
total = 0 |
|
|
|
for d in splits: |
|
|
|
_len = self._length_function(d) |
|
|
|
if ( |
|
|
|
total + _len + (separator_len if len(current_doc) > 0 else 0) |
|
|
|
> self._chunk_size |
|
|
|
): |
|
|
|
if total > self._chunk_size: |
|
|
|
logger.warning( |
|
|
|
f"Created a chunk of size {total}, " |
|
|
|
f"which is longer than the specified {self._chunk_size}" |
|
|
|
) |
|
|
|
if len(current_doc) > 0: |
|
|
|
doc = self._join_docs(current_doc, separator) |
|
|
|
if doc is not None: |
|
|
|
docs.append(doc) |
|
|
|
# Keep on popping if: |
|
|
|
# - we have a larger chunk than in the chunk overlap |
|
|
|
# - or if we still have any chunks and the length is long |
|
|
|
while total > self._chunk_overlap or ( |
|
|
|
total + _len + (separator_len if len(current_doc) > 0 else 0) |
|
|
|
> self._chunk_size |
|
|
|
and total > 0 |
|
|
|
): |
|
|
|
total -= self._length_function(current_doc[0]) + ( |
|
|
|
separator_len if len(current_doc) > 1 else 0 |
|
|
|
) |
|
|
|
current_doc = current_doc[1:] |
|
|
|
current_doc.append(d) |
|
|
|
total += _len + (separator_len if len(current_doc) > 1 else 0) |
|
|
|
doc = self._join_docs(current_doc, separator) |
|
|
|
if doc is not None: |
|
|
|
docs.append(doc) |
|
|
|
return docs |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter: |
|
|
|
"""Text splitter that uses HuggingFace tokenizer to count length.""" |
|
|
|
try: |
|
|
|
from transformers import PreTrainedTokenizerBase |
|
|
|
|
|
|
|
if not isinstance(tokenizer, PreTrainedTokenizerBase): |
|
|
|
raise ValueError( |
|
|
|
"Tokenizer received was not an instance of PreTrainedTokenizerBase" |
|
|
|
) |
|
|
|
|
|
|
|
def _huggingface_tokenizer_length(text: str) -> int: |
|
|
|
return len(tokenizer.encode(text)) |
|
|
|
|
|
|
|
except ImportError: |
|
|
|
raise ValueError( |
|
|
|
"Could not import transformers python package. " |
|
|
|
"Please install it with `pip install transformers`." |
|
|
|
) |
|
|
|
return cls(length_function=_huggingface_tokenizer_length, **kwargs) |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def from_tiktoken_encoder( |
|
|
|
cls: type[TS], |
|
|
|
encoding_name: str = "gpt2", |
|
|
|
model_name: Optional[str] = None, |
|
|
|
allowed_special: Union[Literal["all"], Set[str]] = set(), |
|
|
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all", |
|
|
|
**kwargs: Any, |
|
|
|
) -> TS: |
|
|
|
"""Text splitter that uses tiktoken encoder to count length.""" |
|
|
|
try: |
|
|
|
import tiktoken |
|
|
|
except ImportError: |
|
|
|
raise ImportError( |
|
|
|
"Could not import tiktoken python package. " |
|
|
|
"This is needed in order to calculate max_tokens_for_prompt. " |
|
|
|
"Please install it with `pip install tiktoken`." |
|
|
|
) |
|
|
|
|
|
|
|
if model_name is not None: |
|
|
|
enc = tiktoken.encoding_for_model(model_name) |
|
|
|
else: |
|
|
|
enc = tiktoken.get_encoding(encoding_name) |
|
|
|
|
|
|
|
def _tiktoken_encoder(text: str) -> int: |
|
|
|
return len( |
|
|
|
enc.encode( |
|
|
|
text, |
|
|
|
allowed_special=allowed_special, |
|
|
|
disallowed_special=disallowed_special, |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
if issubclass(cls, TokenTextSplitter): |
|
|
|
extra_kwargs = { |
|
|
|
"encoding_name": encoding_name, |
|
|
|
"model_name": model_name, |
|
|
|
"allowed_special": allowed_special, |
|
|
|
"disallowed_special": disallowed_special, |
|
|
|
} |
|
|
|
kwargs = {**kwargs, **extra_kwargs} |
|
|
|
|
|
|
|
return cls(length_function=_tiktoken_encoder, **kwargs) |
|
|
|
|
|
|
|
def transform_documents( |
|
|
|
self, documents: Sequence[Document], **kwargs: Any |
|
|
|
) -> Sequence[Document]: |
|
|
|
"""Transform sequence of documents by splitting them.""" |
|
|
|
return self.split_documents(list(documents)) |
|
|
|
|
|
|
|
async def atransform_documents( |
|
|
|
self, documents: Sequence[Document], **kwargs: Any |
|
|
|
) -> Sequence[Document]: |
|
|
|
"""Asynchronously transform a sequence of documents by splitting them.""" |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
class CharacterTextSplitter(TextSplitter): |
|
|
|
"""Splitting text that looks at characters.""" |
|
|
|
|
|
|
|
def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None: |
|
|
|
"""Create a new TextSplitter.""" |
|
|
|
super().__init__(**kwargs) |
|
|
|
self._separator = separator |
|
|
|
|
|
|
|
def split_text(self, text: str) -> list[str]: |
|
|
|
"""Split incoming text and return chunks.""" |
|
|
|
# First we naively split the large input into a bunch of smaller ones. |
|
|
|
splits = _split_text_with_regex(text, self._separator, self._keep_separator) |
|
|
|
_separator = "" if self._keep_separator else self._separator |
|
|
|
return self._merge_splits(splits, _separator) |
|
|
|
|
|
|
|
|
|
|
|
class LineType(TypedDict): |
|
|
|
"""Line type as typed dict.""" |
|
|
|
|
|
|
|
metadata: dict[str, str] |
|
|
|
content: str |
|
|
|
|
|
|
|
|
|
|
|
class HeaderType(TypedDict): |
|
|
|
"""Header type as typed dict.""" |
|
|
|
|
|
|
|
level: int |
|
|
|
name: str |
|
|
|
data: str |
|
|
|
|
|
|
|
|
|
|
|
class MarkdownHeaderTextSplitter: |
|
|
|
"""Splitting markdown files based on specified headers.""" |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False |
|
|
|
): |
|
|
|
"""Create a new MarkdownHeaderTextSplitter. |
|
|
|
|
|
|
|
Args: |
|
|
|
headers_to_split_on: Headers we want to track |
|
|
|
return_each_line: Return each line w/ associated headers |
|
|
|
""" |
|
|
|
# Output line-by-line or aggregated into chunks w/ common headers |
|
|
|
self.return_each_line = return_each_line |
|
|
|
# Given the headers we want to split on, |
|
|
|
# (e.g., "#, ##, etc") order by length |
|
|
|
self.headers_to_split_on = sorted( |
|
|
|
headers_to_split_on, key=lambda split: len(split[0]), reverse=True |
|
|
|
) |
|
|
|
|
|
|
|
def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]: |
|
|
|
"""Combine lines with common metadata into chunks |
|
|
|
Args: |
|
|
|
lines: Line of text / associated header metadata |
|
|
|
""" |
|
|
|
aggregated_chunks: list[LineType] = [] |
|
|
|
|
|
|
|
for line in lines: |
|
|
|
if ( |
|
|
|
aggregated_chunks |
|
|
|
and aggregated_chunks[-1]["metadata"] == line["metadata"] |
|
|
|
): |
|
|
|
# If the last line in the aggregated list |
|
|
|
# has the same metadata as the current line, |
|
|
|
# append the current content to the last lines's content |
|
|
|
aggregated_chunks[-1]["content"] += " \n" + line["content"] |
|
|
|
else: |
|
|
|
# Otherwise, append the current line to the aggregated list |
|
|
|
aggregated_chunks.append(line) |
|
|
|
|
|
|
|
return [ |
|
|
|
Document(page_content=chunk["content"], metadata=chunk["metadata"]) |
|
|
|
for chunk in aggregated_chunks |
|
|
|
] |
|
|
|
|
|
|
|
def split_text(self, text: str) -> list[Document]: |
|
|
|
"""Split markdown file |
|
|
|
Args: |
|
|
|
text: Markdown file""" |
|
|
|
|
|
|
|
# Split the input text by newline character ("\n"). |
|
|
|
lines = text.split("\n") |
|
|
|
# Final output |
|
|
|
lines_with_metadata: list[LineType] = [] |
|
|
|
# Content and metadata of the chunk currently being processed |
|
|
|
current_content: list[str] = [] |
|
|
|
current_metadata: dict[str, str] = {} |
|
|
|
# Keep track of the nested header structure |
|
|
|
# header_stack: List[Dict[str, Union[int, str]]] = [] |
|
|
|
header_stack: list[HeaderType] = [] |
|
|
|
initial_metadata: dict[str, str] = {} |
|
|
|
|
|
|
|
for line in lines: |
|
|
|
stripped_line = line.strip() |
|
|
|
# Check each line against each of the header types (e.g., #, ##) |
|
|
|
for sep, name in self.headers_to_split_on: |
|
|
|
# Check if line starts with a header that we intend to split on |
|
|
|
if stripped_line.startswith(sep) and ( |
|
|
|
# Header with no text OR header is followed by space |
|
|
|
# Both are valid conditions that sep is being used a header |
|
|
|
len(stripped_line) == len(sep) |
|
|
|
or stripped_line[len(sep)] == " " |
|
|
|
): |
|
|
|
# Ensure we are tracking the header as metadata |
|
|
|
if name is not None: |
|
|
|
# Get the current header level |
|
|
|
current_header_level = sep.count("#") |
|
|
|
|
|
|
|
# Pop out headers of lower or same level from the stack |
|
|
|
while ( |
|
|
|
header_stack |
|
|
|
and header_stack[-1]["level"] >= current_header_level |
|
|
|
): |
|
|
|
# We have encountered a new header |
|
|
|
# at the same or higher level |
|
|
|
popped_header = header_stack.pop() |
|
|
|
# Clear the metadata for the |
|
|
|
# popped header in initial_metadata |
|
|
|
if popped_header["name"] in initial_metadata: |
|
|
|
initial_metadata.pop(popped_header["name"]) |
|
|
|
|
|
|
|
# Push the current header to the stack |
|
|
|
header: HeaderType = { |
|
|
|
"level": current_header_level, |
|
|
|
"name": name, |
|
|
|
"data": stripped_line[len(sep):].strip(), |
|
|
|
} |
|
|
|
header_stack.append(header) |
|
|
|
# Update initial_metadata with the current header |
|
|
|
initial_metadata[name] = header["data"] |
|
|
|
|
|
|
|
# Add the previous line to the lines_with_metadata |
|
|
|
# only if current_content is not empty |
|
|
|
if current_content: |
|
|
|
lines_with_metadata.append( |
|
|
|
{ |
|
|
|
"content": "\n".join(current_content), |
|
|
|
"metadata": current_metadata.copy(), |
|
|
|
} |
|
|
|
) |
|
|
|
current_content.clear() |
|
|
|
|
|
|
|
break |
|
|
|
else: |
|
|
|
if stripped_line: |
|
|
|
current_content.append(stripped_line) |
|
|
|
elif current_content: |
|
|
|
lines_with_metadata.append( |
|
|
|
{ |
|
|
|
"content": "\n".join(current_content), |
|
|
|
"metadata": current_metadata.copy(), |
|
|
|
} |
|
|
|
) |
|
|
|
current_content.clear() |
|
|
|
|
|
|
|
current_metadata = initial_metadata.copy() |
|
|
|
|
|
|
|
if current_content: |
|
|
|
lines_with_metadata.append( |
|
|
|
{"content": "\n".join(current_content), "metadata": current_metadata} |
|
|
|
) |
|
|
|
|
|
|
|
# lines_with_metadata has each line with associated header metadata |
|
|
|
# aggregate these into chunks based on common metadata |
|
|
|
if not self.return_each_line: |
|
|
|
return self.aggregate_lines_to_chunks(lines_with_metadata) |
|
|
|
else: |
|
|
|
return [ |
|
|
|
Document(page_content=chunk["content"], metadata=chunk["metadata"]) |
|
|
|
for chunk in lines_with_metadata |
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
# should be in newer Python versions (3.10+) |
|
|
|
# @dataclass(frozen=True, kw_only=True, slots=True) |
|
|
|
@dataclass(frozen=True) |
|
|
|
class Tokenizer: |
|
|
|
chunk_overlap: int |
|
|
|
tokens_per_chunk: int |
|
|
|
decode: Callable[[list[int]], str] |
|
|
|
encode: Callable[[str], list[int]] |
|
|
|
|
|
|
|
|
|
|
|
def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]: |
|
|
|
"""Split incoming text and return chunks using tokenizer.""" |
|
|
|
splits: list[str] = [] |
|
|
|
input_ids = tokenizer.encode(text) |
|
|
|
start_idx = 0 |
|
|
|
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) |
|
|
|
chunk_ids = input_ids[start_idx:cur_idx] |
|
|
|
while start_idx < len(input_ids): |
|
|
|
splits.append(tokenizer.decode(chunk_ids)) |
|
|
|
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap |
|
|
|
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) |
|
|
|
chunk_ids = input_ids[start_idx:cur_idx] |
|
|
|
return splits |
|
|
|
|
|
|
|
|
|
|
|
class TokenTextSplitter(TextSplitter): |
|
|
|
"""Splitting text to tokens using model tokenizer.""" |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
encoding_name: str = "gpt2", |
|
|
|
model_name: Optional[str] = None, |
|
|
|
allowed_special: Union[Literal["all"], Set[str]] = set(), |
|
|
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all", |
|
|
|
**kwargs: Any, |
|
|
|
) -> None: |
|
|
|
"""Create a new TextSplitter.""" |
|
|
|
super().__init__(**kwargs) |
|
|
|
try: |
|
|
|
import tiktoken |
|
|
|
except ImportError: |
|
|
|
raise ImportError( |
|
|
|
"Could not import tiktoken python package. " |
|
|
|
"This is needed in order to for TokenTextSplitter. " |
|
|
|
"Please install it with `pip install tiktoken`." |
|
|
|
) |
|
|
|
|
|
|
|
if model_name is not None: |
|
|
|
enc = tiktoken.encoding_for_model(model_name) |
|
|
|
else: |
|
|
|
enc = tiktoken.get_encoding(encoding_name) |
|
|
|
self._tokenizer = enc |
|
|
|
self._allowed_special = allowed_special |
|
|
|
self._disallowed_special = disallowed_special |
|
|
|
|
|
|
|
def split_text(self, text: str) -> list[str]: |
|
|
|
def _encode(_text: str) -> list[int]: |
|
|
|
return self._tokenizer.encode( |
|
|
|
_text, |
|
|
|
allowed_special=self._allowed_special, |
|
|
|
disallowed_special=self._disallowed_special, |
|
|
|
) |
|
|
|
|
|
|
|
tokenizer = Tokenizer( |
|
|
|
chunk_overlap=self._chunk_overlap, |
|
|
|
tokens_per_chunk=self._chunk_size, |
|
|
|
decode=self._tokenizer.decode, |
|
|
|
encode=_encode, |
|
|
|
) |
|
|
|
|
|
|
|
return split_text_on_tokens(text=text, tokenizer=tokenizer) |
|
|
|
|
|
|
|
|
|
|
|
class Language(str, Enum): |
|
|
|
"""Enum of the programming languages.""" |
|
|
|
|
|
|
|
CPP = "cpp" |
|
|
|
GO = "go" |
|
|
|
JAVA = "java" |
|
|
|
JS = "js" |
|
|
|
PHP = "php" |
|
|
|
PROTO = "proto" |
|
|
|
PYTHON = "python" |
|
|
|
RST = "rst" |
|
|
|
RUBY = "ruby" |
|
|
|
RUST = "rust" |
|
|
|
SCALA = "scala" |
|
|
|
SWIFT = "swift" |
|
|
|
MARKDOWN = "markdown" |
|
|
|
LATEX = "latex" |
|
|
|
HTML = "html" |
|
|
|
SOL = "sol" |
|
|
|
|
|
|
|
|
|
|
|
class RecursiveCharacterTextSplitter(TextSplitter): |
|
|
|
"""Splitting text by recursively look at characters. |
|
|
|
|
|
|
|
Recursively tries to split by different characters to find one |
|
|
|
that works. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
separators: Optional[list[str]] = None, |
|
|
|
keep_separator: bool = True, |
|
|
|
**kwargs: Any, |
|
|
|
) -> None: |
|
|
|
"""Create a new TextSplitter.""" |
|
|
|
super().__init__(keep_separator=keep_separator, **kwargs) |
|
|
|
self._separators = separators or ["\n\n", "\n", " ", ""] |
|
|
|
|
|
|
|
def _split_text(self, text: str, separators: list[str]) -> list[str]: |
|
|
|
"""Split incoming text and return chunks.""" |
|
|
|
final_chunks = [] |
|
|
|
# Get appropriate separator to use |
|
|
|
separator = separators[-1] |
|
|
|
new_separators = [] |
|
|
|
for i, _s in enumerate(separators): |
|
|
|
if _s == "": |
|
|
|
separator = _s |
|
|
|
break |
|
|
|
if re.search(_s, text): |
|
|
|
separator = _s |
|
|
|
new_separators = separators[i + 1:] |
|
|
|
break |
|
|
|
|
|
|
|
splits = _split_text_with_regex(text, separator, self._keep_separator) |
|
|
|
# Now go merging things, recursively splitting longer texts. |
|
|
|
_good_splits = [] |
|
|
|
_separator = "" if self._keep_separator else separator |
|
|
|
for s in splits: |
|
|
|
if self._length_function(s) < self._chunk_size: |
|
|
|
_good_splits.append(s) |
|
|
|
else: |
|
|
|
if _good_splits: |
|
|
|
merged_text = self._merge_splits(_good_splits, _separator) |
|
|
|
final_chunks.extend(merged_text) |
|
|
|
_good_splits = [] |
|
|
|
if not new_separators: |
|
|
|
final_chunks.append(s) |
|
|
|
else: |
|
|
|
other_info = self._split_text(s, new_separators) |
|
|
|
final_chunks.extend(other_info) |
|
|
|
if _good_splits: |
|
|
|
merged_text = self._merge_splits(_good_splits, _separator) |
|
|
|
final_chunks.extend(merged_text) |
|
|
|
return final_chunks |
|
|
|
|
|
|
|
def split_text(self, text: str) -> list[str]: |
|
|
|
return self._split_text(text, self._separators) |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def from_language( |
|
|
|
cls, language: Language, **kwargs: Any |
|
|
|
) -> RecursiveCharacterTextSplitter: |
|
|
|
separators = cls.get_separators_for_language(language) |
|
|
|
return cls(separators=separators, **kwargs) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def get_separators_for_language(language: Language) -> list[str]: |
|
|
|
if language == Language.CPP: |
|
|
|
return [ |
|
|
|
# Split along class definitions |
|
|
|
"\nclass ", |
|
|
|
# Split along function definitions |
|
|
|
"\nvoid ", |
|
|
|
"\nint ", |
|
|
|
"\nfloat ", |
|
|
|
"\ndouble ", |
|
|
|
# Split along control flow statements |
|
|
|
"\nif ", |
|
|
|
"\nfor ", |
|
|
|
"\nwhile ", |
|
|
|
"\nswitch ", |
|
|
|
"\ncase ", |
|
|
|
# Split by the normal type of lines |
|
|
|
"\n\n", |
|
|
|
"\n", |
|
|
|
" ", |
|
|
|
"", |
|
|
|
] |
|
|
|
elif language == Language.GO: |
|
|
|
return [ |
|
|
|
# Split along function definitions |
|
|
|
"\nfunc ", |
|
|
|
"\nvar ", |
|
|
|
"\nconst ", |
|
|
|
"\ntype ", |
|
|
|
# Split along control flow statements |
|
|
|
"\nif ", |
|
|
|
"\nfor ", |
|
|
|
"\nswitch ", |
|
|
|
"\ncase ", |
|
|
|
# Split by the normal type of lines |
|
|
|
"\n\n", |
|
|
|
"\n", |
|
|
|
" ", |
|
|
|
"", |
|
|
|
] |
|
|
|
elif language == Language.JAVA: |
|
|
|
return [ |
|
|
|
# Split along class definitions |
|
|
|
"\nclass ", |
|
|
|
# Split along method definitions |
|
|
|
"\npublic ", |
|
|
|
"\nprotected ", |
|
|
|
"\nprivate ", |
|
|
|
"\nstatic ", |
|
|
|
# Split along control flow statements |
|
|
|
"\nif ", |
|
|
|
"\nfor ", |
|
|
|
"\nwhile ", |
|
|
|
"\nswitch ", |
|
|
|
"\ncase ", |
|
|
|
# Split by the normal type of lines |
|
|
|
"\n\n", |
|
|
|
"\n", |
|
|
|
" ", |
|
|
|
"", |
|
|
|
] |
|
|
|
elif language == Language.JS: |
|
|
|
return [ |
|
|
|
# Split along function definitions |
|
|
|
"\nfunction ", |
|
|
|
"\nconst ", |
|
|
|
"\nlet ", |
|
|
|
"\nvar ", |
|
|
|
"\nclass ", |
|
|
|
# Split along control flow statements |
|
|
|
"\nif ", |
|
|
|
"\nfor ", |
|
|
|
"\nwhile ", |
|
|
|
"\nswitch ", |
|
|
|
"\ncase ", |
|
|
|
"\ndefault ", |
|
|
|
# Split by the normal type of lines |
|
|
|
"\n\n", |
|
|
|
"\n", |
|
|
|
" ", |
|
|
|
"", |
|
|
|
] |
|
|
|
elif language == Language.PHP: |
|
|
|
return [ |
|
|
|
# Split along function definitions |
|
|
|
"\nfunction ", |
|
|
|
# Split along class definitions |
|
|
|
"\nclass ", |
|
|
|
# Split along control flow statements |
|
|
|
"\nif ", |
|
|
|
"\nforeach ", |
|
|
|
"\nwhile ", |
|
|
|
"\ndo ", |
|
|
|
"\nswitch ", |
|
|
|
"\ncase ", |
|
|
|
# Split by the normal type of lines |
|
|
|
"\n\n", |
|
|
|
"\n", |
|
|
|
" ", |
|
|
|
"", |
|
|
|
] |
|
|
|
elif language == Language.PROTO: |
|
|
|
return [ |
|
|
|
# Split along message definitions |
|
|
|
"\nmessage ", |
|
|
|
# Split along service definitions |
|
|
|
"\nservice ", |
|
|
|
# Split along enum definitions |
|
|
|
"\nenum ", |
|
|
|
# Split along option definitions |
|
|
|
"\noption ", |
|
|
|
# Split along import statements |
|
|
|
"\nimport ", |
|
|
|
# Split along syntax declarations |
|
|
|
"\nsyntax ", |
|
|
|
# Split by the normal type of lines |
|
|
|
"\n\n", |
|
|
|
"\n", |
|
|
|
" ", |
|
|
|
"", |
|
|
|
] |
|
|
|
elif language == Language.PYTHON: |
|
|
|
return [ |
|
|
|
# First, try to split along class definitions |
|
|
|
"\nclass ", |
|
|
|
"\ndef ", |
|
|
|
"\n\tdef ", |
|
|
|
# Now split by the normal type of lines |
|
|
|
"\n\n", |
|
|
|
"\n", |
|
|
|
" ", |
|
|
|
"", |
|
|
|
] |
|
|
|
elif language == Language.RST: |
|
|
|
return [ |
|
|
|
# Split along section titles |
|
|
|
"\n=+\n", |
|
|
|
"\n-+\n", |
|
|
|
"\n\*+\n", |
|
|
|
# Split along directive markers |
|
|
|
"\n\n.. *\n\n", |
|
|
|
# Split by the normal type of lines |
|
|
|
"\n\n", |
|
|
|
"\n", |
|
|
|
" ", |
|
|
|
"", |
|
|
|
] |
|
|
|
elif language == Language.RUBY: |
|
|
|
return [ |
|
|
|
# Split along method definitions |
|
|
|
"\ndef ", |
|
|
|
"\nclass ", |
|
|
|
# Split along control flow statements |
|
|
|
"\nif ", |
|
|
|
"\nunless ", |
|
|
|
"\nwhile ", |
|
|
|
"\nfor ", |
|
|
|
"\ndo ", |
|
|
|
"\nbegin ", |
|
|
|
"\nrescue ", |
|
|
|
# Split by the normal type of lines |
|
|
|
"\n\n", |
|
|
|
"\n", |
|
|
|
" ", |
|
|
|
"", |
|
|
|
] |
|
|
|
elif language == Language.RUST: |
|
|
|
return [ |
|
|
|
# Split along function definitions |
|
|
|
"\nfn ", |
|
|
|
"\nconst ", |
|
|
|
"\nlet ", |
|
|
|
# Split along control flow statements |
|
|
|
"\nif ", |
|
|
|
"\nwhile ", |
|
|
|
"\nfor ", |
|
|
|
"\nloop ", |
|
|
|
"\nmatch ", |
|
|
|
"\nconst ", |
|
|
|
# Split by the normal type of lines |
|
|
|
"\n\n", |
|
|
|
"\n", |
|
|
|
" ", |
|
|
|
"", |
|
|
|
] |
|
|
|
elif language == Language.SCALA: |
|
|
|
return [ |
|
|
|
# Split along class definitions |
|
|
|
"\nclass ", |
|
|
|
"\nobject ", |
|
|
|
# Split along method definitions |
|
|
|
"\ndef ", |
|
|
|
"\nval ", |
|
|
|
"\nvar ", |
|
|
|
# Split along control flow statements |
|
|
|
"\nif ", |
|
|
|
"\nfor ", |
|
|
|
"\nwhile ", |
|
|
|
"\nmatch ", |
|
|
|
"\ncase ", |
|
|
|
# Split by the normal type of lines |
|
|
|
"\n\n", |
|
|
|
"\n", |
|
|
|
" ", |
|
|
|
"", |
|
|
|
] |
|
|
|
elif language == Language.SWIFT: |
|
|
|
return [ |
|
|
|
# Split along function definitions |
|
|
|
"\nfunc ", |
|
|
|
# Split along class definitions |
|
|
|
"\nclass ", |
|
|
|
"\nstruct ", |
|
|
|
"\nenum ", |
|
|
|
# Split along control flow statements |
|
|
|
"\nif ", |
|
|
|
"\nfor ", |
|
|
|
"\nwhile ", |
|
|
|
"\ndo ", |
|
|
|
"\nswitch ", |
|
|
|
"\ncase ", |
|
|
|
# Split by the normal type of lines |
|
|
|
"\n\n", |
|
|
|
"\n", |
|
|
|
" ", |
|
|
|
"", |
|
|
|
] |
|
|
|
elif language == Language.MARKDOWN: |
|
|
|
return [ |
|
|
|
# First, try to split along Markdown headings (starting with level 2) |
|
|
|
"\n#{1,6} ", |
|
|
|
# Note the alternative syntax for headings (below) is not handled here |
|
|
|
# Heading level 2 |
|
|
|
# --------------- |
|
|
|
# End of code block |
|
|
|
"```\n", |
|
|
|
# Horizontal lines |
|
|
|
"\n\*\*\*+\n", |
|
|
|
"\n---+\n", |
|
|
|
"\n___+\n", |
|
|
|
# Note that this splitter doesn't handle horizontal lines defined |
|
|
|
# by *three or more* of ***, ---, or ___, but this is not handled |
|
|
|
"\n\n", |
|
|
|
"\n", |
|
|
|
" ", |
|
|
|
"", |
|
|
|
] |
|
|
|
elif language == Language.LATEX: |
|
|
|
return [ |
|
|
|
# First, try to split along Latex sections |
|
|
|
"\n\\\chapter{", |
|
|
|
"\n\\\section{", |
|
|
|
"\n\\\subsection{", |
|
|
|
"\n\\\subsubsection{", |
|
|
|
# Now split by environments |
|
|
|
"\n\\\begin{enumerate}", |
|
|
|
"\n\\\begin{itemize}", |
|
|
|
"\n\\\begin{description}", |
|
|
|
"\n\\\begin{list}", |
|
|
|
"\n\\\begin{quote}", |
|
|
|
"\n\\\begin{quotation}", |
|
|
|
"\n\\\begin{verse}", |
|
|
|
"\n\\\begin{verbatim}", |
|
|
|
# Now split by math environments |
|
|
|
"\n\\\begin{align}", |
|
|
|
"$$", |
|
|
|
"$", |
|
|
|
# Now split by the normal type of lines |
|
|
|
" ", |
|
|
|
"", |
|
|
|
] |
|
|
|
elif language == Language.HTML: |
|
|
|
return [ |
|
|
|
# First, try to split along HTML tags |
|
|
|
"<body", |
|
|
|
"<div", |
|
|
|
"<p", |
|
|
|
"<br", |
|
|
|
"<li", |
|
|
|
"<h1", |
|
|
|
"<h2", |
|
|
|
"<h3", |
|
|
|
"<h4", |
|
|
|
"<h5", |
|
|
|
"<h6", |
|
|
|
"<span", |
|
|
|
"<table", |
|
|
|
"<tr", |
|
|
|
"<td", |
|
|
|
"<th", |
|
|
|
"<ul", |
|
|
|
"<ol", |
|
|
|
"<header", |
|
|
|
"<footer", |
|
|
|
"<nav", |
|
|
|
# Head |
|
|
|
"<head", |
|
|
|
"<style", |
|
|
|
"<script", |
|
|
|
"<meta", |
|
|
|
"<title", |
|
|
|
"", |
|
|
|
] |
|
|
|
elif language == Language.SOL: |
|
|
|
return [ |
|
|
|
# Split along compiler information definitions |
|
|
|
"\npragma ", |
|
|
|
"\nusing ", |
|
|
|
# Split along contract definitions |
|
|
|
"\ncontract ", |
|
|
|
"\ninterface ", |
|
|
|
"\nlibrary ", |
|
|
|
# Split along method definitions |
|
|
|
"\nconstructor ", |
|
|
|
"\ntype ", |
|
|
|
"\nfunction ", |
|
|
|
"\nevent ", |
|
|
|
"\nmodifier ", |
|
|
|
"\nerror ", |
|
|
|
"\nstruct ", |
|
|
|
"\nenum ", |
|
|
|
# Split along control flow statements |
|
|
|
"\nif ", |
|
|
|
"\nfor ", |
|
|
|
"\nwhile ", |
|
|
|
"\ndo while ", |
|
|
|
"\nassembly ", |
|
|
|
# Split by the normal type of lines |
|
|
|
"\n\n", |
|
|
|
"\n", |
|
|
|
" ", |
|
|
|
"", |
|
|
|
] |
|
|
|
else: |
|
|
|
raise ValueError( |
|
|
|
f"Language {language} is not supported! " |
|
|
|
f"Please choose from {list(Language)}" |
|
|
|
) |