Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

datasets_segments.py 29KB


  1. import uuid
  2. from flask import request
  3. from flask_login import current_user
  4. from flask_restful import Resource, marshal, reqparse
  5. from sqlalchemy import select
  6. from werkzeug.exceptions import Forbidden, NotFound
  7. import services
  8. from controllers.console import api
  9. from controllers.console.app.error import ProviderNotInitializeError
  10. from controllers.console.datasets.error import (
  11. ChildChunkDeleteIndexError,
  12. ChildChunkIndexingError,
  13. InvalidActionError,
  14. )
  15. from controllers.console.wraps import (
  16. account_initialization_required,
  17. cloud_edition_billing_knowledge_limit_check,
  18. cloud_edition_billing_rate_limit_check,
  19. cloud_edition_billing_resource_check,
  20. setup_required,
  21. )
  22. from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
  23. from core.model_manager import ModelManager
  24. from core.model_runtime.entities.model_entities import ModelType
  25. from extensions.ext_database import db
  26. from extensions.ext_redis import redis_client
  27. from fields.segment_fields import child_chunk_fields, segment_fields
  28. from libs.login import login_required
  29. from models.dataset import ChildChunk, DocumentSegment
  30. from models.model import UploadFile
  31. from services.dataset_service import DatasetService, DocumentService, SegmentService
  32. from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
  33. from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
  34. from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
  35. from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
  36. class DatasetDocumentSegmentListApi(Resource):
  37. @setup_required
  38. @login_required
  39. @account_initialization_required
  40. def get(self, dataset_id, document_id):
  41. dataset_id = str(dataset_id)
  42. document_id = str(document_id)
  43. dataset = DatasetService.get_dataset(dataset_id)
  44. if not dataset:
  45. raise NotFound("Dataset not found.")
  46. try:
  47. DatasetService.check_dataset_permission(dataset, current_user)
  48. except services.errors.account.NoPermissionError as e:
  49. raise Forbidden(str(e))
  50. document = DocumentService.get_document(dataset_id, document_id)
  51. if not document:
  52. raise NotFound("Document not found.")
  53. parser = reqparse.RequestParser()
  54. parser.add_argument("limit", type=int, default=20, location="args")
  55. parser.add_argument("status", type=str, action="append", default=[], location="args")
  56. parser.add_argument("hit_count_gte", type=int, default=None, location="args")
  57. parser.add_argument("enabled", type=str, default="all", location="args")
  58. parser.add_argument("keyword", type=str, default=None, location="args")
  59. parser.add_argument("page", type=int, default=1, location="args")
  60. args = parser.parse_args()
  61. page = args["page"]
  62. limit = min(args["limit"], 100)
  63. status_list = args["status"]
  64. hit_count_gte = args["hit_count_gte"]
  65. keyword = args["keyword"]
  66. query = (
  67. select(DocumentSegment)
  68. .where(
  69. DocumentSegment.document_id == str(document_id),
  70. DocumentSegment.tenant_id == current_user.current_tenant_id,
  71. )
  72. .order_by(DocumentSegment.position.asc())
  73. )
  74. if status_list:
  75. query = query.where(DocumentSegment.status.in_(status_list))
  76. if hit_count_gte is not None:
  77. query = query.where(DocumentSegment.hit_count >= hit_count_gte)
  78. if keyword:
  79. query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
  80. if args["enabled"].lower() != "all":
  81. if args["enabled"].lower() == "true":
  82. query = query.where(DocumentSegment.enabled == True)
  83. elif args["enabled"].lower() == "false":
  84. query = query.where(DocumentSegment.enabled == False)
  85. segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
  86. response = {
  87. "data": marshal(segments.items, segment_fields),
  88. "limit": limit,
  89. "total": segments.total,
  90. "total_pages": segments.pages,
  91. "page": page,
  92. }
  93. return response, 200
  94. @setup_required
  95. @login_required
  96. @account_initialization_required
  97. @cloud_edition_billing_rate_limit_check("knowledge")
  98. def delete(self, dataset_id, document_id):
  99. # check dataset
  100. dataset_id = str(dataset_id)
  101. dataset = DatasetService.get_dataset(dataset_id)
  102. if not dataset:
  103. raise NotFound("Dataset not found.")
  104. # check user's model setting
  105. DatasetService.check_dataset_model_setting(dataset)
  106. # check document
  107. document_id = str(document_id)
  108. document = DocumentService.get_document(dataset_id, document_id)
  109. if not document:
  110. raise NotFound("Document not found.")
  111. segment_ids = request.args.getlist("segment_id")
  112. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  113. if not current_user.is_dataset_editor:
  114. raise Forbidden()
  115. try:
  116. DatasetService.check_dataset_permission(dataset, current_user)
  117. except services.errors.account.NoPermissionError as e:
  118. raise Forbidden(str(e))
  119. SegmentService.delete_segments(segment_ids, document, dataset)
  120. return {"result": "success"}, 204
  121. class DatasetDocumentSegmentApi(Resource):
  122. @setup_required
  123. @login_required
  124. @account_initialization_required
  125. @cloud_edition_billing_resource_check("vector_space")
  126. @cloud_edition_billing_rate_limit_check("knowledge")
  127. def patch(self, dataset_id, document_id, action):
  128. dataset_id = str(dataset_id)
  129. dataset = DatasetService.get_dataset(dataset_id)
  130. if not dataset:
  131. raise NotFound("Dataset not found.")
  132. document_id = str(document_id)
  133. document = DocumentService.get_document(dataset_id, document_id)
  134. if not document:
  135. raise NotFound("Document not found.")
  136. # check user's model setting
  137. DatasetService.check_dataset_model_setting(dataset)
  138. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  139. if not current_user.is_dataset_editor:
  140. raise Forbidden()
  141. try:
  142. DatasetService.check_dataset_permission(dataset, current_user)
  143. except services.errors.account.NoPermissionError as e:
  144. raise Forbidden(str(e))
  145. if dataset.indexing_technique == "high_quality":
  146. # check embedding model setting
  147. try:
  148. model_manager = ModelManager()
  149. model_manager.get_model_instance(
  150. tenant_id=current_user.current_tenant_id,
  151. provider=dataset.embedding_model_provider,
  152. model_type=ModelType.TEXT_EMBEDDING,
  153. model=dataset.embedding_model,
  154. )
  155. except LLMBadRequestError:
  156. raise ProviderNotInitializeError(
  157. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  158. )
  159. except ProviderTokenNotInitError as ex:
  160. raise ProviderNotInitializeError(ex.description)
  161. segment_ids = request.args.getlist("segment_id")
  162. document_indexing_cache_key = f"document_{document.id}_indexing"
  163. cache_result = redis_client.get(document_indexing_cache_key)
  164. if cache_result is not None:
  165. raise InvalidActionError("Document is being indexed, please try again later")
  166. try:
  167. SegmentService.update_segments_status(segment_ids, action, dataset, document)
  168. except Exception as e:
  169. raise InvalidActionError(str(e))
  170. return {"result": "success"}, 200
  171. class DatasetDocumentSegmentAddApi(Resource):
  172. @setup_required
  173. @login_required
  174. @account_initialization_required
  175. @cloud_edition_billing_resource_check("vector_space")
  176. @cloud_edition_billing_knowledge_limit_check("add_segment")
  177. @cloud_edition_billing_rate_limit_check("knowledge")
  178. def post(self, dataset_id, document_id):
  179. # check dataset
  180. dataset_id = str(dataset_id)
  181. dataset = DatasetService.get_dataset(dataset_id)
  182. if not dataset:
  183. raise NotFound("Dataset not found.")
  184. # check document
  185. document_id = str(document_id)
  186. document = DocumentService.get_document(dataset_id, document_id)
  187. if not document:
  188. raise NotFound("Document not found.")
  189. if not current_user.is_dataset_editor:
  190. raise Forbidden()
  191. # check embedding model setting
  192. if dataset.indexing_technique == "high_quality":
  193. try:
  194. model_manager = ModelManager()
  195. model_manager.get_model_instance(
  196. tenant_id=current_user.current_tenant_id,
  197. provider=dataset.embedding_model_provider,
  198. model_type=ModelType.TEXT_EMBEDDING,
  199. model=dataset.embedding_model,
  200. )
  201. except LLMBadRequestError:
  202. raise ProviderNotInitializeError(
  203. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  204. )
  205. except ProviderTokenNotInitError as ex:
  206. raise ProviderNotInitializeError(ex.description)
  207. try:
  208. DatasetService.check_dataset_permission(dataset, current_user)
  209. except services.errors.account.NoPermissionError as e:
  210. raise Forbidden(str(e))
  211. # validate args
  212. parser = reqparse.RequestParser()
  213. parser.add_argument("content", type=str, required=True, nullable=False, location="json")
  214. parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
  215. parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
  216. args = parser.parse_args()
  217. SegmentService.segment_create_args_validate(args, document)
  218. segment = SegmentService.create_segment(args, document, dataset)
  219. return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
  220. class DatasetDocumentSegmentUpdateApi(Resource):
  221. @setup_required
  222. @login_required
  223. @account_initialization_required
  224. @cloud_edition_billing_resource_check("vector_space")
  225. @cloud_edition_billing_rate_limit_check("knowledge")
  226. def patch(self, dataset_id, document_id, segment_id):
  227. # check dataset
  228. dataset_id = str(dataset_id)
  229. dataset = DatasetService.get_dataset(dataset_id)
  230. if not dataset:
  231. raise NotFound("Dataset not found.")
  232. # check user's model setting
  233. DatasetService.check_dataset_model_setting(dataset)
  234. # check document
  235. document_id = str(document_id)
  236. document = DocumentService.get_document(dataset_id, document_id)
  237. if not document:
  238. raise NotFound("Document not found.")
  239. if dataset.indexing_technique == "high_quality":
  240. # check embedding model setting
  241. try:
  242. model_manager = ModelManager()
  243. model_manager.get_model_instance(
  244. tenant_id=current_user.current_tenant_id,
  245. provider=dataset.embedding_model_provider,
  246. model_type=ModelType.TEXT_EMBEDDING,
  247. model=dataset.embedding_model,
  248. )
  249. except LLMBadRequestError:
  250. raise ProviderNotInitializeError(
  251. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  252. )
  253. except ProviderTokenNotInitError as ex:
  254. raise ProviderNotInitializeError(ex.description)
  255. # check segment
  256. segment_id = str(segment_id)
  257. segment = (
  258. db.session.query(DocumentSegment)
  259. .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
  260. .first()
  261. )
  262. if not segment:
  263. raise NotFound("Segment not found.")
  264. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  265. if not current_user.is_dataset_editor:
  266. raise Forbidden()
  267. try:
  268. DatasetService.check_dataset_permission(dataset, current_user)
  269. except services.errors.account.NoPermissionError as e:
  270. raise Forbidden(str(e))
  271. # validate args
  272. parser = reqparse.RequestParser()
  273. parser.add_argument("content", type=str, required=True, nullable=False, location="json")
  274. parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
  275. parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
  276. parser.add_argument(
  277. "regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
  278. )
  279. args = parser.parse_args()
  280. SegmentService.segment_create_args_validate(args, document)
  281. segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset)
  282. return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
  283. @setup_required
  284. @login_required
  285. @account_initialization_required
  286. @cloud_edition_billing_rate_limit_check("knowledge")
  287. def delete(self, dataset_id, document_id, segment_id):
  288. # check dataset
  289. dataset_id = str(dataset_id)
  290. dataset = DatasetService.get_dataset(dataset_id)
  291. if not dataset:
  292. raise NotFound("Dataset not found.")
  293. # check user's model setting
  294. DatasetService.check_dataset_model_setting(dataset)
  295. # check document
  296. document_id = str(document_id)
  297. document = DocumentService.get_document(dataset_id, document_id)
  298. if not document:
  299. raise NotFound("Document not found.")
  300. # check segment
  301. segment_id = str(segment_id)
  302. segment = (
  303. db.session.query(DocumentSegment)
  304. .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
  305. .first()
  306. )
  307. if not segment:
  308. raise NotFound("Segment not found.")
  309. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  310. if not current_user.is_dataset_editor:
  311. raise Forbidden()
  312. try:
  313. DatasetService.check_dataset_permission(dataset, current_user)
  314. except services.errors.account.NoPermissionError as e:
  315. raise Forbidden(str(e))
  316. SegmentService.delete_segment(segment, document, dataset)
  317. return {"result": "success"}, 204
  318. class DatasetDocumentSegmentBatchImportApi(Resource):
  319. @setup_required
  320. @login_required
  321. @account_initialization_required
  322. @cloud_edition_billing_resource_check("vector_space")
  323. @cloud_edition_billing_knowledge_limit_check("add_segment")
  324. @cloud_edition_billing_rate_limit_check("knowledge")
  325. def post(self, dataset_id, document_id):
  326. # check dataset
  327. dataset_id = str(dataset_id)
  328. dataset = DatasetService.get_dataset(dataset_id)
  329. if not dataset:
  330. raise NotFound("Dataset not found.")
  331. # check document
  332. document_id = str(document_id)
  333. document = DocumentService.get_document(dataset_id, document_id)
  334. if not document:
  335. raise NotFound("Document not found.")
  336. parser = reqparse.RequestParser()
  337. parser.add_argument("upload_file_id", type=str, required=True, nullable=False, location="json")
  338. args = parser.parse_args()
  339. upload_file_id = args["upload_file_id"]
  340. upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
  341. if not upload_file:
  342. raise NotFound("UploadFile not found.")
  343. # check file type
  344. if not upload_file.name or not upload_file.name.lower().endswith(".csv"):
  345. raise ValueError("Invalid file type. Only CSV files are allowed")
  346. try:
  347. # async job
  348. job_id = str(uuid.uuid4())
  349. indexing_cache_key = f"segment_batch_import_{str(job_id)}"
  350. # send batch add segments task
  351. redis_client.setnx(indexing_cache_key, "waiting")
  352. batch_create_segment_to_index_task.delay(
  353. str(job_id), upload_file_id, dataset_id, document_id, current_user.current_tenant_id, current_user.id
  354. )
  355. except Exception as e:
  356. return {"error": str(e)}, 500
  357. return {"job_id": job_id, "job_status": "waiting"}, 200
  358. @setup_required
  359. @login_required
  360. @account_initialization_required
  361. def get(self, job_id):
  362. job_id = str(job_id)
  363. indexing_cache_key = f"segment_batch_import_{job_id}"
  364. cache_result = redis_client.get(indexing_cache_key)
  365. if cache_result is None:
  366. raise ValueError("The job does not exist.")
  367. return {"job_id": job_id, "job_status": cache_result.decode()}, 200
  368. class ChildChunkAddApi(Resource):
  369. @setup_required
  370. @login_required
  371. @account_initialization_required
  372. @cloud_edition_billing_resource_check("vector_space")
  373. @cloud_edition_billing_knowledge_limit_check("add_segment")
  374. @cloud_edition_billing_rate_limit_check("knowledge")
  375. def post(self, dataset_id, document_id, segment_id):
  376. # check dataset
  377. dataset_id = str(dataset_id)
  378. dataset = DatasetService.get_dataset(dataset_id)
  379. if not dataset:
  380. raise NotFound("Dataset not found.")
  381. # check document
  382. document_id = str(document_id)
  383. document = DocumentService.get_document(dataset_id, document_id)
  384. if not document:
  385. raise NotFound("Document not found.")
  386. # check segment
  387. segment_id = str(segment_id)
  388. segment = (
  389. db.session.query(DocumentSegment)
  390. .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
  391. .first()
  392. )
  393. if not segment:
  394. raise NotFound("Segment not found.")
  395. if not current_user.is_dataset_editor:
  396. raise Forbidden()
  397. # check embedding model setting
  398. if dataset.indexing_technique == "high_quality":
  399. try:
  400. model_manager = ModelManager()
  401. model_manager.get_model_instance(
  402. tenant_id=current_user.current_tenant_id,
  403. provider=dataset.embedding_model_provider,
  404. model_type=ModelType.TEXT_EMBEDDING,
  405. model=dataset.embedding_model,
  406. )
  407. except LLMBadRequestError:
  408. raise ProviderNotInitializeError(
  409. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  410. )
  411. except ProviderTokenNotInitError as ex:
  412. raise ProviderNotInitializeError(ex.description)
  413. try:
  414. DatasetService.check_dataset_permission(dataset, current_user)
  415. except services.errors.account.NoPermissionError as e:
  416. raise Forbidden(str(e))
  417. # validate args
  418. parser = reqparse.RequestParser()
  419. parser.add_argument("content", type=str, required=True, nullable=False, location="json")
  420. args = parser.parse_args()
  421. try:
  422. child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset)
  423. except ChildChunkIndexingServiceError as e:
  424. raise ChildChunkIndexingError(str(e))
  425. return {"data": marshal(child_chunk, child_chunk_fields)}, 200
  426. @setup_required
  427. @login_required
  428. @account_initialization_required
  429. def get(self, dataset_id, document_id, segment_id):
  430. # check dataset
  431. dataset_id = str(dataset_id)
  432. dataset = DatasetService.get_dataset(dataset_id)
  433. if not dataset:
  434. raise NotFound("Dataset not found.")
  435. # check user's model setting
  436. DatasetService.check_dataset_model_setting(dataset)
  437. # check document
  438. document_id = str(document_id)
  439. document = DocumentService.get_document(dataset_id, document_id)
  440. if not document:
  441. raise NotFound("Document not found.")
  442. # check segment
  443. segment_id = str(segment_id)
  444. segment = (
  445. db.session.query(DocumentSegment)
  446. .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
  447. .first()
  448. )
  449. if not segment:
  450. raise NotFound("Segment not found.")
  451. parser = reqparse.RequestParser()
  452. parser.add_argument("limit", type=int, default=20, location="args")
  453. parser.add_argument("keyword", type=str, default=None, location="args")
  454. parser.add_argument("page", type=int, default=1, location="args")
  455. args = parser.parse_args()
  456. page = args["page"]
  457. limit = min(args["limit"], 100)
  458. keyword = args["keyword"]
  459. child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
  460. return {
  461. "data": marshal(child_chunks.items, child_chunk_fields),
  462. "total": child_chunks.total,
  463. "total_pages": child_chunks.pages,
  464. "page": page,
  465. "limit": limit,
  466. }, 200
  467. @setup_required
  468. @login_required
  469. @account_initialization_required
  470. @cloud_edition_billing_resource_check("vector_space")
  471. @cloud_edition_billing_rate_limit_check("knowledge")
  472. def patch(self, dataset_id, document_id, segment_id):
  473. # check dataset
  474. dataset_id = str(dataset_id)
  475. dataset = DatasetService.get_dataset(dataset_id)
  476. if not dataset:
  477. raise NotFound("Dataset not found.")
  478. # check user's model setting
  479. DatasetService.check_dataset_model_setting(dataset)
  480. # check document
  481. document_id = str(document_id)
  482. document = DocumentService.get_document(dataset_id, document_id)
  483. if not document:
  484. raise NotFound("Document not found.")
  485. # check segment
  486. segment_id = str(segment_id)
  487. segment = (
  488. db.session.query(DocumentSegment)
  489. .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
  490. .first()
  491. )
  492. if not segment:
  493. raise NotFound("Segment not found.")
  494. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  495. if not current_user.is_dataset_editor:
  496. raise Forbidden()
  497. try:
  498. DatasetService.check_dataset_permission(dataset, current_user)
  499. except services.errors.account.NoPermissionError as e:
  500. raise Forbidden(str(e))
  501. # validate args
  502. parser = reqparse.RequestParser()
  503. parser.add_argument("chunks", type=list, required=True, nullable=False, location="json")
  504. args = parser.parse_args()
  505. try:
  506. chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")]
  507. child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
  508. except ChildChunkIndexingServiceError as e:
  509. raise ChildChunkIndexingError(str(e))
  510. return {"data": marshal(child_chunks, child_chunk_fields)}, 200
  511. class ChildChunkUpdateApi(Resource):
  512. @setup_required
  513. @login_required
  514. @account_initialization_required
  515. @cloud_edition_billing_rate_limit_check("knowledge")
  516. def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
  517. # check dataset
  518. dataset_id = str(dataset_id)
  519. dataset = DatasetService.get_dataset(dataset_id)
  520. if not dataset:
  521. raise NotFound("Dataset not found.")
  522. # check user's model setting
  523. DatasetService.check_dataset_model_setting(dataset)
  524. # check document
  525. document_id = str(document_id)
  526. document = DocumentService.get_document(dataset_id, document_id)
  527. if not document:
  528. raise NotFound("Document not found.")
  529. # check segment
  530. segment_id = str(segment_id)
  531. segment = (
  532. db.session.query(DocumentSegment)
  533. .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
  534. .first()
  535. )
  536. if not segment:
  537. raise NotFound("Segment not found.")
  538. # check child chunk
  539. child_chunk_id = str(child_chunk_id)
  540. child_chunk = (
  541. db.session.query(ChildChunk)
  542. .where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
  543. .first()
  544. )
  545. if not child_chunk:
  546. raise NotFound("Child chunk not found.")
  547. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  548. if not current_user.is_dataset_editor:
  549. raise Forbidden()
  550. try:
  551. DatasetService.check_dataset_permission(dataset, current_user)
  552. except services.errors.account.NoPermissionError as e:
  553. raise Forbidden(str(e))
  554. try:
  555. SegmentService.delete_child_chunk(child_chunk, dataset)
  556. except ChildChunkDeleteIndexServiceError as e:
  557. raise ChildChunkDeleteIndexError(str(e))
  558. return {"result": "success"}, 204
  559. @setup_required
  560. @login_required
  561. @account_initialization_required
  562. @cloud_edition_billing_resource_check("vector_space")
  563. @cloud_edition_billing_rate_limit_check("knowledge")
  564. def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
  565. # check dataset
  566. dataset_id = str(dataset_id)
  567. dataset = DatasetService.get_dataset(dataset_id)
  568. if not dataset:
  569. raise NotFound("Dataset not found.")
  570. # check user's model setting
  571. DatasetService.check_dataset_model_setting(dataset)
  572. # check document
  573. document_id = str(document_id)
  574. document = DocumentService.get_document(dataset_id, document_id)
  575. if not document:
  576. raise NotFound("Document not found.")
  577. # check segment
  578. segment_id = str(segment_id)
  579. segment = (
  580. db.session.query(DocumentSegment)
  581. .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
  582. .first()
  583. )
  584. if not segment:
  585. raise NotFound("Segment not found.")
  586. # check child chunk
  587. child_chunk_id = str(child_chunk_id)
  588. child_chunk = (
  589. db.session.query(ChildChunk)
  590. .where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
  591. .first()
  592. )
  593. if not child_chunk:
  594. raise NotFound("Child chunk not found.")
  595. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  596. if not current_user.is_dataset_editor:
  597. raise Forbidden()
  598. try:
  599. DatasetService.check_dataset_permission(dataset, current_user)
  600. except services.errors.account.NoPermissionError as e:
  601. raise Forbidden(str(e))
  602. # validate args
  603. parser = reqparse.RequestParser()
  604. parser.add_argument("content", type=str, required=True, nullable=False, location="json")
  605. args = parser.parse_args()
  606. try:
  607. child_chunk = SegmentService.update_child_chunk(
  608. args.get("content"), child_chunk, segment, document, dataset
  609. )
  610. except ChildChunkIndexingServiceError as e:
  611. raise ChildChunkIndexingError(str(e))
  612. return {"data": marshal(child_chunk, child_chunk_fields)}, 200
  613. api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
  614. api.add_resource(
  615. DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>"
  616. )
  617. api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
  618. api.add_resource(
  619. DatasetDocumentSegmentUpdateApi,
  620. "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>",
  621. )
  622. api.add_resource(
  623. DatasetDocumentSegmentBatchImportApi,
  624. "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
  625. "/datasets/batch_import_status/<uuid:job_id>",
  626. )
  627. api.add_resource(
  628. ChildChunkAddApi,
  629. "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks",
  630. )
  631. api.add_resource(
  632. ChildChunkUpdateApi,
  633. "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>",
  634. )