You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets_segments.py 29KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693
  1. import uuid
  2. import pandas as pd
  3. from flask import request
  4. from flask_login import current_user
  5. from flask_restful import Resource, marshal, reqparse
  6. from sqlalchemy import select
  7. from werkzeug.exceptions import Forbidden, NotFound
  8. import services
  9. from controllers.console import api
  10. from controllers.console.app.error import ProviderNotInitializeError
  11. from controllers.console.datasets.error import (
  12. ChildChunkDeleteIndexError,
  13. ChildChunkIndexingError,
  14. InvalidActionError,
  15. NoFileUploadedError,
  16. TooManyFilesError,
  17. )
  18. from controllers.console.wraps import (
  19. account_initialization_required,
  20. cloud_edition_billing_knowledge_limit_check,
  21. cloud_edition_billing_rate_limit_check,
  22. cloud_edition_billing_resource_check,
  23. setup_required,
  24. )
  25. from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
  26. from core.model_manager import ModelManager
  27. from core.model_runtime.entities.model_entities import ModelType
  28. from extensions.ext_database import db
  29. from extensions.ext_redis import redis_client
  30. from fields.segment_fields import child_chunk_fields, segment_fields
  31. from libs.login import login_required
  32. from models.dataset import ChildChunk, DocumentSegment
  33. from services.dataset_service import DatasetService, DocumentService, SegmentService
  34. from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
  35. from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
  36. from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
  37. from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
  38. class DatasetDocumentSegmentListApi(Resource):
  39. @setup_required
  40. @login_required
  41. @account_initialization_required
  42. def get(self, dataset_id, document_id):
  43. dataset_id = str(dataset_id)
  44. document_id = str(document_id)
  45. dataset = DatasetService.get_dataset(dataset_id)
  46. if not dataset:
  47. raise NotFound("Dataset not found.")
  48. try:
  49. DatasetService.check_dataset_permission(dataset, current_user)
  50. except services.errors.account.NoPermissionError as e:
  51. raise Forbidden(str(e))
  52. document = DocumentService.get_document(dataset_id, document_id)
  53. if not document:
  54. raise NotFound("Document not found.")
  55. parser = reqparse.RequestParser()
  56. parser.add_argument("limit", type=int, default=20, location="args")
  57. parser.add_argument("status", type=str, action="append", default=[], location="args")
  58. parser.add_argument("hit_count_gte", type=int, default=None, location="args")
  59. parser.add_argument("enabled", type=str, default="all", location="args")
  60. parser.add_argument("keyword", type=str, default=None, location="args")
  61. parser.add_argument("page", type=int, default=1, location="args")
  62. args = parser.parse_args()
  63. page = args["page"]
  64. limit = min(args["limit"], 100)
  65. status_list = args["status"]
  66. hit_count_gte = args["hit_count_gte"]
  67. keyword = args["keyword"]
  68. query = (
  69. select(DocumentSegment)
  70. .where(
  71. DocumentSegment.document_id == str(document_id),
  72. DocumentSegment.tenant_id == current_user.current_tenant_id,
  73. )
  74. .order_by(DocumentSegment.position.asc())
  75. )
  76. if status_list:
  77. query = query.where(DocumentSegment.status.in_(status_list))
  78. if hit_count_gte is not None:
  79. query = query.where(DocumentSegment.hit_count >= hit_count_gte)
  80. if keyword:
  81. query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
  82. if args["enabled"].lower() != "all":
  83. if args["enabled"].lower() == "true":
  84. query = query.where(DocumentSegment.enabled == True)
  85. elif args["enabled"].lower() == "false":
  86. query = query.where(DocumentSegment.enabled == False)
  87. segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
  88. response = {
  89. "data": marshal(segments.items, segment_fields),
  90. "limit": limit,
  91. "total": segments.total,
  92. "total_pages": segments.pages,
  93. "page": page,
  94. }
  95. return response, 200
  96. @setup_required
  97. @login_required
  98. @account_initialization_required
  99. @cloud_edition_billing_rate_limit_check("knowledge")
  100. def delete(self, dataset_id, document_id):
  101. # check dataset
  102. dataset_id = str(dataset_id)
  103. dataset = DatasetService.get_dataset(dataset_id)
  104. if not dataset:
  105. raise NotFound("Dataset not found.")
  106. # check user's model setting
  107. DatasetService.check_dataset_model_setting(dataset)
  108. # check document
  109. document_id = str(document_id)
  110. document = DocumentService.get_document(dataset_id, document_id)
  111. if not document:
  112. raise NotFound("Document not found.")
  113. segment_ids = request.args.getlist("segment_id")
  114. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  115. if not current_user.is_dataset_editor:
  116. raise Forbidden()
  117. try:
  118. DatasetService.check_dataset_permission(dataset, current_user)
  119. except services.errors.account.NoPermissionError as e:
  120. raise Forbidden(str(e))
  121. SegmentService.delete_segments(segment_ids, document, dataset)
  122. return {"result": "success"}, 204
  123. class DatasetDocumentSegmentApi(Resource):
  124. @setup_required
  125. @login_required
  126. @account_initialization_required
  127. @cloud_edition_billing_resource_check("vector_space")
  128. @cloud_edition_billing_rate_limit_check("knowledge")
  129. def patch(self, dataset_id, document_id, action):
  130. dataset_id = str(dataset_id)
  131. dataset = DatasetService.get_dataset(dataset_id)
  132. if not dataset:
  133. raise NotFound("Dataset not found.")
  134. document_id = str(document_id)
  135. document = DocumentService.get_document(dataset_id, document_id)
  136. if not document:
  137. raise NotFound("Document not found.")
  138. # check user's model setting
  139. DatasetService.check_dataset_model_setting(dataset)
  140. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  141. if not current_user.is_dataset_editor:
  142. raise Forbidden()
  143. try:
  144. DatasetService.check_dataset_permission(dataset, current_user)
  145. except services.errors.account.NoPermissionError as e:
  146. raise Forbidden(str(e))
  147. if dataset.indexing_technique == "high_quality":
  148. # check embedding model setting
  149. try:
  150. model_manager = ModelManager()
  151. model_manager.get_model_instance(
  152. tenant_id=current_user.current_tenant_id,
  153. provider=dataset.embedding_model_provider,
  154. model_type=ModelType.TEXT_EMBEDDING,
  155. model=dataset.embedding_model,
  156. )
  157. except LLMBadRequestError:
  158. raise ProviderNotInitializeError(
  159. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  160. )
  161. except ProviderTokenNotInitError as ex:
  162. raise ProviderNotInitializeError(ex.description)
  163. segment_ids = request.args.getlist("segment_id")
  164. document_indexing_cache_key = f"document_{document.id}_indexing"
  165. cache_result = redis_client.get(document_indexing_cache_key)
  166. if cache_result is not None:
  167. raise InvalidActionError("Document is being indexed, please try again later")
  168. try:
  169. SegmentService.update_segments_status(segment_ids, action, dataset, document)
  170. except Exception as e:
  171. raise InvalidActionError(str(e))
  172. return {"result": "success"}, 200
  173. class DatasetDocumentSegmentAddApi(Resource):
  174. @setup_required
  175. @login_required
  176. @account_initialization_required
  177. @cloud_edition_billing_resource_check("vector_space")
  178. @cloud_edition_billing_knowledge_limit_check("add_segment")
  179. @cloud_edition_billing_rate_limit_check("knowledge")
  180. def post(self, dataset_id, document_id):
  181. # check dataset
  182. dataset_id = str(dataset_id)
  183. dataset = DatasetService.get_dataset(dataset_id)
  184. if not dataset:
  185. raise NotFound("Dataset not found.")
  186. # check document
  187. document_id = str(document_id)
  188. document = DocumentService.get_document(dataset_id, document_id)
  189. if not document:
  190. raise NotFound("Document not found.")
  191. if not current_user.is_dataset_editor:
  192. raise Forbidden()
  193. # check embedding model setting
  194. if dataset.indexing_technique == "high_quality":
  195. try:
  196. model_manager = ModelManager()
  197. model_manager.get_model_instance(
  198. tenant_id=current_user.current_tenant_id,
  199. provider=dataset.embedding_model_provider,
  200. model_type=ModelType.TEXT_EMBEDDING,
  201. model=dataset.embedding_model,
  202. )
  203. except LLMBadRequestError:
  204. raise ProviderNotInitializeError(
  205. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  206. )
  207. except ProviderTokenNotInitError as ex:
  208. raise ProviderNotInitializeError(ex.description)
  209. try:
  210. DatasetService.check_dataset_permission(dataset, current_user)
  211. except services.errors.account.NoPermissionError as e:
  212. raise Forbidden(str(e))
  213. # validate args
  214. parser = reqparse.RequestParser()
  215. parser.add_argument("content", type=str, required=True, nullable=False, location="json")
  216. parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
  217. parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
  218. args = parser.parse_args()
  219. SegmentService.segment_create_args_validate(args, document)
  220. segment = SegmentService.create_segment(args, document, dataset)
  221. return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
  222. class DatasetDocumentSegmentUpdateApi(Resource):
  223. @setup_required
  224. @login_required
  225. @account_initialization_required
  226. @cloud_edition_billing_resource_check("vector_space")
  227. @cloud_edition_billing_rate_limit_check("knowledge")
  228. def patch(self, dataset_id, document_id, segment_id):
  229. # check dataset
  230. dataset_id = str(dataset_id)
  231. dataset = DatasetService.get_dataset(dataset_id)
  232. if not dataset:
  233. raise NotFound("Dataset not found.")
  234. # check user's model setting
  235. DatasetService.check_dataset_model_setting(dataset)
  236. # check document
  237. document_id = str(document_id)
  238. document = DocumentService.get_document(dataset_id, document_id)
  239. if not document:
  240. raise NotFound("Document not found.")
  241. if dataset.indexing_technique == "high_quality":
  242. # check embedding model setting
  243. try:
  244. model_manager = ModelManager()
  245. model_manager.get_model_instance(
  246. tenant_id=current_user.current_tenant_id,
  247. provider=dataset.embedding_model_provider,
  248. model_type=ModelType.TEXT_EMBEDDING,
  249. model=dataset.embedding_model,
  250. )
  251. except LLMBadRequestError:
  252. raise ProviderNotInitializeError(
  253. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  254. )
  255. except ProviderTokenNotInitError as ex:
  256. raise ProviderNotInitializeError(ex.description)
  257. # check segment
  258. segment_id = str(segment_id)
  259. segment = (
  260. db.session.query(DocumentSegment)
  261. .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
  262. .first()
  263. )
  264. if not segment:
  265. raise NotFound("Segment not found.")
  266. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  267. if not current_user.is_dataset_editor:
  268. raise Forbidden()
  269. try:
  270. DatasetService.check_dataset_permission(dataset, current_user)
  271. except services.errors.account.NoPermissionError as e:
  272. raise Forbidden(str(e))
  273. # validate args
  274. parser = reqparse.RequestParser()
  275. parser.add_argument("content", type=str, required=True, nullable=False, location="json")
  276. parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
  277. parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
  278. parser.add_argument(
  279. "regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
  280. )
  281. args = parser.parse_args()
  282. SegmentService.segment_create_args_validate(args, document)
  283. segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset)
  284. return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
  285. @setup_required
  286. @login_required
  287. @account_initialization_required
  288. @cloud_edition_billing_rate_limit_check("knowledge")
  289. def delete(self, dataset_id, document_id, segment_id):
  290. # check dataset
  291. dataset_id = str(dataset_id)
  292. dataset = DatasetService.get_dataset(dataset_id)
  293. if not dataset:
  294. raise NotFound("Dataset not found.")
  295. # check user's model setting
  296. DatasetService.check_dataset_model_setting(dataset)
  297. # check document
  298. document_id = str(document_id)
  299. document = DocumentService.get_document(dataset_id, document_id)
  300. if not document:
  301. raise NotFound("Document not found.")
  302. # check segment
  303. segment_id = str(segment_id)
  304. segment = (
  305. db.session.query(DocumentSegment)
  306. .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
  307. .first()
  308. )
  309. if not segment:
  310. raise NotFound("Segment not found.")
  311. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  312. if not current_user.is_dataset_editor:
  313. raise Forbidden()
  314. try:
  315. DatasetService.check_dataset_permission(dataset, current_user)
  316. except services.errors.account.NoPermissionError as e:
  317. raise Forbidden(str(e))
  318. SegmentService.delete_segment(segment, document, dataset)
  319. return {"result": "success"}, 204
  320. class DatasetDocumentSegmentBatchImportApi(Resource):
  321. @setup_required
  322. @login_required
  323. @account_initialization_required
  324. @cloud_edition_billing_resource_check("vector_space")
  325. @cloud_edition_billing_knowledge_limit_check("add_segment")
  326. @cloud_edition_billing_rate_limit_check("knowledge")
  327. def post(self, dataset_id, document_id):
  328. # check dataset
  329. dataset_id = str(dataset_id)
  330. dataset = DatasetService.get_dataset(dataset_id)
  331. if not dataset:
  332. raise NotFound("Dataset not found.")
  333. # check document
  334. document_id = str(document_id)
  335. document = DocumentService.get_document(dataset_id, document_id)
  336. if not document:
  337. raise NotFound("Document not found.")
  338. # get file from request
  339. file = request.files["file"]
  340. # check file
  341. if "file" not in request.files:
  342. raise NoFileUploadedError()
  343. if len(request.files) > 1:
  344. raise TooManyFilesError()
  345. # check file type
  346. if not file.filename or not file.filename.lower().endswith(".csv"):
  347. raise ValueError("Invalid file type. Only CSV files are allowed")
  348. try:
  349. # Skip the first row
  350. df = pd.read_csv(file)
  351. result = []
  352. for index, row in df.iterrows():
  353. if document.doc_form == "qa_model":
  354. data = {"content": row.iloc[0], "answer": row.iloc[1]}
  355. else:
  356. data = {"content": row.iloc[0]}
  357. result.append(data)
  358. if len(result) == 0:
  359. raise ValueError("The CSV file is empty.")
  360. # async job
  361. job_id = str(uuid.uuid4())
  362. indexing_cache_key = f"segment_batch_import_{str(job_id)}"
  363. # send batch add segments task
  364. redis_client.setnx(indexing_cache_key, "waiting")
  365. batch_create_segment_to_index_task.delay(
  366. str(job_id), result, dataset_id, document_id, current_user.current_tenant_id, current_user.id
  367. )
  368. except Exception as e:
  369. return {"error": str(e)}, 500
  370. return {"job_id": job_id, "job_status": "waiting"}, 200
  371. @setup_required
  372. @login_required
  373. @account_initialization_required
  374. def get(self, job_id):
  375. job_id = str(job_id)
  376. indexing_cache_key = f"segment_batch_import_{job_id}"
  377. cache_result = redis_client.get(indexing_cache_key)
  378. if cache_result is None:
  379. raise ValueError("The job does not exist.")
  380. return {"job_id": job_id, "job_status": cache_result.decode()}, 200
  381. class ChildChunkAddApi(Resource):
  382. @setup_required
  383. @login_required
  384. @account_initialization_required
  385. @cloud_edition_billing_resource_check("vector_space")
  386. @cloud_edition_billing_knowledge_limit_check("add_segment")
  387. @cloud_edition_billing_rate_limit_check("knowledge")
  388. def post(self, dataset_id, document_id, segment_id):
  389. # check dataset
  390. dataset_id = str(dataset_id)
  391. dataset = DatasetService.get_dataset(dataset_id)
  392. if not dataset:
  393. raise NotFound("Dataset not found.")
  394. # check document
  395. document_id = str(document_id)
  396. document = DocumentService.get_document(dataset_id, document_id)
  397. if not document:
  398. raise NotFound("Document not found.")
  399. # check segment
  400. segment_id = str(segment_id)
  401. segment = (
  402. db.session.query(DocumentSegment)
  403. .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
  404. .first()
  405. )
  406. if not segment:
  407. raise NotFound("Segment not found.")
  408. if not current_user.is_dataset_editor:
  409. raise Forbidden()
  410. # check embedding model setting
  411. if dataset.indexing_technique == "high_quality":
  412. try:
  413. model_manager = ModelManager()
  414. model_manager.get_model_instance(
  415. tenant_id=current_user.current_tenant_id,
  416. provider=dataset.embedding_model_provider,
  417. model_type=ModelType.TEXT_EMBEDDING,
  418. model=dataset.embedding_model,
  419. )
  420. except LLMBadRequestError:
  421. raise ProviderNotInitializeError(
  422. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  423. )
  424. except ProviderTokenNotInitError as ex:
  425. raise ProviderNotInitializeError(ex.description)
  426. try:
  427. DatasetService.check_dataset_permission(dataset, current_user)
  428. except services.errors.account.NoPermissionError as e:
  429. raise Forbidden(str(e))
  430. # validate args
  431. parser = reqparse.RequestParser()
  432. parser.add_argument("content", type=str, required=True, nullable=False, location="json")
  433. args = parser.parse_args()
  434. try:
  435. child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset)
  436. except ChildChunkIndexingServiceError as e:
  437. raise ChildChunkIndexingError(str(e))
  438. return {"data": marshal(child_chunk, child_chunk_fields)}, 200
  439. @setup_required
  440. @login_required
  441. @account_initialization_required
  442. def get(self, dataset_id, document_id, segment_id):
  443. # check dataset
  444. dataset_id = str(dataset_id)
  445. dataset = DatasetService.get_dataset(dataset_id)
  446. if not dataset:
  447. raise NotFound("Dataset not found.")
  448. # check user's model setting
  449. DatasetService.check_dataset_model_setting(dataset)
  450. # check document
  451. document_id = str(document_id)
  452. document = DocumentService.get_document(dataset_id, document_id)
  453. if not document:
  454. raise NotFound("Document not found.")
  455. # check segment
  456. segment_id = str(segment_id)
  457. segment = (
  458. db.session.query(DocumentSegment)
  459. .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
  460. .first()
  461. )
  462. if not segment:
  463. raise NotFound("Segment not found.")
  464. parser = reqparse.RequestParser()
  465. parser.add_argument("limit", type=int, default=20, location="args")
  466. parser.add_argument("keyword", type=str, default=None, location="args")
  467. parser.add_argument("page", type=int, default=1, location="args")
  468. args = parser.parse_args()
  469. page = args["page"]
  470. limit = min(args["limit"], 100)
  471. keyword = args["keyword"]
  472. child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
  473. return {
  474. "data": marshal(child_chunks.items, child_chunk_fields),
  475. "total": child_chunks.total,
  476. "total_pages": child_chunks.pages,
  477. "page": page,
  478. "limit": limit,
  479. }, 200
  480. @setup_required
  481. @login_required
  482. @account_initialization_required
  483. @cloud_edition_billing_resource_check("vector_space")
  484. @cloud_edition_billing_rate_limit_check("knowledge")
  485. def patch(self, dataset_id, document_id, segment_id):
  486. # check dataset
  487. dataset_id = str(dataset_id)
  488. dataset = DatasetService.get_dataset(dataset_id)
  489. if not dataset:
  490. raise NotFound("Dataset not found.")
  491. # check user's model setting
  492. DatasetService.check_dataset_model_setting(dataset)
  493. # check document
  494. document_id = str(document_id)
  495. document = DocumentService.get_document(dataset_id, document_id)
  496. if not document:
  497. raise NotFound("Document not found.")
  498. # check segment
  499. segment_id = str(segment_id)
  500. segment = (
  501. db.session.query(DocumentSegment)
  502. .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
  503. .first()
  504. )
  505. if not segment:
  506. raise NotFound("Segment not found.")
  507. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  508. if not current_user.is_dataset_editor:
  509. raise Forbidden()
  510. try:
  511. DatasetService.check_dataset_permission(dataset, current_user)
  512. except services.errors.account.NoPermissionError as e:
  513. raise Forbidden(str(e))
  514. # validate args
  515. parser = reqparse.RequestParser()
  516. parser.add_argument("chunks", type=list, required=True, nullable=False, location="json")
  517. args = parser.parse_args()
  518. try:
  519. chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")]
  520. child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
  521. except ChildChunkIndexingServiceError as e:
  522. raise ChildChunkIndexingError(str(e))
  523. return {"data": marshal(child_chunks, child_chunk_fields)}, 200
  524. class ChildChunkUpdateApi(Resource):
  525. @setup_required
  526. @login_required
  527. @account_initialization_required
  528. @cloud_edition_billing_rate_limit_check("knowledge")
  529. def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
  530. # check dataset
  531. dataset_id = str(dataset_id)
  532. dataset = DatasetService.get_dataset(dataset_id)
  533. if not dataset:
  534. raise NotFound("Dataset not found.")
  535. # check user's model setting
  536. DatasetService.check_dataset_model_setting(dataset)
  537. # check document
  538. document_id = str(document_id)
  539. document = DocumentService.get_document(dataset_id, document_id)
  540. if not document:
  541. raise NotFound("Document not found.")
  542. # check segment
  543. segment_id = str(segment_id)
  544. segment = (
  545. db.session.query(DocumentSegment)
  546. .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
  547. .first()
  548. )
  549. if not segment:
  550. raise NotFound("Segment not found.")
  551. # check child chunk
  552. child_chunk_id = str(child_chunk_id)
  553. child_chunk = (
  554. db.session.query(ChildChunk)
  555. .where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
  556. .first()
  557. )
  558. if not child_chunk:
  559. raise NotFound("Child chunk not found.")
  560. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  561. if not current_user.is_dataset_editor:
  562. raise Forbidden()
  563. try:
  564. DatasetService.check_dataset_permission(dataset, current_user)
  565. except services.errors.account.NoPermissionError as e:
  566. raise Forbidden(str(e))
  567. try:
  568. SegmentService.delete_child_chunk(child_chunk, dataset)
  569. except ChildChunkDeleteIndexServiceError as e:
  570. raise ChildChunkDeleteIndexError(str(e))
  571. return {"result": "success"}, 204
  572. @setup_required
  573. @login_required
  574. @account_initialization_required
  575. @cloud_edition_billing_resource_check("vector_space")
  576. @cloud_edition_billing_rate_limit_check("knowledge")
  577. def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
  578. # check dataset
  579. dataset_id = str(dataset_id)
  580. dataset = DatasetService.get_dataset(dataset_id)
  581. if not dataset:
  582. raise NotFound("Dataset not found.")
  583. # check user's model setting
  584. DatasetService.check_dataset_model_setting(dataset)
  585. # check document
  586. document_id = str(document_id)
  587. document = DocumentService.get_document(dataset_id, document_id)
  588. if not document:
  589. raise NotFound("Document not found.")
  590. # check segment
  591. segment_id = str(segment_id)
  592. segment = (
  593. db.session.query(DocumentSegment)
  594. .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
  595. .first()
  596. )
  597. if not segment:
  598. raise NotFound("Segment not found.")
  599. # check child chunk
  600. child_chunk_id = str(child_chunk_id)
  601. child_chunk = (
  602. db.session.query(ChildChunk)
  603. .where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
  604. .first()
  605. )
  606. if not child_chunk:
  607. raise NotFound("Child chunk not found.")
  608. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  609. if not current_user.is_dataset_editor:
  610. raise Forbidden()
  611. try:
  612. DatasetService.check_dataset_permission(dataset, current_user)
  613. except services.errors.account.NoPermissionError as e:
  614. raise Forbidden(str(e))
  615. # validate args
  616. parser = reqparse.RequestParser()
  617. parser.add_argument("content", type=str, required=True, nullable=False, location="json")
  618. args = parser.parse_args()
  619. try:
  620. child_chunk = SegmentService.update_child_chunk(
  621. args.get("content"), child_chunk, segment, document, dataset
  622. )
  623. except ChildChunkIndexingServiceError as e:
  624. raise ChildChunkIndexingError(str(e))
  625. return {"data": marshal(child_chunk, child_chunk_fields)}, 200
  626. api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
  627. api.add_resource(
  628. DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>"
  629. )
  630. api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
  631. api.add_resource(
  632. DatasetDocumentSegmentUpdateApi,
  633. "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>",
  634. )
  635. api.add_resource(
  636. DatasetDocumentSegmentBatchImportApi,
  637. "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
  638. "/datasets/batch_import_status/<uuid:job_id>",
  639. )
  640. api.add_resource(
  641. ChildChunkAddApi,
  642. "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks",
  643. )
  644. api.add_resource(
  645. ChildChunkUpdateApi,
  646. "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>",
  647. )