您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

batch_create_segment_to_index_task.py 5.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import logging
  2. import tempfile
  3. import time
  4. import uuid
  5. from pathlib import Path
  6. import click
  7. import pandas as pd
  8. from celery import shared_task
  9. from sqlalchemy import func
  10. from sqlalchemy.orm import Session
  11. from core.model_manager import ModelManager
  12. from core.model_runtime.entities.model_entities import ModelType
  13. from extensions.ext_database import db
  14. from extensions.ext_redis import redis_client
  15. from extensions.ext_storage import storage
  16. from libs import helper
  17. from libs.datetime_utils import naive_utc_now
  18. from models.dataset import Dataset, Document, DocumentSegment
  19. from models.model import UploadFile
  20. from services.vector_service import VectorService
  21. logger = logging.getLogger(__name__)
  22. @shared_task(queue="dataset")
  23. def batch_create_segment_to_index_task(
  24. job_id: str,
  25. upload_file_id: str,
  26. dataset_id: str,
  27. document_id: str,
  28. tenant_id: str,
  29. user_id: str,
  30. ):
  31. """
  32. Async batch create segment to index
  33. :param job_id:
  34. :param upload_file_id:
  35. :param dataset_id:
  36. :param document_id:
  37. :param tenant_id:
  38. :param user_id:
  39. Usage: batch_create_segment_to_index_task.delay(job_id, upload_file_id, dataset_id, document_id, tenant_id, user_id)
  40. """
  41. logger.info(click.style(f"Start batch create segment jobId: {job_id}", fg="green"))
  42. start_at = time.perf_counter()
  43. indexing_cache_key = f"segment_batch_import_{job_id}"
  44. try:
  45. with Session(db.engine) as session:
  46. dataset = session.get(Dataset, dataset_id)
  47. if not dataset:
  48. raise ValueError("Dataset not exist.")
  49. dataset_document = session.get(Document, document_id)
  50. if not dataset_document:
  51. raise ValueError("Document not exist.")
  52. if (
  53. not dataset_document.enabled
  54. or dataset_document.archived
  55. or dataset_document.indexing_status != "completed"
  56. ):
  57. raise ValueError("Document is not available.")
  58. upload_file = session.get(UploadFile, upload_file_id)
  59. if not upload_file:
  60. raise ValueError("UploadFile not found.")
  61. with tempfile.TemporaryDirectory() as temp_dir:
  62. suffix = Path(upload_file.key).suffix
  63. # FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
  64. file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
  65. storage.download(upload_file.key, file_path)
  66. # Skip the first row
  67. df = pd.read_csv(file_path)
  68. content = []
  69. for _, row in df.iterrows():
  70. if dataset_document.doc_form == "qa_model":
  71. data = {"content": row.iloc[0], "answer": row.iloc[1]}
  72. else:
  73. data = {"content": row.iloc[0]}
  74. content.append(data)
  75. if len(content) == 0:
  76. raise ValueError("The CSV file is empty.")
  77. document_segments = []
  78. embedding_model = None
  79. if dataset.indexing_technique == "high_quality":
  80. model_manager = ModelManager()
  81. embedding_model = model_manager.get_model_instance(
  82. tenant_id=dataset.tenant_id,
  83. provider=dataset.embedding_model_provider,
  84. model_type=ModelType.TEXT_EMBEDDING,
  85. model=dataset.embedding_model,
  86. )
  87. word_count_change = 0
  88. if embedding_model:
  89. tokens_list = embedding_model.get_text_embedding_num_tokens(
  90. texts=[segment["content"] for segment in content]
  91. )
  92. else:
  93. tokens_list = [0] * len(content)
  94. for segment, tokens in zip(content, tokens_list):
  95. content = segment["content"]
  96. doc_id = str(uuid.uuid4())
  97. segment_hash = helper.generate_text_hash(content) # type: ignore
  98. max_position = (
  99. db.session.query(func.max(DocumentSegment.position))
  100. .where(DocumentSegment.document_id == dataset_document.id)
  101. .scalar()
  102. )
  103. segment_document = DocumentSegment(
  104. tenant_id=tenant_id,
  105. dataset_id=dataset_id,
  106. document_id=document_id,
  107. index_node_id=doc_id,
  108. index_node_hash=segment_hash,
  109. position=max_position + 1 if max_position else 1,
  110. content=content,
  111. word_count=len(content),
  112. tokens=tokens,
  113. created_by=user_id,
  114. indexing_at=naive_utc_now(),
  115. status="completed",
  116. completed_at=naive_utc_now(),
  117. )
  118. if dataset_document.doc_form == "qa_model":
  119. segment_document.answer = segment["answer"]
  120. segment_document.word_count += len(segment["answer"])
  121. word_count_change += segment_document.word_count
  122. db.session.add(segment_document)
  123. document_segments.append(segment_document)
  124. # update document word count
  125. assert dataset_document.word_count is not None
  126. dataset_document.word_count += word_count_change
  127. db.session.add(dataset_document)
  128. # add index to db
  129. VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
  130. db.session.commit()
  131. redis_client.setex(indexing_cache_key, 600, "completed")
  132. end_at = time.perf_counter()
  133. logger.info(
  134. click.style(
  135. f"Segment batch created job: {job_id} latency: {end_at - start_at}",
  136. fg="green",
  137. )
  138. )
  139. except Exception:
  140. logger.exception("Segments batch created index failed")
  141. redis_client.setex(indexing_cache_key, 600, "error")
  142. finally:
  143. db.session.close()