You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

text_splitter.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. from __future__ import annotations
  2. import copy
  3. import logging
  4. import re
  5. from abc import ABC, abstractmethod
  6. from collections.abc import Callable, Collection, Iterable, Sequence, Set
  7. from dataclasses import dataclass
  8. from typing import (
  9. Any,
  10. Literal,
  11. Optional,
  12. TypeVar,
  13. Union,
  14. )
  15. from core.rag.models.document import BaseDocumentTransformer, Document
  16. logger = logging.getLogger(__name__)
  17. TS = TypeVar("TS", bound="TextSplitter")
  18. def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> list[str]:
  19. # Now that we have the separator, split the text
  20. if separator:
  21. if keep_separator:
  22. # The parentheses in the pattern keep the delimiters in the result.
  23. _splits = re.split(f"({re.escape(separator)})", text)
  24. splits = [_splits[i - 1] + _splits[i] for i in range(1, len(_splits), 2)]
  25. if len(_splits) % 2 != 0:
  26. splits += _splits[-1:]
  27. else:
  28. splits = re.split(separator, text)
  29. else:
  30. splits = list(text)
  31. return [s for s in splits if (s not in {"", "\n"})]
  32. class TextSplitter(BaseDocumentTransformer, ABC):
  33. """Interface for splitting text into chunks."""
  34. def __init__(
  35. self,
  36. chunk_size: int = 4000,
  37. chunk_overlap: int = 200,
  38. length_function: Callable[[list[str]], list[int]] = lambda x: [len(x) for x in x],
  39. keep_separator: bool = False,
  40. add_start_index: bool = False,
  41. ) -> None:
  42. """Create a new TextSplitter.
  43. Args:
  44. chunk_size: Maximum size of chunks to return
  45. chunk_overlap: Overlap in characters between chunks
  46. length_function: Function that measures the length of given chunks
  47. keep_separator: Whether to keep the separator in the chunks
  48. add_start_index: If `True`, includes chunk's start index in metadata
  49. """
  50. if chunk_overlap > chunk_size:
  51. raise ValueError(
  52. f"Got a larger chunk overlap ({chunk_overlap}) than chunk size ({chunk_size}), should be smaller."
  53. )
  54. self._chunk_size = chunk_size
  55. self._chunk_overlap = chunk_overlap
  56. self._length_function = length_function
  57. self._keep_separator = keep_separator
  58. self._add_start_index = add_start_index
  59. @abstractmethod
  60. def split_text(self, text: str) -> list[str]:
  61. """Split text into multiple components."""
  62. def create_documents(self, texts: list[str], metadatas: Optional[list[dict]] = None) -> list[Document]:
  63. """Create documents from a list of texts."""
  64. _metadatas = metadatas or [{}] * len(texts)
  65. documents = []
  66. for i, text in enumerate(texts):
  67. index = -1
  68. for chunk in self.split_text(text):
  69. metadata = copy.deepcopy(_metadatas[i])
  70. if self._add_start_index:
  71. index = text.find(chunk, index + 1)
  72. metadata["start_index"] = index
  73. new_doc = Document(page_content=chunk, metadata=metadata)
  74. documents.append(new_doc)
  75. return documents
  76. def split_documents(self, documents: Iterable[Document]) -> list[Document]:
  77. """Split documents."""
  78. texts, metadatas = [], []
  79. for doc in documents:
  80. texts.append(doc.page_content)
  81. metadatas.append(doc.metadata or {})
  82. return self.create_documents(texts, metadatas=metadatas)
  83. def _join_docs(self, docs: list[str], separator: str) -> Optional[str]:
  84. text = separator.join(docs)
  85. text = text.strip()
  86. if text == "":
  87. return None
  88. else:
  89. return text
  90. def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int]) -> list[str]:
  91. # We now want to combine these smaller pieces into medium size
  92. # chunks to send to the LLM.
  93. separator_len = self._length_function([separator])[0]
  94. docs = []
  95. current_doc: list[str] = []
  96. total = 0
  97. index = 0
  98. for d in splits:
  99. _len = lengths[index]
  100. if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size:
  101. if total > self._chunk_size:
  102. logger.warning(
  103. "Created a chunk of size %s, which is longer than the specified %s", total, self._chunk_size
  104. )
  105. if len(current_doc) > 0:
  106. doc = self._join_docs(current_doc, separator)
  107. if doc is not None:
  108. docs.append(doc)
  109. # Keep on popping if:
  110. # - we have a larger chunk than in the chunk overlap
  111. # - or if we still have any chunks and the length is long
  112. while total > self._chunk_overlap or (
  113. total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0
  114. ):
  115. total -= self._length_function([current_doc[0]])[0] + (
  116. separator_len if len(current_doc) > 1 else 0
  117. )
  118. current_doc = current_doc[1:]
  119. current_doc.append(d)
  120. total += _len + (separator_len if len(current_doc) > 1 else 0)
  121. index += 1
  122. doc = self._join_docs(current_doc, separator)
  123. if doc is not None:
  124. docs.append(doc)
  125. return docs
  126. @classmethod
  127. def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:
  128. """Text splitter that uses HuggingFace tokenizer to count length."""
  129. try:
  130. from transformers import PreTrainedTokenizerBase # type: ignore
  131. if not isinstance(tokenizer, PreTrainedTokenizerBase):
  132. raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase")
  133. def _huggingface_tokenizer_length(text: str) -> int:
  134. return len(tokenizer.encode(text))
  135. except ImportError:
  136. raise ValueError(
  137. "Could not import transformers python package. Please install it with `pip install transformers`."
  138. )
  139. return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs)
  140. def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
  141. """Transform sequence of documents by splitting them."""
  142. return self.split_documents(list(documents))
  143. async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
  144. """Asynchronously transform a sequence of documents by splitting them."""
  145. raise NotImplementedError
  146. # @dataclass(frozen=True, kw_only=True, slots=True)
  147. @dataclass(frozen=True)
  148. class Tokenizer:
  149. chunk_overlap: int
  150. tokens_per_chunk: int
  151. decode: Callable[[list[int]], str]
  152. encode: Callable[[str], list[int]]
  153. def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]:
  154. """Split incoming text and return chunks using tokenizer."""
  155. splits: list[str] = []
  156. input_ids = tokenizer.encode(text)
  157. start_idx = 0
  158. cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
  159. chunk_ids = input_ids[start_idx:cur_idx]
  160. while start_idx < len(input_ids):
  161. splits.append(tokenizer.decode(chunk_ids))
  162. start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
  163. cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
  164. chunk_ids = input_ids[start_idx:cur_idx]
  165. return splits
  166. class TokenTextSplitter(TextSplitter):
  167. """Splitting text to tokens using model tokenizer."""
  168. def __init__(
  169. self,
  170. encoding_name: str = "gpt2",
  171. model_name: Optional[str] = None,
  172. allowed_special: Union[Literal["all"], Set[str]] = set(),
  173. disallowed_special: Union[Literal["all"], Collection[str]] = "all",
  174. **kwargs: Any,
  175. ) -> None:
  176. """Create a new TextSplitter."""
  177. super().__init__(**kwargs)
  178. try:
  179. import tiktoken
  180. except ImportError:
  181. raise ImportError(
  182. "Could not import tiktoken python package. "
  183. "This is needed in order to for TokenTextSplitter. "
  184. "Please install it with `pip install tiktoken`."
  185. )
  186. if model_name is not None:
  187. enc = tiktoken.encoding_for_model(model_name)
  188. else:
  189. enc = tiktoken.get_encoding(encoding_name)
  190. self._tokenizer = enc
  191. self._allowed_special = allowed_special
  192. self._disallowed_special = disallowed_special
  193. def split_text(self, text: str) -> list[str]:
  194. def _encode(_text: str) -> list[int]:
  195. return self._tokenizer.encode(
  196. _text,
  197. allowed_special=self._allowed_special,
  198. disallowed_special=self._disallowed_special,
  199. )
  200. tokenizer = Tokenizer(
  201. chunk_overlap=self._chunk_overlap,
  202. tokens_per_chunk=self._chunk_size,
  203. decode=self._tokenizer.decode,
  204. encode=_encode,
  205. )
  206. return split_text_on_tokens(text=text, tokenizer=tokenizer)
  207. class RecursiveCharacterTextSplitter(TextSplitter):
  208. """Splitting text by recursively look at characters.
  209. Recursively tries to split by different characters to find one
  210. that works.
  211. """
  212. def __init__(
  213. self,
  214. separators: Optional[list[str]] = None,
  215. keep_separator: bool = True,
  216. **kwargs: Any,
  217. ) -> None:
  218. """Create a new TextSplitter."""
  219. super().__init__(keep_separator=keep_separator, **kwargs)
  220. self._separators = separators or ["\n\n", "\n", " ", ""]
  221. def _split_text(self, text: str, separators: list[str]) -> list[str]:
  222. final_chunks = []
  223. separator = separators[-1]
  224. new_separators = []
  225. for i, _s in enumerate(separators):
  226. if _s == "":
  227. separator = _s
  228. break
  229. if re.search(_s, text):
  230. separator = _s
  231. new_separators = separators[i + 1 :]
  232. break
  233. splits = _split_text_with_regex(text, separator, self._keep_separator)
  234. _good_splits = []
  235. _good_splits_lengths = [] # cache the lengths of the splits
  236. _separator = "" if self._keep_separator else separator
  237. s_lens = self._length_function(splits)
  238. for s, s_len in zip(splits, s_lens):
  239. if s_len < self._chunk_size:
  240. _good_splits.append(s)
  241. _good_splits_lengths.append(s_len)
  242. else:
  243. if _good_splits:
  244. merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
  245. final_chunks.extend(merged_text)
  246. _good_splits = []
  247. _good_splits_lengths = []
  248. if not new_separators:
  249. final_chunks.append(s)
  250. else:
  251. other_info = self._split_text(s, new_separators)
  252. final_chunks.extend(other_info)
  253. if _good_splits:
  254. merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
  255. final_chunks.extend(merged_text)
  256. return final_chunks
  257. def split_text(self, text: str) -> list[str]:
  258. return self._split_text(text, self._separators)