Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

vector_service.py 9.7KB

10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. import logging
  2. from typing import Optional
  3. from core.model_manager import ModelInstance, ModelManager
  4. from core.model_runtime.entities.model_entities import ModelType
  5. from core.rag.datasource.keyword.keyword_factory import Keyword
  6. from core.rag.datasource.vdb.vector_factory import Vector
  7. from core.rag.index_processor.constant.index_type import IndexType
  8. from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
  9. from core.rag.models.document import Document
  10. from extensions.ext_database import db
  11. from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
  12. from models.dataset import Document as DatasetDocument
  13. from services.entities.knowledge_entities.knowledge_entities import ParentMode
  14. logger = logging.getLogger(__name__)
  15. class VectorService:
  16. @classmethod
  17. def create_segments_vector(
  18. cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str
  19. ):
  20. documents: list[Document] = []
  21. for segment in segments:
  22. if doc_form == IndexType.PARENT_CHILD_INDEX:
  23. dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
  24. if not dataset_document:
  25. logger.warning(
  26. "Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s",
  27. segment.document_id,
  28. segment.id,
  29. )
  30. continue
  31. # get the process rule
  32. processing_rule = (
  33. db.session.query(DatasetProcessRule)
  34. .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
  35. .first()
  36. )
  37. if not processing_rule:
  38. raise ValueError("No processing rule found.")
  39. # get embedding model instance
  40. if dataset.indexing_technique == "high_quality":
  41. # check embedding model setting
  42. model_manager = ModelManager()
  43. if dataset.embedding_model_provider:
  44. embedding_model_instance = model_manager.get_model_instance(
  45. tenant_id=dataset.tenant_id,
  46. provider=dataset.embedding_model_provider,
  47. model_type=ModelType.TEXT_EMBEDDING,
  48. model=dataset.embedding_model,
  49. )
  50. else:
  51. embedding_model_instance = model_manager.get_default_model_instance(
  52. tenant_id=dataset.tenant_id,
  53. model_type=ModelType.TEXT_EMBEDDING,
  54. )
  55. else:
  56. raise ValueError("The knowledge base index technique is not high quality!")
  57. cls.generate_child_chunks(
  58. segment, dataset_document, dataset, embedding_model_instance, processing_rule, False
  59. )
  60. else:
  61. rag_document = Document(
  62. page_content=segment.content,
  63. metadata={
  64. "doc_id": segment.index_node_id,
  65. "doc_hash": segment.index_node_hash,
  66. "document_id": segment.document_id,
  67. "dataset_id": segment.dataset_id,
  68. },
  69. )
  70. documents.append(rag_document)
  71. if len(documents) > 0:
  72. index_processor = IndexProcessorFactory(doc_form).init_index_processor()
  73. index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)
  74. @classmethod
  75. def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):
  76. # update segment index task
  77. # format new index
  78. document = Document(
  79. page_content=segment.content,
  80. metadata={
  81. "doc_id": segment.index_node_id,
  82. "doc_hash": segment.index_node_hash,
  83. "document_id": segment.document_id,
  84. "dataset_id": segment.dataset_id,
  85. },
  86. )
  87. if dataset.indexing_technique == "high_quality":
  88. # update vector index
  89. vector = Vector(dataset=dataset)
  90. vector.delete_by_ids([segment.index_node_id])
  91. vector.add_texts([document], duplicate_check=True)
  92. else:
  93. # update keyword index
  94. keyword = Keyword(dataset)
  95. keyword.delete_by_ids([segment.index_node_id])
  96. # save keyword index
  97. if keywords and len(keywords) > 0:
  98. keyword.add_texts([document], keywords_list=[keywords])
  99. else:
  100. keyword.add_texts([document])
  101. @classmethod
  102. def generate_child_chunks(
  103. cls,
  104. segment: DocumentSegment,
  105. dataset_document: DatasetDocument,
  106. dataset: Dataset,
  107. embedding_model_instance: ModelInstance,
  108. processing_rule: DatasetProcessRule,
  109. regenerate: bool = False,
  110. ):
  111. index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
  112. if regenerate:
  113. # delete child chunks
  114. index_processor.clean(dataset, [segment.index_node_id], with_keywords=True, delete_child_chunks=True)
  115. # generate child chunks
  116. document = Document(
  117. page_content=segment.content,
  118. metadata={
  119. "doc_id": segment.index_node_id,
  120. "doc_hash": segment.index_node_hash,
  121. "document_id": segment.document_id,
  122. "dataset_id": segment.dataset_id,
  123. },
  124. )
  125. # use full doc mode to generate segment's child chunk
  126. processing_rule_dict = processing_rule.to_dict()
  127. processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC.value
  128. documents = index_processor.transform(
  129. [document],
  130. embedding_model_instance=embedding_model_instance,
  131. process_rule=processing_rule_dict,
  132. tenant_id=dataset.tenant_id,
  133. doc_language=dataset_document.doc_language,
  134. )
  135. # save child chunks
  136. if documents and documents[0].children:
  137. index_processor.load(dataset, documents)
  138. for position, child_chunk in enumerate(documents[0].children, start=1):
  139. child_segment = ChildChunk(
  140. tenant_id=dataset.tenant_id,
  141. dataset_id=dataset.id,
  142. document_id=dataset_document.id,
  143. segment_id=segment.id,
  144. position=position,
  145. index_node_id=child_chunk.metadata["doc_id"],
  146. index_node_hash=child_chunk.metadata["doc_hash"],
  147. content=child_chunk.page_content,
  148. word_count=len(child_chunk.page_content),
  149. type="automatic",
  150. created_by=dataset_document.created_by,
  151. )
  152. db.session.add(child_segment)
  153. db.session.commit()
  154. @classmethod
  155. def create_child_chunk_vector(cls, child_segment: ChildChunk, dataset: Dataset):
  156. child_document = Document(
  157. page_content=child_segment.content,
  158. metadata={
  159. "doc_id": child_segment.index_node_id,
  160. "doc_hash": child_segment.index_node_hash,
  161. "document_id": child_segment.document_id,
  162. "dataset_id": child_segment.dataset_id,
  163. },
  164. )
  165. if dataset.indexing_technique == "high_quality":
  166. # save vector index
  167. vector = Vector(dataset=dataset)
  168. vector.add_texts([child_document], duplicate_check=True)
  169. @classmethod
  170. def update_child_chunk_vector(
  171. cls,
  172. new_child_chunks: list[ChildChunk],
  173. update_child_chunks: list[ChildChunk],
  174. delete_child_chunks: list[ChildChunk],
  175. dataset: Dataset,
  176. ):
  177. documents = []
  178. delete_node_ids = []
  179. for new_child_chunk in new_child_chunks:
  180. new_child_document = Document(
  181. page_content=new_child_chunk.content,
  182. metadata={
  183. "doc_id": new_child_chunk.index_node_id,
  184. "doc_hash": new_child_chunk.index_node_hash,
  185. "document_id": new_child_chunk.document_id,
  186. "dataset_id": new_child_chunk.dataset_id,
  187. },
  188. )
  189. documents.append(new_child_document)
  190. for update_child_chunk in update_child_chunks:
  191. child_document = Document(
  192. page_content=update_child_chunk.content,
  193. metadata={
  194. "doc_id": update_child_chunk.index_node_id,
  195. "doc_hash": update_child_chunk.index_node_hash,
  196. "document_id": update_child_chunk.document_id,
  197. "dataset_id": update_child_chunk.dataset_id,
  198. },
  199. )
  200. documents.append(child_document)
  201. delete_node_ids.append(update_child_chunk.index_node_id)
  202. for delete_child_chunk in delete_child_chunks:
  203. delete_node_ids.append(delete_child_chunk.index_node_id)
  204. if dataset.indexing_technique == "high_quality":
  205. # update vector index
  206. vector = Vector(dataset=dataset)
  207. if delete_node_ids:
  208. vector.delete_by_ids(delete_node_ids)
  209. if documents:
  210. vector.add_texts(documents, duplicate_check=True)
  211. @classmethod
  212. def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset):
  213. vector = Vector(dataset=dataset)
  214. vector.delete_by_ids([child_chunk.index_node_id])