Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

datasets_segments.py 29KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682
  1. import uuid
  2. from flask import request
  3. from flask_login import current_user
  4. from flask_restx import Resource, marshal, reqparse
  5. from sqlalchemy import select
  6. from werkzeug.exceptions import Forbidden, NotFound
  7. import services
  8. from controllers.console import console_ns
  9. from controllers.console.app.error import ProviderNotInitializeError
  10. from controllers.console.datasets.error import (
  11. ChildChunkDeleteIndexError,
  12. ChildChunkIndexingError,
  13. InvalidActionError,
  14. )
  15. from controllers.console.wraps import (
  16. account_initialization_required,
  17. cloud_edition_billing_knowledge_limit_check,
  18. cloud_edition_billing_rate_limit_check,
  19. cloud_edition_billing_resource_check,
  20. setup_required,
  21. )
  22. from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
  23. from core.model_manager import ModelManager
  24. from core.model_runtime.entities.model_entities import ModelType
  25. from extensions.ext_database import db
  26. from extensions.ext_redis import redis_client
  27. from fields.segment_fields import child_chunk_fields, segment_fields
  28. from libs.login import login_required
  29. from models.dataset import ChildChunk, DocumentSegment
  30. from models.model import UploadFile
  31. from services.dataset_service import DatasetService, DocumentService, SegmentService
  32. from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
  33. from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
  34. from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
  35. from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
  36. @console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
  37. class DatasetDocumentSegmentListApi(Resource):
  38. @setup_required
  39. @login_required
  40. @account_initialization_required
  41. def get(self, dataset_id, document_id):
  42. dataset_id = str(dataset_id)
  43. document_id = str(document_id)
  44. dataset = DatasetService.get_dataset(dataset_id)
  45. if not dataset:
  46. raise NotFound("Dataset not found.")
  47. try:
  48. DatasetService.check_dataset_permission(dataset, current_user)
  49. except services.errors.account.NoPermissionError as e:
  50. raise Forbidden(str(e))
  51. document = DocumentService.get_document(dataset_id, document_id)
  52. if not document:
  53. raise NotFound("Document not found.")
  54. parser = reqparse.RequestParser()
  55. parser.add_argument("limit", type=int, default=20, location="args")
  56. parser.add_argument("status", type=str, action="append", default=[], location="args")
  57. parser.add_argument("hit_count_gte", type=int, default=None, location="args")
  58. parser.add_argument("enabled", type=str, default="all", location="args")
  59. parser.add_argument("keyword", type=str, default=None, location="args")
  60. parser.add_argument("page", type=int, default=1, location="args")
  61. args = parser.parse_args()
  62. page = args["page"]
  63. limit = min(args["limit"], 100)
  64. status_list = args["status"]
  65. hit_count_gte = args["hit_count_gte"]
  66. keyword = args["keyword"]
  67. query = (
  68. select(DocumentSegment)
  69. .where(
  70. DocumentSegment.document_id == str(document_id),
  71. DocumentSegment.tenant_id == current_user.current_tenant_id,
  72. )
  73. .order_by(DocumentSegment.position.asc())
  74. )
  75. if status_list:
  76. query = query.where(DocumentSegment.status.in_(status_list))
  77. if hit_count_gte is not None:
  78. query = query.where(DocumentSegment.hit_count >= hit_count_gte)
  79. if keyword:
  80. query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
  81. if args["enabled"].lower() != "all":
  82. if args["enabled"].lower() == "true":
  83. query = query.where(DocumentSegment.enabled == True)
  84. elif args["enabled"].lower() == "false":
  85. query = query.where(DocumentSegment.enabled == False)
  86. segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
  87. response = {
  88. "data": marshal(segments.items, segment_fields),
  89. "limit": limit,
  90. "total": segments.total,
  91. "total_pages": segments.pages,
  92. "page": page,
  93. }
  94. return response, 200
  95. @setup_required
  96. @login_required
  97. @account_initialization_required
  98. @cloud_edition_billing_rate_limit_check("knowledge")
  99. def delete(self, dataset_id, document_id):
  100. # check dataset
  101. dataset_id = str(dataset_id)
  102. dataset = DatasetService.get_dataset(dataset_id)
  103. if not dataset:
  104. raise NotFound("Dataset not found.")
  105. # check user's model setting
  106. DatasetService.check_dataset_model_setting(dataset)
  107. # check document
  108. document_id = str(document_id)
  109. document = DocumentService.get_document(dataset_id, document_id)
  110. if not document:
  111. raise NotFound("Document not found.")
  112. segment_ids = request.args.getlist("segment_id")
  113. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  114. if not current_user.is_dataset_editor:
  115. raise Forbidden()
  116. try:
  117. DatasetService.check_dataset_permission(dataset, current_user)
  118. except services.errors.account.NoPermissionError as e:
  119. raise Forbidden(str(e))
  120. SegmentService.delete_segments(segment_ids, document, dataset)
  121. return {"result": "success"}, 204
  122. @console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>")
  123. class DatasetDocumentSegmentApi(Resource):
  124. @setup_required
  125. @login_required
  126. @account_initialization_required
  127. @cloud_edition_billing_resource_check("vector_space")
  128. @cloud_edition_billing_rate_limit_check("knowledge")
  129. def patch(self, dataset_id, document_id, action):
  130. dataset_id = str(dataset_id)
  131. dataset = DatasetService.get_dataset(dataset_id)
  132. if not dataset:
  133. raise NotFound("Dataset not found.")
  134. document_id = str(document_id)
  135. document = DocumentService.get_document(dataset_id, document_id)
  136. if not document:
  137. raise NotFound("Document not found.")
  138. # check user's model setting
  139. DatasetService.check_dataset_model_setting(dataset)
  140. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  141. if not current_user.is_dataset_editor:
  142. raise Forbidden()
  143. try:
  144. DatasetService.check_dataset_permission(dataset, current_user)
  145. except services.errors.account.NoPermissionError as e:
  146. raise Forbidden(str(e))
  147. if dataset.indexing_technique == "high_quality":
  148. # check embedding model setting
  149. try:
  150. model_manager = ModelManager()
  151. model_manager.get_model_instance(
  152. tenant_id=current_user.current_tenant_id,
  153. provider=dataset.embedding_model_provider,
  154. model_type=ModelType.TEXT_EMBEDDING,
  155. model=dataset.embedding_model,
  156. )
  157. except LLMBadRequestError:
  158. raise ProviderNotInitializeError(
  159. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  160. )
  161. except ProviderTokenNotInitError as ex:
  162. raise ProviderNotInitializeError(ex.description)
  163. segment_ids = request.args.getlist("segment_id")
  164. document_indexing_cache_key = f"document_{document.id}_indexing"
  165. cache_result = redis_client.get(document_indexing_cache_key)
  166. if cache_result is not None:
  167. raise InvalidActionError("Document is being indexed, please try again later")
  168. try:
  169. SegmentService.update_segments_status(segment_ids, action, dataset, document)
  170. except Exception as e:
  171. raise InvalidActionError(str(e))
  172. return {"result": "success"}, 200
  173. @console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
  174. class DatasetDocumentSegmentAddApi(Resource):
  175. @setup_required
  176. @login_required
  177. @account_initialization_required
  178. @cloud_edition_billing_resource_check("vector_space")
  179. @cloud_edition_billing_knowledge_limit_check("add_segment")
  180. @cloud_edition_billing_rate_limit_check("knowledge")
  181. def post(self, dataset_id, document_id):
  182. # check dataset
  183. dataset_id = str(dataset_id)
  184. dataset = DatasetService.get_dataset(dataset_id)
  185. if not dataset:
  186. raise NotFound("Dataset not found.")
  187. # check document
  188. document_id = str(document_id)
  189. document = DocumentService.get_document(dataset_id, document_id)
  190. if not document:
  191. raise NotFound("Document not found.")
  192. if not current_user.is_dataset_editor:
  193. raise Forbidden()
  194. # check embedding model setting
  195. if dataset.indexing_technique == "high_quality":
  196. try:
  197. model_manager = ModelManager()
  198. model_manager.get_model_instance(
  199. tenant_id=current_user.current_tenant_id,
  200. provider=dataset.embedding_model_provider,
  201. model_type=ModelType.TEXT_EMBEDDING,
  202. model=dataset.embedding_model,
  203. )
  204. except LLMBadRequestError:
  205. raise ProviderNotInitializeError(
  206. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  207. )
  208. except ProviderTokenNotInitError as ex:
  209. raise ProviderNotInitializeError(ex.description)
  210. try:
  211. DatasetService.check_dataset_permission(dataset, current_user)
  212. except services.errors.account.NoPermissionError as e:
  213. raise Forbidden(str(e))
  214. # validate args
  215. parser = reqparse.RequestParser()
  216. parser.add_argument("content", type=str, required=True, nullable=False, location="json")
  217. parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
  218. parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
  219. args = parser.parse_args()
  220. SegmentService.segment_create_args_validate(args, document)
  221. segment = SegmentService.create_segment(args, document, dataset)
  222. return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
  223. @console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
  224. class DatasetDocumentSegmentUpdateApi(Resource):
  225. @setup_required
  226. @login_required
  227. @account_initialization_required
  228. @cloud_edition_billing_resource_check("vector_space")
  229. @cloud_edition_billing_rate_limit_check("knowledge")
  230. def patch(self, dataset_id, document_id, segment_id):
  231. # check dataset
  232. dataset_id = str(dataset_id)
  233. dataset = DatasetService.get_dataset(dataset_id)
  234. if not dataset:
  235. raise NotFound("Dataset not found.")
  236. # check user's model setting
  237. DatasetService.check_dataset_model_setting(dataset)
  238. # check document
  239. document_id = str(document_id)
  240. document = DocumentService.get_document(dataset_id, document_id)
  241. if not document:
  242. raise NotFound("Document not found.")
  243. if dataset.indexing_technique == "high_quality":
  244. # check embedding model setting
  245. try:
  246. model_manager = ModelManager()
  247. model_manager.get_model_instance(
  248. tenant_id=current_user.current_tenant_id,
  249. provider=dataset.embedding_model_provider,
  250. model_type=ModelType.TEXT_EMBEDDING,
  251. model=dataset.embedding_model,
  252. )
  253. except LLMBadRequestError:
  254. raise ProviderNotInitializeError(
  255. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  256. )
  257. except ProviderTokenNotInitError as ex:
  258. raise ProviderNotInitializeError(ex.description)
  259. # check segment
  260. segment_id = str(segment_id)
  261. segment = (
  262. db.session.query(DocumentSegment)
  263. .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
  264. .first()
  265. )
  266. if not segment:
  267. raise NotFound("Segment not found.")
  268. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  269. if not current_user.is_dataset_editor:
  270. raise Forbidden()
  271. try:
  272. DatasetService.check_dataset_permission(dataset, current_user)
  273. except services.errors.account.NoPermissionError as e:
  274. raise Forbidden(str(e))
  275. # validate args
  276. parser = reqparse.RequestParser()
  277. parser.add_argument("content", type=str, required=True, nullable=False, location="json")
  278. parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
  279. parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
  280. parser.add_argument(
  281. "regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
  282. )
  283. args = parser.parse_args()
  284. SegmentService.segment_create_args_validate(args, document)
  285. segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset)
  286. return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
  287. @setup_required
  288. @login_required
  289. @account_initialization_required
  290. @cloud_edition_billing_rate_limit_check("knowledge")
  291. def delete(self, dataset_id, document_id, segment_id):
  292. # check dataset
  293. dataset_id = str(dataset_id)
  294. dataset = DatasetService.get_dataset(dataset_id)
  295. if not dataset:
  296. raise NotFound("Dataset not found.")
  297. # check user's model setting
  298. DatasetService.check_dataset_model_setting(dataset)
  299. # check document
  300. document_id = str(document_id)
  301. document = DocumentService.get_document(dataset_id, document_id)
  302. if not document:
  303. raise NotFound("Document not found.")
  304. # check segment
  305. segment_id = str(segment_id)
  306. segment = (
  307. db.session.query(DocumentSegment)
  308. .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
  309. .first()
  310. )
  311. if not segment:
  312. raise NotFound("Segment not found.")
  313. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  314. if not current_user.is_dataset_editor:
  315. raise Forbidden()
  316. try:
  317. DatasetService.check_dataset_permission(dataset, current_user)
  318. except services.errors.account.NoPermissionError as e:
  319. raise Forbidden(str(e))
  320. SegmentService.delete_segment(segment, document, dataset)
  321. return {"result": "success"}, 204
  322. @console_ns.route(
  323. "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
  324. "/datasets/batch_import_status/<uuid:job_id>",
  325. )
  326. class DatasetDocumentSegmentBatchImportApi(Resource):
  327. @setup_required
  328. @login_required
  329. @account_initialization_required
  330. @cloud_edition_billing_resource_check("vector_space")
  331. @cloud_edition_billing_knowledge_limit_check("add_segment")
  332. @cloud_edition_billing_rate_limit_check("knowledge")
  333. def post(self, dataset_id, document_id):
  334. # check dataset
  335. dataset_id = str(dataset_id)
  336. dataset = DatasetService.get_dataset(dataset_id)
  337. if not dataset:
  338. raise NotFound("Dataset not found.")
  339. # check document
  340. document_id = str(document_id)
  341. document = DocumentService.get_document(dataset_id, document_id)
  342. if not document:
  343. raise NotFound("Document not found.")
  344. parser = reqparse.RequestParser()
  345. parser.add_argument("upload_file_id", type=str, required=True, nullable=False, location="json")
  346. args = parser.parse_args()
  347. upload_file_id = args["upload_file_id"]
  348. upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
  349. if not upload_file:
  350. raise NotFound("UploadFile not found.")
  351. # check file type
  352. if not upload_file.name or not upload_file.name.lower().endswith(".csv"):
  353. raise ValueError("Invalid file type. Only CSV files are allowed")
  354. try:
  355. # async job
  356. job_id = str(uuid.uuid4())
  357. indexing_cache_key = f"segment_batch_import_{str(job_id)}"
  358. # send batch add segments task
  359. redis_client.setnx(indexing_cache_key, "waiting")
  360. batch_create_segment_to_index_task.delay(
  361. str(job_id), upload_file_id, dataset_id, document_id, current_user.current_tenant_id, current_user.id
  362. )
  363. except Exception as e:
  364. return {"error": str(e)}, 500
  365. return {"job_id": job_id, "job_status": "waiting"}, 200
  366. @setup_required
  367. @login_required
  368. @account_initialization_required
  369. def get(self, job_id=None, dataset_id=None, document_id=None):
  370. if job_id is None:
  371. raise NotFound("The job does not exist.")
  372. job_id = str(job_id)
  373. indexing_cache_key = f"segment_batch_import_{job_id}"
  374. cache_result = redis_client.get(indexing_cache_key)
  375. if cache_result is None:
  376. raise ValueError("The job does not exist.")
  377. return {"job_id": job_id, "job_status": cache_result.decode()}, 200
  378. @console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks")
  379. class ChildChunkAddApi(Resource):
  380. @setup_required
  381. @login_required
  382. @account_initialization_required
  383. @cloud_edition_billing_resource_check("vector_space")
  384. @cloud_edition_billing_knowledge_limit_check("add_segment")
  385. @cloud_edition_billing_rate_limit_check("knowledge")
  386. def post(self, dataset_id, document_id, segment_id):
  387. # check dataset
  388. dataset_id = str(dataset_id)
  389. dataset = DatasetService.get_dataset(dataset_id)
  390. if not dataset:
  391. raise NotFound("Dataset not found.")
  392. # check document
  393. document_id = str(document_id)
  394. document = DocumentService.get_document(dataset_id, document_id)
  395. if not document:
  396. raise NotFound("Document not found.")
  397. # check segment
  398. segment_id = str(segment_id)
  399. segment = (
  400. db.session.query(DocumentSegment)
  401. .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
  402. .first()
  403. )
  404. if not segment:
  405. raise NotFound("Segment not found.")
  406. if not current_user.is_dataset_editor:
  407. raise Forbidden()
  408. # check embedding model setting
  409. if dataset.indexing_technique == "high_quality":
  410. try:
  411. model_manager = ModelManager()
  412. model_manager.get_model_instance(
  413. tenant_id=current_user.current_tenant_id,
  414. provider=dataset.embedding_model_provider,
  415. model_type=ModelType.TEXT_EMBEDDING,
  416. model=dataset.embedding_model,
  417. )
  418. except LLMBadRequestError:
  419. raise ProviderNotInitializeError(
  420. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  421. )
  422. except ProviderTokenNotInitError as ex:
  423. raise ProviderNotInitializeError(ex.description)
  424. try:
  425. DatasetService.check_dataset_permission(dataset, current_user)
  426. except services.errors.account.NoPermissionError as e:
  427. raise Forbidden(str(e))
  428. # validate args
  429. parser = reqparse.RequestParser()
  430. parser.add_argument("content", type=str, required=True, nullable=False, location="json")
  431. args = parser.parse_args()
  432. try:
  433. child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset)
  434. except ChildChunkIndexingServiceError as e:
  435. raise ChildChunkIndexingError(str(e))
  436. return {"data": marshal(child_chunk, child_chunk_fields)}, 200
  437. @setup_required
  438. @login_required
  439. @account_initialization_required
  440. def get(self, dataset_id, document_id, segment_id):
  441. # check dataset
  442. dataset_id = str(dataset_id)
  443. dataset = DatasetService.get_dataset(dataset_id)
  444. if not dataset:
  445. raise NotFound("Dataset not found.")
  446. # check user's model setting
  447. DatasetService.check_dataset_model_setting(dataset)
  448. # check document
  449. document_id = str(document_id)
  450. document = DocumentService.get_document(dataset_id, document_id)
  451. if not document:
  452. raise NotFound("Document not found.")
  453. # check segment
  454. segment_id = str(segment_id)
  455. segment = (
  456. db.session.query(DocumentSegment)
  457. .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
  458. .first()
  459. )
  460. if not segment:
  461. raise NotFound("Segment not found.")
  462. parser = reqparse.RequestParser()
  463. parser.add_argument("limit", type=int, default=20, location="args")
  464. parser.add_argument("keyword", type=str, default=None, location="args")
  465. parser.add_argument("page", type=int, default=1, location="args")
  466. args = parser.parse_args()
  467. page = args["page"]
  468. limit = min(args["limit"], 100)
  469. keyword = args["keyword"]
  470. child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
  471. return {
  472. "data": marshal(child_chunks.items, child_chunk_fields),
  473. "total": child_chunks.total,
  474. "total_pages": child_chunks.pages,
  475. "page": page,
  476. "limit": limit,
  477. }, 200
  478. @setup_required
  479. @login_required
  480. @account_initialization_required
  481. @cloud_edition_billing_resource_check("vector_space")
  482. @cloud_edition_billing_rate_limit_check("knowledge")
  483. def patch(self, dataset_id, document_id, segment_id):
  484. # check dataset
  485. dataset_id = str(dataset_id)
  486. dataset = DatasetService.get_dataset(dataset_id)
  487. if not dataset:
  488. raise NotFound("Dataset not found.")
  489. # check user's model setting
  490. DatasetService.check_dataset_model_setting(dataset)
  491. # check document
  492. document_id = str(document_id)
  493. document = DocumentService.get_document(dataset_id, document_id)
  494. if not document:
  495. raise NotFound("Document not found.")
  496. # check segment
  497. segment_id = str(segment_id)
  498. segment = (
  499. db.session.query(DocumentSegment)
  500. .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
  501. .first()
  502. )
  503. if not segment:
  504. raise NotFound("Segment not found.")
  505. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  506. if not current_user.is_dataset_editor:
  507. raise Forbidden()
  508. try:
  509. DatasetService.check_dataset_permission(dataset, current_user)
  510. except services.errors.account.NoPermissionError as e:
  511. raise Forbidden(str(e))
  512. # validate args
  513. parser = reqparse.RequestParser()
  514. parser.add_argument("chunks", type=list, required=True, nullable=False, location="json")
  515. args = parser.parse_args()
  516. try:
  517. chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")]
  518. child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
  519. except ChildChunkIndexingServiceError as e:
  520. raise ChildChunkIndexingError(str(e))
  521. return {"data": marshal(child_chunks, child_chunk_fields)}, 200
  522. @console_ns.route(
  523. "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>"
  524. )
  525. class ChildChunkUpdateApi(Resource):
  526. @setup_required
  527. @login_required
  528. @account_initialization_required
  529. @cloud_edition_billing_rate_limit_check("knowledge")
  530. def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
  531. # check dataset
  532. dataset_id = str(dataset_id)
  533. dataset = DatasetService.get_dataset(dataset_id)
  534. if not dataset:
  535. raise NotFound("Dataset not found.")
  536. # check user's model setting
  537. DatasetService.check_dataset_model_setting(dataset)
  538. # check document
  539. document_id = str(document_id)
  540. document = DocumentService.get_document(dataset_id, document_id)
  541. if not document:
  542. raise NotFound("Document not found.")
  543. # check segment
  544. segment_id = str(segment_id)
  545. segment = (
  546. db.session.query(DocumentSegment)
  547. .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
  548. .first()
  549. )
  550. if not segment:
  551. raise NotFound("Segment not found.")
  552. # check child chunk
  553. child_chunk_id = str(child_chunk_id)
  554. child_chunk = (
  555. db.session.query(ChildChunk)
  556. .where(
  557. ChildChunk.id == str(child_chunk_id),
  558. ChildChunk.tenant_id == current_user.current_tenant_id,
  559. ChildChunk.segment_id == segment.id,
  560. ChildChunk.document_id == document_id,
  561. )
  562. .first()
  563. )
  564. if not child_chunk:
  565. raise NotFound("Child chunk not found.")
  566. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  567. if not current_user.is_dataset_editor:
  568. raise Forbidden()
  569. try:
  570. DatasetService.check_dataset_permission(dataset, current_user)
  571. except services.errors.account.NoPermissionError as e:
  572. raise Forbidden(str(e))
  573. try:
  574. SegmentService.delete_child_chunk(child_chunk, dataset)
  575. except ChildChunkDeleteIndexServiceError as e:
  576. raise ChildChunkDeleteIndexError(str(e))
  577. return {"result": "success"}, 204
  578. @setup_required
  579. @login_required
  580. @account_initialization_required
  581. @cloud_edition_billing_resource_check("vector_space")
  582. @cloud_edition_billing_rate_limit_check("knowledge")
  583. def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
  584. # check dataset
  585. dataset_id = str(dataset_id)
  586. dataset = DatasetService.get_dataset(dataset_id)
  587. if not dataset:
  588. raise NotFound("Dataset not found.")
  589. # check user's model setting
  590. DatasetService.check_dataset_model_setting(dataset)
  591. # check document
  592. document_id = str(document_id)
  593. document = DocumentService.get_document(dataset_id, document_id)
  594. if not document:
  595. raise NotFound("Document not found.")
  596. # check segment
  597. segment_id = str(segment_id)
  598. segment = (
  599. db.session.query(DocumentSegment)
  600. .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
  601. .first()
  602. )
  603. if not segment:
  604. raise NotFound("Segment not found.")
  605. # check child chunk
  606. child_chunk_id = str(child_chunk_id)
  607. child_chunk = (
  608. db.session.query(ChildChunk)
  609. .where(
  610. ChildChunk.id == str(child_chunk_id),
  611. ChildChunk.tenant_id == current_user.current_tenant_id,
  612. ChildChunk.segment_id == segment.id,
  613. ChildChunk.document_id == document_id,
  614. )
  615. .first()
  616. )
  617. if not child_chunk:
  618. raise NotFound("Child chunk not found.")
  619. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  620. if not current_user.is_dataset_editor:
  621. raise Forbidden()
  622. try:
  623. DatasetService.check_dataset_permission(dataset, current_user)
  624. except services.errors.account.NoPermissionError as e:
  625. raise Forbidden(str(e))
  626. # validate args
  627. parser = reqparse.RequestParser()
  628. parser.add_argument("content", type=str, required=True, nullable=False, location="json")
  629. args = parser.parse_args()
  630. try:
  631. child_chunk = SegmentService.update_child_chunk(
  632. args.get("content"), child_chunk, segment, document, dataset
  633. )
  634. except ChildChunkIndexingServiceError as e:
  635. raise ChildChunkIndexingError(str(e))
  636. return {"data": marshal(child_chunk, child_chunk_fields)}, 200