|
|
|
@@ -0,0 +1,68 @@ |
|
|
|
"""Functionality for splitting text.""" |
|
|
|
from __future__ import annotations |
|
|
|
|
|
|
|
from typing import ( |
|
|
|
Any, |
|
|
|
List, |
|
|
|
Optional, |
|
|
|
) |
|
|
|
|
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
|
|
|
|
|
|
|
|
|
class FixedRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): |
|
|
|
def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any): |
|
|
|
"""Create a new TextSplitter.""" |
|
|
|
super().__init__(**kwargs) |
|
|
|
self._fixed_separator = fixed_separator |
|
|
|
self._separators = separators or ["\n\n", "\n", " ", ""] |
|
|
|
|
|
|
|
def split_text(self, text: str) -> List[str]: |
|
|
|
"""Split incoming text and return chunks.""" |
|
|
|
if self._fixed_separator: |
|
|
|
chunks = text.split(self._fixed_separator) |
|
|
|
else: |
|
|
|
chunks = list(text) |
|
|
|
|
|
|
|
final_chunks = [] |
|
|
|
for chunk in chunks: |
|
|
|
if self._length_function(chunk) > self._chunk_size: |
|
|
|
final_chunks.extend(self.recursive_split_text(chunk)) |
|
|
|
else: |
|
|
|
final_chunks.append(chunk) |
|
|
|
|
|
|
|
return final_chunks |
|
|
|
|
|
|
|
def recursive_split_text(self, text: str) -> List[str]: |
|
|
|
"""Split incoming text and return chunks.""" |
|
|
|
final_chunks = [] |
|
|
|
# Get appropriate separator to use |
|
|
|
separator = self._separators[-1] |
|
|
|
for _s in self._separators: |
|
|
|
if _s == "": |
|
|
|
separator = _s |
|
|
|
break |
|
|
|
if _s in text: |
|
|
|
separator = _s |
|
|
|
break |
|
|
|
# Now that we have the separator, split the text |
|
|
|
if separator: |
|
|
|
splits = text.split(separator) |
|
|
|
else: |
|
|
|
splits = list(text) |
|
|
|
# Now go merging things, recursively splitting longer texts. |
|
|
|
_good_splits = [] |
|
|
|
for s in splits: |
|
|
|
if self._length_function(s) < self._chunk_size: |
|
|
|
_good_splits.append(s) |
|
|
|
else: |
|
|
|
if _good_splits: |
|
|
|
merged_text = self._merge_splits(_good_splits, separator) |
|
|
|
final_chunks.extend(merged_text) |
|
|
|
_good_splits = [] |
|
|
|
other_info = self.recursive_split_text(s) |
|
|
|
final_chunks.extend(other_info) |
|
|
|
if _good_splits: |
|
|
|
merged_text = self._merge_splits(_good_splits, separator) |
|
|
|
final_chunks.extend(merged_text) |
|
|
|
return final_chunks |