選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

batch_create_segment_to_index_task.py 5.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import datetime
  2. import logging
  3. import tempfile
  4. import time
  5. import uuid
  6. from pathlib import Path
  7. import click
  8. import pandas as pd
  9. from celery import shared_task # type: ignore
  10. from sqlalchemy import func
  11. from sqlalchemy.orm import Session
  12. from core.model_manager import ModelManager
  13. from core.model_runtime.entities.model_entities import ModelType
  14. from extensions.ext_database import db
  15. from extensions.ext_redis import redis_client
  16. from extensions.ext_storage import storage
  17. from libs import helper
  18. from models.dataset import Dataset, Document, DocumentSegment
  19. from models.model import UploadFile
  20. from services.vector_service import VectorService
  21. @shared_task(queue="dataset")
  22. def batch_create_segment_to_index_task(
  23. job_id: str,
  24. upload_file_id: str,
  25. dataset_id: str,
  26. document_id: str,
  27. tenant_id: str,
  28. user_id: str,
  29. ):
  30. """
  31. Async batch create segment to index
  32. :param job_id:
  33. :param upload_file_id:
  34. :param dataset_id:
  35. :param document_id:
  36. :param tenant_id:
  37. :param user_id:
  38. Usage: batch_create_segment_to_index_task.delay(job_id, upload_file_id, dataset_id, document_id, tenant_id, user_id)
  39. """
  40. logging.info(click.style(f"Start batch create segment jobId: {job_id}", fg="green"))
  41. start_at = time.perf_counter()
  42. indexing_cache_key = f"segment_batch_import_{job_id}"
  43. try:
  44. with Session(db.engine) as session:
  45. dataset = session.get(Dataset, dataset_id)
  46. if not dataset:
  47. raise ValueError("Dataset not exist.")
  48. dataset_document = session.get(Document, document_id)
  49. if not dataset_document:
  50. raise ValueError("Document not exist.")
  51. if (
  52. not dataset_document.enabled
  53. or dataset_document.archived
  54. or dataset_document.indexing_status != "completed"
  55. ):
  56. raise ValueError("Document is not available.")
  57. upload_file = session.get(UploadFile, upload_file_id)
  58. if not upload_file:
  59. raise ValueError("UploadFile not found.")
  60. with tempfile.TemporaryDirectory() as temp_dir:
  61. suffix = Path(upload_file.key).suffix
  62. # FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
  63. file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
  64. storage.download(upload_file.key, file_path)
  65. # Skip the first row
  66. df = pd.read_csv(file_path)
  67. content = []
  68. for index, row in df.iterrows():
  69. if dataset_document.doc_form == "qa_model":
  70. data = {"content": row.iloc[0], "answer": row.iloc[1]}
  71. else:
  72. data = {"content": row.iloc[0]}
  73. content.append(data)
  74. if len(content) == 0:
  75. raise ValueError("The CSV file is empty.")
  76. document_segments = []
  77. embedding_model = None
  78. if dataset.indexing_technique == "high_quality":
  79. model_manager = ModelManager()
  80. embedding_model = model_manager.get_model_instance(
  81. tenant_id=dataset.tenant_id,
  82. provider=dataset.embedding_model_provider,
  83. model_type=ModelType.TEXT_EMBEDDING,
  84. model=dataset.embedding_model,
  85. )
  86. word_count_change = 0
  87. if embedding_model:
  88. tokens_list = embedding_model.get_text_embedding_num_tokens(
  89. texts=[segment["content"] for segment in content]
  90. )
  91. else:
  92. tokens_list = [0] * len(content)
  93. for segment, tokens in zip(content, tokens_list):
  94. content = segment["content"]
  95. doc_id = str(uuid.uuid4())
  96. segment_hash = helper.generate_text_hash(content) # type: ignore
  97. max_position = (
  98. db.session.query(func.max(DocumentSegment.position))
  99. .where(DocumentSegment.document_id == dataset_document.id)
  100. .scalar()
  101. )
  102. segment_document = DocumentSegment(
  103. tenant_id=tenant_id,
  104. dataset_id=dataset_id,
  105. document_id=document_id,
  106. index_node_id=doc_id,
  107. index_node_hash=segment_hash,
  108. position=max_position + 1 if max_position else 1,
  109. content=content,
  110. word_count=len(content),
  111. tokens=tokens,
  112. created_by=user_id,
  113. indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
  114. status="completed",
  115. completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
  116. )
  117. if dataset_document.doc_form == "qa_model":
  118. segment_document.answer = segment["answer"]
  119. segment_document.word_count += len(segment["answer"])
  120. word_count_change += segment_document.word_count
  121. db.session.add(segment_document)
  122. document_segments.append(segment_document)
  123. # update document word count
  124. assert dataset_document.word_count is not None
  125. dataset_document.word_count += word_count_change
  126. db.session.add(dataset_document)
  127. # add index to db
  128. VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
  129. db.session.commit()
  130. redis_client.setex(indexing_cache_key, 600, "completed")
  131. end_at = time.perf_counter()
  132. logging.info(
  133. click.style(
  134. f"Segment batch created job: {job_id} latency: {end_at - start_at}",
  135. fg="green",
  136. )
  137. )
  138. except Exception:
  139. logging.exception("Segments batch created index failed")
  140. redis_client.setex(indexing_cache_key, 600, "error")
  141. finally:
  142. db.session.close()