You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. from flask import request
  2. from flask_login import current_user
  3. from flask_restful import marshal, reqparse
  4. from werkzeug.exceptions import NotFound
  5. from controllers.service_api import api
  6. from controllers.service_api.app.error import ProviderNotInitializeError
  7. from controllers.service_api.wraps import (
  8. DatasetApiResource,
  9. cloud_edition_billing_knowledge_limit_check,
  10. cloud_edition_billing_rate_limit_check,
  11. cloud_edition_billing_resource_check,
  12. )
  13. from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
  14. from core.model_manager import ModelManager
  15. from core.model_runtime.entities.model_entities import ModelType
  16. from extensions.ext_database import db
  17. from fields.segment_fields import child_chunk_fields, segment_fields
  18. from models.dataset import Dataset
  19. from services.dataset_service import DatasetService, DocumentService, SegmentService
  20. from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
  21. from services.errors.chunk import (
  22. ChildChunkDeleteIndexError,
  23. ChildChunkIndexingError,
  24. )
  25. from services.errors.chunk import (
  26. ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError,
  27. )
  28. from services.errors.chunk import (
  29. ChildChunkIndexingError as ChildChunkIndexingServiceError,
  30. )
  31. class SegmentApi(DatasetApiResource):
  32. """Resource for segments."""
  33. @cloud_edition_billing_resource_check("vector_space", "dataset")
  34. @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
  35. @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
  36. def post(self, tenant_id, dataset_id, document_id):
  37. """Create single segment."""
  38. # check dataset
  39. dataset_id = str(dataset_id)
  40. tenant_id = str(tenant_id)
  41. dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
  42. if not dataset:
  43. raise NotFound("Dataset not found.")
  44. # check document
  45. document_id = str(document_id)
  46. document = DocumentService.get_document(dataset.id, document_id)
  47. if not document:
  48. raise NotFound("Document not found.")
  49. if document.indexing_status != "completed":
  50. raise NotFound("Document is not completed.")
  51. if not document.enabled:
  52. raise NotFound("Document is disabled.")
  53. # check embedding model setting
  54. if dataset.indexing_technique == "high_quality":
  55. try:
  56. model_manager = ModelManager()
  57. model_manager.get_model_instance(
  58. tenant_id=current_user.current_tenant_id,
  59. provider=dataset.embedding_model_provider,
  60. model_type=ModelType.TEXT_EMBEDDING,
  61. model=dataset.embedding_model,
  62. )
  63. except LLMBadRequestError:
  64. raise ProviderNotInitializeError(
  65. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  66. )
  67. except ProviderTokenNotInitError as ex:
  68. raise ProviderNotInitializeError(ex.description)
  69. # validate args
  70. parser = reqparse.RequestParser()
  71. parser.add_argument("segments", type=list, required=False, nullable=True, location="json")
  72. args = parser.parse_args()
  73. if args["segments"] is not None:
  74. for args_item in args["segments"]:
  75. SegmentService.segment_create_args_validate(args_item, document)
  76. segments = SegmentService.multi_create_segment(args["segments"], document, dataset)
  77. return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
  78. else:
  79. return {"error": "Segments is required"}, 400
  80. def get(self, tenant_id, dataset_id, document_id):
  81. """Get segments."""
  82. # check dataset
  83. dataset_id = str(dataset_id)
  84. tenant_id = str(tenant_id)
  85. page = request.args.get("page", default=1, type=int)
  86. limit = request.args.get("limit", default=20, type=int)
  87. dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
  88. if not dataset:
  89. raise NotFound("Dataset not found.")
  90. # check document
  91. document_id = str(document_id)
  92. document = DocumentService.get_document(dataset.id, document_id)
  93. if not document:
  94. raise NotFound("Document not found.")
  95. # check embedding model setting
  96. if dataset.indexing_technique == "high_quality":
  97. try:
  98. model_manager = ModelManager()
  99. model_manager.get_model_instance(
  100. tenant_id=current_user.current_tenant_id,
  101. provider=dataset.embedding_model_provider,
  102. model_type=ModelType.TEXT_EMBEDDING,
  103. model=dataset.embedding_model,
  104. )
  105. except LLMBadRequestError:
  106. raise ProviderNotInitializeError(
  107. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  108. )
  109. except ProviderTokenNotInitError as ex:
  110. raise ProviderNotInitializeError(ex.description)
  111. parser = reqparse.RequestParser()
  112. parser.add_argument("status", type=str, action="append", default=[], location="args")
  113. parser.add_argument("keyword", type=str, default=None, location="args")
  114. args = parser.parse_args()
  115. segments, total = SegmentService.get_segments(
  116. document_id=document_id,
  117. tenant_id=current_user.current_tenant_id,
  118. status_list=args["status"],
  119. keyword=args["keyword"],
  120. page=page,
  121. limit=limit,
  122. )
  123. response = {
  124. "data": marshal(segments, segment_fields),
  125. "doc_form": document.doc_form,
  126. "total": total,
  127. "has_more": len(segments) == limit,
  128. "limit": limit,
  129. "page": page,
  130. }
  131. return response, 200
  132. class DatasetSegmentApi(DatasetApiResource):
  133. @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
  134. def delete(self, tenant_id, dataset_id, document_id, segment_id):
  135. # check dataset
  136. dataset_id = str(dataset_id)
  137. tenant_id = str(tenant_id)
  138. dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
  139. if not dataset:
  140. raise NotFound("Dataset not found.")
  141. # check user's model setting
  142. DatasetService.check_dataset_model_setting(dataset)
  143. # check document
  144. document_id = str(document_id)
  145. document = DocumentService.get_document(dataset_id, document_id)
  146. if not document:
  147. raise NotFound("Document not found.")
  148. # check segment
  149. segment_id = str(segment_id)
  150. segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
  151. if not segment:
  152. raise NotFound("Segment not found.")
  153. SegmentService.delete_segment(segment, document, dataset)
  154. return 204
  155. @cloud_edition_billing_resource_check("vector_space", "dataset")
  156. @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
  157. def post(self, tenant_id, dataset_id, document_id, segment_id):
  158. # check dataset
  159. dataset_id = str(dataset_id)
  160. tenant_id = str(tenant_id)
  161. dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
  162. if not dataset:
  163. raise NotFound("Dataset not found.")
  164. # check user's model setting
  165. DatasetService.check_dataset_model_setting(dataset)
  166. # check document
  167. document_id = str(document_id)
  168. document = DocumentService.get_document(dataset_id, document_id)
  169. if not document:
  170. raise NotFound("Document not found.")
  171. if dataset.indexing_technique == "high_quality":
  172. # check embedding model setting
  173. try:
  174. model_manager = ModelManager()
  175. model_manager.get_model_instance(
  176. tenant_id=current_user.current_tenant_id,
  177. provider=dataset.embedding_model_provider,
  178. model_type=ModelType.TEXT_EMBEDDING,
  179. model=dataset.embedding_model,
  180. )
  181. except LLMBadRequestError:
  182. raise ProviderNotInitializeError(
  183. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  184. )
  185. except ProviderTokenNotInitError as ex:
  186. raise ProviderNotInitializeError(ex.description)
  187. # check segment
  188. segment_id = str(segment_id)
  189. segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
  190. if not segment:
  191. raise NotFound("Segment not found.")
  192. # validate args
  193. parser = reqparse.RequestParser()
  194. parser.add_argument("segment", type=dict, required=False, nullable=True, location="json")
  195. args = parser.parse_args()
  196. updated_segment = SegmentService.update_segment(
  197. SegmentUpdateArgs(**args["segment"]), segment, document, dataset
  198. )
  199. return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
  200. def get(self, tenant_id, dataset_id, document_id, segment_id):
  201. # check dataset
  202. dataset_id = str(dataset_id)
  203. tenant_id = str(tenant_id)
  204. dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
  205. if not dataset:
  206. raise NotFound("Dataset not found.")
  207. # check user's model setting
  208. DatasetService.check_dataset_model_setting(dataset)
  209. # check document
  210. document_id = str(document_id)
  211. document = DocumentService.get_document(dataset_id, document_id)
  212. if not document:
  213. raise NotFound("Document not found.")
  214. # check segment
  215. segment_id = str(segment_id)
  216. segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
  217. if not segment:
  218. raise NotFound("Segment not found.")
  219. return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
  220. class ChildChunkApi(DatasetApiResource):
  221. """Resource for child chunks."""
  222. @cloud_edition_billing_resource_check("vector_space", "dataset")
  223. @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
  224. @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
  225. def post(self, tenant_id, dataset_id, document_id, segment_id):
  226. """Create child chunk."""
  227. # check dataset
  228. dataset_id = str(dataset_id)
  229. tenant_id = str(tenant_id)
  230. dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
  231. if not dataset:
  232. raise NotFound("Dataset not found.")
  233. # check document
  234. document_id = str(document_id)
  235. document = DocumentService.get_document(dataset.id, document_id)
  236. if not document:
  237. raise NotFound("Document not found.")
  238. # check segment
  239. segment_id = str(segment_id)
  240. segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
  241. if not segment:
  242. raise NotFound("Segment not found.")
  243. # check embedding model setting
  244. if dataset.indexing_technique == "high_quality":
  245. try:
  246. model_manager = ModelManager()
  247. model_manager.get_model_instance(
  248. tenant_id=current_user.current_tenant_id,
  249. provider=dataset.embedding_model_provider,
  250. model_type=ModelType.TEXT_EMBEDDING,
  251. model=dataset.embedding_model,
  252. )
  253. except LLMBadRequestError:
  254. raise ProviderNotInitializeError(
  255. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  256. )
  257. except ProviderTokenNotInitError as ex:
  258. raise ProviderNotInitializeError(ex.description)
  259. # validate args
  260. parser = reqparse.RequestParser()
  261. parser.add_argument("content", type=str, required=True, nullable=False, location="json")
  262. args = parser.parse_args()
  263. try:
  264. child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset)
  265. except ChildChunkIndexingServiceError as e:
  266. raise ChildChunkIndexingError(str(e))
  267. return {"data": marshal(child_chunk, child_chunk_fields)}, 200
  268. def get(self, tenant_id, dataset_id, document_id, segment_id):
  269. """Get child chunks."""
  270. # check dataset
  271. dataset_id = str(dataset_id)
  272. tenant_id = str(tenant_id)
  273. dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
  274. if not dataset:
  275. raise NotFound("Dataset not found.")
  276. # check document
  277. document_id = str(document_id)
  278. document = DocumentService.get_document(dataset.id, document_id)
  279. if not document:
  280. raise NotFound("Document not found.")
  281. # check segment
  282. segment_id = str(segment_id)
  283. segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
  284. if not segment:
  285. raise NotFound("Segment not found.")
  286. parser = reqparse.RequestParser()
  287. parser.add_argument("limit", type=int, default=20, location="args")
  288. parser.add_argument("keyword", type=str, default=None, location="args")
  289. parser.add_argument("page", type=int, default=1, location="args")
  290. args = parser.parse_args()
  291. page = args["page"]
  292. limit = min(args["limit"], 100)
  293. keyword = args["keyword"]
  294. child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
  295. return {
  296. "data": marshal(child_chunks.items, child_chunk_fields),
  297. "total": child_chunks.total,
  298. "total_pages": child_chunks.pages,
  299. "page": page,
  300. "limit": limit,
  301. }, 200
  302. class DatasetChildChunkApi(DatasetApiResource):
  303. """Resource for updating child chunks."""
  304. @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
  305. @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
  306. def delete(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id):
  307. """Delete child chunk."""
  308. # check dataset
  309. dataset_id = str(dataset_id)
  310. tenant_id = str(tenant_id)
  311. dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
  312. if not dataset:
  313. raise NotFound("Dataset not found.")
  314. # check document
  315. document_id = str(document_id)
  316. document = DocumentService.get_document(dataset.id, document_id)
  317. if not document:
  318. raise NotFound("Document not found.")
  319. # check segment
  320. segment_id = str(segment_id)
  321. segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
  322. if not segment:
  323. raise NotFound("Segment not found.")
  324. # validate segment belongs to the specified document
  325. if segment.document_id != document_id:
  326. raise NotFound("Document not found.")
  327. # check child chunk
  328. child_chunk_id = str(child_chunk_id)
  329. child_chunk = SegmentService.get_child_chunk_by_id(
  330. child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id
  331. )
  332. if not child_chunk:
  333. raise NotFound("Child chunk not found.")
  334. # validate child chunk belongs to the specified segment
  335. if child_chunk.segment_id != segment.id:
  336. raise NotFound("Child chunk not found.")
  337. try:
  338. SegmentService.delete_child_chunk(child_chunk, dataset)
  339. except ChildChunkDeleteIndexServiceError as e:
  340. raise ChildChunkDeleteIndexError(str(e))
  341. return 204
  342. @cloud_edition_billing_resource_check("vector_space", "dataset")
  343. @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
  344. @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
  345. def patch(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id):
  346. """Update child chunk."""
  347. # check dataset
  348. dataset_id = str(dataset_id)
  349. tenant_id = str(tenant_id)
  350. dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
  351. if not dataset:
  352. raise NotFound("Dataset not found.")
  353. # get document
  354. document = DocumentService.get_document(dataset_id, document_id)
  355. if not document:
  356. raise NotFound("Document not found.")
  357. # get segment
  358. segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
  359. if not segment:
  360. raise NotFound("Segment not found.")
  361. # validate segment belongs to the specified document
  362. if segment.document_id != document_id:
  363. raise NotFound("Segment not found.")
  364. # get child chunk
  365. child_chunk = SegmentService.get_child_chunk_by_id(
  366. child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id
  367. )
  368. if not child_chunk:
  369. raise NotFound("Child chunk not found.")
  370. # validate child chunk belongs to the specified segment
  371. if child_chunk.segment_id != segment.id:
  372. raise NotFound("Child chunk not found.")
  373. # validate args
  374. parser = reqparse.RequestParser()
  375. parser.add_argument("content", type=str, required=True, nullable=False, location="json")
  376. args = parser.parse_args()
  377. try:
  378. child_chunk = SegmentService.update_child_chunk(
  379. args.get("content"), child_chunk, segment, document, dataset
  380. )
  381. except ChildChunkIndexingServiceError as e:
  382. raise ChildChunkIndexingError(str(e))
  383. return {"data": marshal(child_chunk, child_chunk_fields)}, 200
  384. api.add_resource(SegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
  385. api.add_resource(
  386. DatasetSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>"
  387. )
  388. api.add_resource(
  389. ChildChunkApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks"
  390. )
  391. api.add_resource(
  392. DatasetChildChunkApi,
  393. "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>",
  394. )