Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

vector_service.py 9.7KB

10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. import logging
  2. from core.model_manager import ModelInstance, ModelManager
  3. from core.model_runtime.entities.model_entities import ModelType
  4. from core.rag.datasource.keyword.keyword_factory import Keyword
  5. from core.rag.datasource.vdb.vector_factory import Vector
  6. from core.rag.index_processor.constant.index_type import IndexType
  7. from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
  8. from core.rag.models.document import Document
  9. from extensions.ext_database import db
  10. from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
  11. from models.dataset import Document as DatasetDocument
  12. from services.entities.knowledge_entities.knowledge_entities import ParentMode
  13. logger = logging.getLogger(__name__)
  14. class VectorService:
  15. @classmethod
  16. def create_segments_vector(
  17. cls, keywords_list: list[list[str]] | None, segments: list[DocumentSegment], dataset: Dataset, doc_form: str
  18. ):
  19. documents: list[Document] = []
  20. for segment in segments:
  21. if doc_form == IndexType.PARENT_CHILD_INDEX:
  22. dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
  23. if not dataset_document:
  24. logger.warning(
  25. "Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s",
  26. segment.document_id,
  27. segment.id,
  28. )
  29. continue
  30. # get the process rule
  31. processing_rule = (
  32. db.session.query(DatasetProcessRule)
  33. .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
  34. .first()
  35. )
  36. if not processing_rule:
  37. raise ValueError("No processing rule found.")
  38. # get embedding model instance
  39. if dataset.indexing_technique == "high_quality":
  40. # check embedding model setting
  41. model_manager = ModelManager()
  42. if dataset.embedding_model_provider:
  43. embedding_model_instance = model_manager.get_model_instance(
  44. tenant_id=dataset.tenant_id,
  45. provider=dataset.embedding_model_provider,
  46. model_type=ModelType.TEXT_EMBEDDING,
  47. model=dataset.embedding_model,
  48. )
  49. else:
  50. embedding_model_instance = model_manager.get_default_model_instance(
  51. tenant_id=dataset.tenant_id,
  52. model_type=ModelType.TEXT_EMBEDDING,
  53. )
  54. else:
  55. raise ValueError("The knowledge base index technique is not high quality!")
  56. cls.generate_child_chunks(
  57. segment, dataset_document, dataset, embedding_model_instance, processing_rule, False
  58. )
  59. else:
  60. rag_document = Document(
  61. page_content=segment.content,
  62. metadata={
  63. "doc_id": segment.index_node_id,
  64. "doc_hash": segment.index_node_hash,
  65. "document_id": segment.document_id,
  66. "dataset_id": segment.dataset_id,
  67. },
  68. )
  69. documents.append(rag_document)
  70. if len(documents) > 0:
  71. index_processor = IndexProcessorFactory(doc_form).init_index_processor()
  72. index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)
  73. @classmethod
  74. def update_segment_vector(cls, keywords: list[str] | None, segment: DocumentSegment, dataset: Dataset):
  75. # update segment index task
  76. # format new index
  77. document = Document(
  78. page_content=segment.content,
  79. metadata={
  80. "doc_id": segment.index_node_id,
  81. "doc_hash": segment.index_node_hash,
  82. "document_id": segment.document_id,
  83. "dataset_id": segment.dataset_id,
  84. },
  85. )
  86. if dataset.indexing_technique == "high_quality":
  87. # update vector index
  88. vector = Vector(dataset=dataset)
  89. vector.delete_by_ids([segment.index_node_id])
  90. vector.add_texts([document], duplicate_check=True)
  91. else:
  92. # update keyword index
  93. keyword = Keyword(dataset)
  94. keyword.delete_by_ids([segment.index_node_id])
  95. # save keyword index
  96. if keywords and len(keywords) > 0:
  97. keyword.add_texts([document], keywords_list=[keywords])
  98. else:
  99. keyword.add_texts([document])
  100. @classmethod
  101. def generate_child_chunks(
  102. cls,
  103. segment: DocumentSegment,
  104. dataset_document: DatasetDocument,
  105. dataset: Dataset,
  106. embedding_model_instance: ModelInstance,
  107. processing_rule: DatasetProcessRule,
  108. regenerate: bool = False,
  109. ):
  110. index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
  111. if regenerate:
  112. # delete child chunks
  113. index_processor.clean(dataset, [segment.index_node_id], with_keywords=True, delete_child_chunks=True)
  114. # generate child chunks
  115. document = Document(
  116. page_content=segment.content,
  117. metadata={
  118. "doc_id": segment.index_node_id,
  119. "doc_hash": segment.index_node_hash,
  120. "document_id": segment.document_id,
  121. "dataset_id": segment.dataset_id,
  122. },
  123. )
  124. # use full doc mode to generate segment's child chunk
  125. processing_rule_dict = processing_rule.to_dict()
  126. processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC.value
  127. documents = index_processor.transform(
  128. [document],
  129. embedding_model_instance=embedding_model_instance,
  130. process_rule=processing_rule_dict,
  131. tenant_id=dataset.tenant_id,
  132. doc_language=dataset_document.doc_language,
  133. )
  134. # save child chunks
  135. if documents and documents[0].children:
  136. index_processor.load(dataset, documents)
  137. for position, child_chunk in enumerate(documents[0].children, start=1):
  138. child_segment = ChildChunk(
  139. tenant_id=dataset.tenant_id,
  140. dataset_id=dataset.id,
  141. document_id=dataset_document.id,
  142. segment_id=segment.id,
  143. position=position,
  144. index_node_id=child_chunk.metadata["doc_id"],
  145. index_node_hash=child_chunk.metadata["doc_hash"],
  146. content=child_chunk.page_content,
  147. word_count=len(child_chunk.page_content),
  148. type="automatic",
  149. created_by=dataset_document.created_by,
  150. )
  151. db.session.add(child_segment)
  152. db.session.commit()
  153. @classmethod
  154. def create_child_chunk_vector(cls, child_segment: ChildChunk, dataset: Dataset):
  155. child_document = Document(
  156. page_content=child_segment.content,
  157. metadata={
  158. "doc_id": child_segment.index_node_id,
  159. "doc_hash": child_segment.index_node_hash,
  160. "document_id": child_segment.document_id,
  161. "dataset_id": child_segment.dataset_id,
  162. },
  163. )
  164. if dataset.indexing_technique == "high_quality":
  165. # save vector index
  166. vector = Vector(dataset=dataset)
  167. vector.add_texts([child_document], duplicate_check=True)
  168. @classmethod
  169. def update_child_chunk_vector(
  170. cls,
  171. new_child_chunks: list[ChildChunk],
  172. update_child_chunks: list[ChildChunk],
  173. delete_child_chunks: list[ChildChunk],
  174. dataset: Dataset,
  175. ):
  176. documents = []
  177. delete_node_ids = []
  178. for new_child_chunk in new_child_chunks:
  179. new_child_document = Document(
  180. page_content=new_child_chunk.content,
  181. metadata={
  182. "doc_id": new_child_chunk.index_node_id,
  183. "doc_hash": new_child_chunk.index_node_hash,
  184. "document_id": new_child_chunk.document_id,
  185. "dataset_id": new_child_chunk.dataset_id,
  186. },
  187. )
  188. documents.append(new_child_document)
  189. for update_child_chunk in update_child_chunks:
  190. child_document = Document(
  191. page_content=update_child_chunk.content,
  192. metadata={
  193. "doc_id": update_child_chunk.index_node_id,
  194. "doc_hash": update_child_chunk.index_node_hash,
  195. "document_id": update_child_chunk.document_id,
  196. "dataset_id": update_child_chunk.dataset_id,
  197. },
  198. )
  199. documents.append(child_document)
  200. delete_node_ids.append(update_child_chunk.index_node_id)
  201. for delete_child_chunk in delete_child_chunks:
  202. delete_node_ids.append(delete_child_chunk.index_node_id)
  203. if dataset.indexing_technique == "high_quality":
  204. # update vector index
  205. vector = Vector(dataset=dataset)
  206. if delete_node_ids:
  207. vector.delete_by_ids(delete_node_ids)
  208. if documents:
  209. vector.add_texts(documents, duplicate_check=True)
  210. @classmethod
  211. def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset):
  212. vector = Vector(dataset=dataset)
  213. vector.delete_by_ids([child_chunk.index_node_id])