Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

enable_segments_to_index_task.py 4.4KB

10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
7 месяцев назад
10 месяцев назад
7 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
10 месяцев назад
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import logging
  2. import time
  3. import click
  4. from celery import shared_task
  5. from core.rag.index_processor.constant.index_type import IndexType
  6. from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
  7. from core.rag.models.document import ChildDocument, Document
  8. from extensions.ext_database import db
  9. from extensions.ext_redis import redis_client
  10. from libs.datetime_utils import naive_utc_now
  11. from models.dataset import Dataset, DocumentSegment
  12. from models.dataset import Document as DatasetDocument
  13. logger = logging.getLogger(__name__)
  14. @shared_task(queue="dataset")
  15. def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_id: str):
  16. """
  17. Async enable segments to index
  18. :param segment_ids: list of segment ids
  19. :param dataset_id: dataset id
  20. :param document_id: document id
  21. Usage: enable_segments_to_index_task.delay(segment_ids, dataset_id, document_id)
  22. """
  23. start_at = time.perf_counter()
  24. dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
  25. if not dataset:
  26. logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
  27. return
  28. dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
  29. if not dataset_document:
  30. logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
  31. db.session.close()
  32. return
  33. if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
  34. logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
  35. db.session.close()
  36. return
  37. # sync index processor
  38. index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
  39. segments = (
  40. db.session.query(DocumentSegment)
  41. .where(
  42. DocumentSegment.id.in_(segment_ids),
  43. DocumentSegment.dataset_id == dataset_id,
  44. DocumentSegment.document_id == document_id,
  45. )
  46. .all()
  47. )
  48. if not segments:
  49. logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan"))
  50. db.session.close()
  51. return
  52. try:
  53. documents = []
  54. for segment in segments:
  55. document = Document(
  56. page_content=segment.content,
  57. metadata={
  58. "doc_id": segment.index_node_id,
  59. "doc_hash": segment.index_node_hash,
  60. "document_id": document_id,
  61. "dataset_id": dataset_id,
  62. },
  63. )
  64. if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
  65. child_chunks = segment.get_child_chunks()
  66. if child_chunks:
  67. child_documents = []
  68. for child_chunk in child_chunks:
  69. child_document = ChildDocument(
  70. page_content=child_chunk.content,
  71. metadata={
  72. "doc_id": child_chunk.index_node_id,
  73. "doc_hash": child_chunk.index_node_hash,
  74. "document_id": document_id,
  75. "dataset_id": dataset_id,
  76. },
  77. )
  78. child_documents.append(child_document)
  79. document.children = child_documents
  80. documents.append(document)
  81. # save vector index
  82. index_processor.load(dataset, documents)
  83. end_at = time.perf_counter()
  84. logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green"))
  85. except Exception as e:
  86. logger.exception("enable segments to index failed")
  87. # update segment error msg
  88. db.session.query(DocumentSegment).where(
  89. DocumentSegment.id.in_(segment_ids),
  90. DocumentSegment.dataset_id == dataset_id,
  91. DocumentSegment.document_id == document_id,
  92. ).update(
  93. {
  94. "error": str(e),
  95. "status": "error",
  96. "disabled_at": naive_utc_now(),
  97. "enabled": False,
  98. }
  99. )
  100. db.session.commit()
  101. finally:
  102. for segment in segments:
  103. indexing_cache_key = f"segment_{segment.id}_indexing"
  104. redis_client.delete(indexing_cache_key)
  105. db.session.close()