Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

vector_service.py 9.6KB

10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. import logging
  2. from typing import Optional
  3. from core.model_manager import ModelInstance, ModelManager
  4. from core.model_runtime.entities.model_entities import ModelType
  5. from core.rag.datasource.keyword.keyword_factory import Keyword
  6. from core.rag.datasource.vdb.vector_factory import Vector
  7. from core.rag.index_processor.constant.index_type import IndexType
  8. from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
  9. from core.rag.models.document import Document
  10. from extensions.ext_database import db
  11. from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
  12. from models.dataset import Document as DatasetDocument
  13. from services.entities.knowledge_entities.knowledge_entities import ParentMode
  14. _logger = logging.getLogger(__name__)
  15. class VectorService:
  16. @classmethod
  17. def create_segments_vector(
  18. cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str
  19. ):
  20. documents = []
  21. for segment in segments:
  22. if doc_form == IndexType.PARENT_CHILD_INDEX:
  23. document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
  24. if not document:
  25. _logger.warning(
  26. "Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s",
  27. segment.document_id,
  28. segment.id,
  29. )
  30. continue
  31. # get the process rule
  32. processing_rule = (
  33. db.session.query(DatasetProcessRule)
  34. .filter(DatasetProcessRule.id == document.dataset_process_rule_id)
  35. .first()
  36. )
  37. if not processing_rule:
  38. raise ValueError("No processing rule found.")
  39. # get embedding model instance
  40. if dataset.indexing_technique == "high_quality":
  41. # check embedding model setting
  42. model_manager = ModelManager()
  43. if dataset.embedding_model_provider:
  44. embedding_model_instance = model_manager.get_model_instance(
  45. tenant_id=dataset.tenant_id,
  46. provider=dataset.embedding_model_provider,
  47. model_type=ModelType.TEXT_EMBEDDING,
  48. model=dataset.embedding_model,
  49. )
  50. else:
  51. embedding_model_instance = model_manager.get_default_model_instance(
  52. tenant_id=dataset.tenant_id,
  53. model_type=ModelType.TEXT_EMBEDDING,
  54. )
  55. else:
  56. raise ValueError("The knowledge base index technique is not high quality!")
  57. cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False)
  58. else:
  59. document = Document( # type: ignore
  60. page_content=segment.content,
  61. metadata={
  62. "doc_id": segment.index_node_id,
  63. "doc_hash": segment.index_node_hash,
  64. "document_id": segment.document_id,
  65. "dataset_id": segment.dataset_id,
  66. },
  67. )
  68. documents.append(document)
  69. if len(documents) > 0:
  70. index_processor = IndexProcessorFactory(doc_form).init_index_processor()
  71. index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) # type: ignore
  72. @classmethod
  73. def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):
  74. # update segment index task
  75. # format new index
  76. document = Document(
  77. page_content=segment.content,
  78. metadata={
  79. "doc_id": segment.index_node_id,
  80. "doc_hash": segment.index_node_hash,
  81. "document_id": segment.document_id,
  82. "dataset_id": segment.dataset_id,
  83. },
  84. )
  85. if dataset.indexing_technique == "high_quality":
  86. # update vector index
  87. vector = Vector(dataset=dataset)
  88. vector.delete_by_ids([segment.index_node_id])
  89. vector.add_texts([document], duplicate_check=True)
  90. # update keyword index
  91. keyword = Keyword(dataset)
  92. keyword.delete_by_ids([segment.index_node_id])
  93. # save keyword index
  94. if keywords and len(keywords) > 0:
  95. keyword.add_texts([document], keywords_list=[keywords])
  96. else:
  97. keyword.add_texts([document])
  98. @classmethod
  99. def generate_child_chunks(
  100. cls,
  101. segment: DocumentSegment,
  102. dataset_document: DatasetDocument,
  103. dataset: Dataset,
  104. embedding_model_instance: ModelInstance,
  105. processing_rule: DatasetProcessRule,
  106. regenerate: bool = False,
  107. ):
  108. index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
  109. if regenerate:
  110. # delete child chunks
  111. index_processor.clean(dataset, [segment.index_node_id], with_keywords=True, delete_child_chunks=True)
  112. # generate child chunks
  113. document = Document(
  114. page_content=segment.content,
  115. metadata={
  116. "doc_id": segment.index_node_id,
  117. "doc_hash": segment.index_node_hash,
  118. "document_id": segment.document_id,
  119. "dataset_id": segment.dataset_id,
  120. },
  121. )
  122. # use full doc mode to generate segment's child chunk
  123. processing_rule_dict = processing_rule.to_dict()
  124. processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC.value
  125. documents = index_processor.transform(
  126. [document],
  127. embedding_model_instance=embedding_model_instance,
  128. process_rule=processing_rule_dict,
  129. tenant_id=dataset.tenant_id,
  130. doc_language=dataset_document.doc_language,
  131. )
  132. # save child chunks
  133. if documents and documents[0].children:
  134. index_processor.load(dataset, documents)
  135. for position, child_chunk in enumerate(documents[0].children, start=1):
  136. child_segment = ChildChunk(
  137. tenant_id=dataset.tenant_id,
  138. dataset_id=dataset.id,
  139. document_id=dataset_document.id,
  140. segment_id=segment.id,
  141. position=position,
  142. index_node_id=child_chunk.metadata["doc_id"],
  143. index_node_hash=child_chunk.metadata["doc_hash"],
  144. content=child_chunk.page_content,
  145. word_count=len(child_chunk.page_content),
  146. type="automatic",
  147. created_by=dataset_document.created_by,
  148. )
  149. db.session.add(child_segment)
  150. db.session.commit()
  151. @classmethod
  152. def create_child_chunk_vector(cls, child_segment: ChildChunk, dataset: Dataset):
  153. child_document = Document(
  154. page_content=child_segment.content,
  155. metadata={
  156. "doc_id": child_segment.index_node_id,
  157. "doc_hash": child_segment.index_node_hash,
  158. "document_id": child_segment.document_id,
  159. "dataset_id": child_segment.dataset_id,
  160. },
  161. )
  162. if dataset.indexing_technique == "high_quality":
  163. # save vector index
  164. vector = Vector(dataset=dataset)
  165. vector.add_texts([child_document], duplicate_check=True)
  166. @classmethod
  167. def update_child_chunk_vector(
  168. cls,
  169. new_child_chunks: list[ChildChunk],
  170. update_child_chunks: list[ChildChunk],
  171. delete_child_chunks: list[ChildChunk],
  172. dataset: Dataset,
  173. ):
  174. documents = []
  175. delete_node_ids = []
  176. for new_child_chunk in new_child_chunks:
  177. new_child_document = Document(
  178. page_content=new_child_chunk.content,
  179. metadata={
  180. "doc_id": new_child_chunk.index_node_id,
  181. "doc_hash": new_child_chunk.index_node_hash,
  182. "document_id": new_child_chunk.document_id,
  183. "dataset_id": new_child_chunk.dataset_id,
  184. },
  185. )
  186. documents.append(new_child_document)
  187. for update_child_chunk in update_child_chunks:
  188. child_document = Document(
  189. page_content=update_child_chunk.content,
  190. metadata={
  191. "doc_id": update_child_chunk.index_node_id,
  192. "doc_hash": update_child_chunk.index_node_hash,
  193. "document_id": update_child_chunk.document_id,
  194. "dataset_id": update_child_chunk.dataset_id,
  195. },
  196. )
  197. documents.append(child_document)
  198. delete_node_ids.append(update_child_chunk.index_node_id)
  199. for delete_child_chunk in delete_child_chunks:
  200. delete_node_ids.append(delete_child_chunk.index_node_id)
  201. if dataset.indexing_technique == "high_quality":
  202. # update vector index
  203. vector = Vector(dataset=dataset)
  204. if delete_node_ids:
  205. vector.delete_by_ids(delete_node_ids)
  206. if documents:
  207. vector.add_texts(documents, duplicate_check=True)
  208. @classmethod
  209. def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset):
  210. vector = Vector(dataset=dataset)
  211. vector.delete_by_ids([child_chunk.index_node_id])