소스 검색

feat: optimize split rule when use custom split segment identifier (#35)

tags/0.2.2
John Wang 2 년 전
부모
커밋
815f794eef
No account linked to committer's email address
2개의 변경된 파일73개의 추가작업 그리고 6개의 파일을 삭제
  1. 68
    0
      api/core/index/spiltter/fixed_text_splitter.py
  2. 5
    6
      api/core/indexing_runner.py

+ 68
- 0
api/core/index/spiltter/fixed_text_splitter.py 파일 보기

@@ -0,0 +1,68 @@
"""Functionality for splitting text."""
from __future__ import annotations

from typing import (
Any,
List,
Optional,
)

from langchain.text_splitter import RecursiveCharacterTextSplitter


class FixedRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
self._fixed_separator = fixed_separator
self._separators = separators or ["\n\n", "\n", " ", ""]

def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
if self._fixed_separator:
chunks = text.split(self._fixed_separator)
else:
chunks = list(text)

final_chunks = []
for chunk in chunks:
if self._length_function(chunk) > self._chunk_size:
final_chunks.extend(self.recursive_split_text(chunk))
else:
final_chunks.append(chunk)

return final_chunks

def recursive_split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
final_chunks = []
# Get appropriate separator to use
separator = self._separators[-1]
for _s in self._separators:
if _s == "":
separator = _s
break
if _s in text:
separator = _s
break
# Now that we have the separator, split the text
if separator:
splits = text.split(separator)
else:
splits = list(text)
# Now go merging things, recursively splitting longer texts.
_good_splits = []
for s in splits:
if self._length_function(s) < self._chunk_size:
_good_splits.append(s)
else:
if _good_splits:
merged_text = self._merge_splits(_good_splits, separator)
final_chunks.extend(merged_text)
_good_splits = []
other_info = self.recursive_split_text(s)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(_good_splits, separator)
final_chunks.extend(merged_text)
return final_chunks

+ 5
- 6
api/core/indexing_runner.py 파일 보기

@@ -18,6 +18,7 @@ from core.docstore.dataset_docstore import DatesetDocumentStore
from core.index.keyword_table_index import KeywordTableIndex
from core.index.readers.html_parser import HTMLParser
from core.index.readers.pdf_parser import PDFParser
from core.index.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from core.index.vector_index import VectorIndex
from core.llm.token_calculator import TokenCalculator
from extensions.ext_database import db
@@ -267,16 +268,14 @@ class IndexingRunner:
raise ValueError("Custom segment length should be between 50 and 1000.")

separator = segmentation["separator"]
if not separator:
separators = ["\n\n", "。", ".", " ", ""]
else:
if separator:
separator = separator.replace('\\n', '\n')
separators = [separator, ""]

character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
character_splitter = FixedRecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=segmentation["max_tokens"],
chunk_overlap=0,
separators=separators
fixed_separator=separator,
separators=["\n\n", "。", ".", " ", ""]
)
else:
# Automatic segmentation

Loading…
취소
저장