您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

datasets_document.py 43KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039
  1. import logging
  2. from argparse import ArgumentTypeError
  3. from typing import Literal, cast
  4. from flask import request
  5. from flask_login import current_user
  6. from flask_restx import Resource, marshal, marshal_with, reqparse
  7. from sqlalchemy import asc, desc, select
  8. from werkzeug.exceptions import Forbidden, NotFound
  9. import services
  10. from controllers.console import api
  11. from controllers.console.app.error import (
  12. ProviderModelCurrentlyNotSupportError,
  13. ProviderNotInitializeError,
  14. ProviderQuotaExceededError,
  15. )
  16. from controllers.console.datasets.error import (
  17. ArchivedDocumentImmutableError,
  18. DocumentAlreadyFinishedError,
  19. DocumentIndexingError,
  20. IndexingEstimateError,
  21. InvalidActionError,
  22. InvalidMetadataError,
  23. )
  24. from controllers.console.wraps import (
  25. account_initialization_required,
  26. cloud_edition_billing_rate_limit_check,
  27. cloud_edition_billing_resource_check,
  28. setup_required,
  29. )
  30. from core.errors.error import (
  31. LLMBadRequestError,
  32. ModelCurrentlyNotSupportError,
  33. ProviderTokenNotInitError,
  34. QuotaExceededError,
  35. )
  36. from core.indexing_runner import IndexingRunner
  37. from core.model_manager import ModelManager
  38. from core.model_runtime.entities.model_entities import ModelType
  39. from core.model_runtime.errors.invoke import InvokeAuthorizationError
  40. from core.plugin.impl.exc import PluginDaemonClientSideError
  41. from core.rag.extractor.entity.datasource_type import DatasourceType
  42. from core.rag.extractor.entity.extract_setting import ExtractSetting
  43. from extensions.ext_database import db
  44. from fields.document_fields import (
  45. dataset_and_document_fields,
  46. document_fields,
  47. document_status_fields,
  48. document_with_segments_fields,
  49. )
  50. from libs.datetime_utils import naive_utc_now
  51. from libs.login import login_required
  52. from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
  53. from services.dataset_service import DatasetService, DocumentService
  54. from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
  55. logger = logging.getLogger(__name__)
  56. class DocumentResource(Resource):
  57. def get_document(self, dataset_id: str, document_id: str) -> Document:
  58. dataset = DatasetService.get_dataset(dataset_id)
  59. if not dataset:
  60. raise NotFound("Dataset not found.")
  61. try:
  62. DatasetService.check_dataset_permission(dataset, current_user)
  63. except services.errors.account.NoPermissionError as e:
  64. raise Forbidden(str(e))
  65. document = DocumentService.get_document(dataset_id, document_id)
  66. if not document:
  67. raise NotFound("Document not found.")
  68. if document.tenant_id != current_user.current_tenant_id:
  69. raise Forbidden("No permission.")
  70. return document
  71. def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]:
  72. dataset = DatasetService.get_dataset(dataset_id)
  73. if not dataset:
  74. raise NotFound("Dataset not found.")
  75. try:
  76. DatasetService.check_dataset_permission(dataset, current_user)
  77. except services.errors.account.NoPermissionError as e:
  78. raise Forbidden(str(e))
  79. documents = DocumentService.get_batch_documents(dataset_id, batch)
  80. if not documents:
  81. raise NotFound("Documents not found.")
  82. return documents
  83. class GetProcessRuleApi(Resource):
  84. @setup_required
  85. @login_required
  86. @account_initialization_required
  87. def get(self):
  88. req_data = request.args
  89. document_id = req_data.get("document_id")
  90. # get default rules
  91. mode = DocumentService.DEFAULT_RULES["mode"]
  92. rules = DocumentService.DEFAULT_RULES["rules"]
  93. limits = DocumentService.DEFAULT_RULES["limits"]
  94. if document_id:
  95. # get the latest process rule
  96. document = db.get_or_404(Document, document_id)
  97. dataset = DatasetService.get_dataset(document.dataset_id)
  98. if not dataset:
  99. raise NotFound("Dataset not found.")
  100. try:
  101. DatasetService.check_dataset_permission(dataset, current_user)
  102. except services.errors.account.NoPermissionError as e:
  103. raise Forbidden(str(e))
  104. # get the latest process rule
  105. dataset_process_rule = (
  106. db.session.query(DatasetProcessRule)
  107. .where(DatasetProcessRule.dataset_id == document.dataset_id)
  108. .order_by(DatasetProcessRule.created_at.desc())
  109. .limit(1)
  110. .one_or_none()
  111. )
  112. if dataset_process_rule:
  113. mode = dataset_process_rule.mode
  114. rules = dataset_process_rule.rules_dict
  115. return {"mode": mode, "rules": rules, "limits": limits}
  116. class DatasetDocumentListApi(Resource):
  117. @setup_required
  118. @login_required
  119. @account_initialization_required
  120. def get(self, dataset_id):
  121. dataset_id = str(dataset_id)
  122. page = request.args.get("page", default=1, type=int)
  123. limit = request.args.get("limit", default=20, type=int)
  124. search = request.args.get("keyword", default=None, type=str)
  125. sort = request.args.get("sort", default="-created_at", type=str)
  126. # "yes", "true", "t", "y", "1" convert to True, while others convert to False.
  127. try:
  128. fetch_val = request.args.get("fetch", default="false")
  129. if isinstance(fetch_val, bool):
  130. fetch = fetch_val
  131. else:
  132. if fetch_val.lower() in ("yes", "true", "t", "y", "1"):
  133. fetch = True
  134. elif fetch_val.lower() in ("no", "false", "f", "n", "0"):
  135. fetch = False
  136. else:
  137. raise ArgumentTypeError(
  138. f"Truthy value expected: got {fetch_val} but expected one of yes/no, true/false, t/f, y/n, 1/0 "
  139. f"(case insensitive)."
  140. )
  141. except (ArgumentTypeError, ValueError, Exception):
  142. fetch = False
  143. dataset = DatasetService.get_dataset(dataset_id)
  144. if not dataset:
  145. raise NotFound("Dataset not found.")
  146. try:
  147. DatasetService.check_dataset_permission(dataset, current_user)
  148. except services.errors.account.NoPermissionError as e:
  149. raise Forbidden(str(e))
  150. query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
  151. if search:
  152. search = f"%{search}%"
  153. query = query.where(Document.name.like(search))
  154. if sort.startswith("-"):
  155. sort_logic = desc
  156. sort = sort[1:]
  157. else:
  158. sort_logic = asc
  159. if sort == "hit_count":
  160. sub_query = (
  161. db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
  162. .group_by(DocumentSegment.document_id)
  163. .subquery()
  164. )
  165. query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
  166. sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
  167. sort_logic(Document.position),
  168. )
  169. elif sort == "created_at":
  170. query = query.order_by(
  171. sort_logic(Document.created_at),
  172. sort_logic(Document.position),
  173. )
  174. else:
  175. query = query.order_by(
  176. desc(Document.created_at),
  177. desc(Document.position),
  178. )
  179. paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
  180. documents = paginated_documents.items
  181. if fetch:
  182. for document in documents:
  183. completed_segments = (
  184. db.session.query(DocumentSegment)
  185. .where(
  186. DocumentSegment.completed_at.isnot(None),
  187. DocumentSegment.document_id == str(document.id),
  188. DocumentSegment.status != "re_segment",
  189. )
  190. .count()
  191. )
  192. total_segments = (
  193. db.session.query(DocumentSegment)
  194. .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
  195. .count()
  196. )
  197. document.completed_segments = completed_segments
  198. document.total_segments = total_segments
  199. data = marshal(documents, document_with_segments_fields)
  200. else:
  201. data = marshal(documents, document_fields)
  202. response = {
  203. "data": data,
  204. "has_more": len(documents) == limit,
  205. "limit": limit,
  206. "total": paginated_documents.total,
  207. "page": page,
  208. }
  209. return response
  210. @setup_required
  211. @login_required
  212. @account_initialization_required
  213. @marshal_with(dataset_and_document_fields)
  214. @cloud_edition_billing_resource_check("vector_space")
  215. @cloud_edition_billing_rate_limit_check("knowledge")
  216. def post(self, dataset_id):
  217. dataset_id = str(dataset_id)
  218. dataset = DatasetService.get_dataset(dataset_id)
  219. if not dataset:
  220. raise NotFound("Dataset not found.")
  221. # The role of the current user in the ta table must be admin, owner, or editor
  222. if not current_user.is_dataset_editor:
  223. raise Forbidden()
  224. try:
  225. DatasetService.check_dataset_permission(dataset, current_user)
  226. except services.errors.account.NoPermissionError as e:
  227. raise Forbidden(str(e))
  228. parser = reqparse.RequestParser()
  229. parser.add_argument(
  230. "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
  231. )
  232. parser.add_argument("data_source", type=dict, required=False, location="json")
  233. parser.add_argument("process_rule", type=dict, required=False, location="json")
  234. parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
  235. parser.add_argument("original_document_id", type=str, required=False, location="json")
  236. parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
  237. parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
  238. parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
  239. parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
  240. parser.add_argument(
  241. "doc_language", type=str, default="English", required=False, nullable=False, location="json"
  242. )
  243. args = parser.parse_args()
  244. knowledge_config = KnowledgeConfig(**args)
  245. if not dataset.indexing_technique and not knowledge_config.indexing_technique:
  246. raise ValueError("indexing_technique is required.")
  247. # validate args
  248. DocumentService.document_create_args_validate(knowledge_config)
  249. try:
  250. documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user)
  251. dataset = DatasetService.get_dataset(dataset_id)
  252. except ProviderTokenNotInitError as ex:
  253. raise ProviderNotInitializeError(ex.description)
  254. except QuotaExceededError:
  255. raise ProviderQuotaExceededError()
  256. except ModelCurrentlyNotSupportError:
  257. raise ProviderModelCurrentlyNotSupportError()
  258. return {"dataset": dataset, "documents": documents, "batch": batch}
  259. @setup_required
  260. @login_required
  261. @account_initialization_required
  262. @cloud_edition_billing_rate_limit_check("knowledge")
  263. def delete(self, dataset_id):
  264. dataset_id = str(dataset_id)
  265. dataset = DatasetService.get_dataset(dataset_id)
  266. if dataset is None:
  267. raise NotFound("Dataset not found.")
  268. # check user's model setting
  269. DatasetService.check_dataset_model_setting(dataset)
  270. try:
  271. document_ids = request.args.getlist("document_id")
  272. DocumentService.delete_documents(dataset, document_ids)
  273. except services.errors.document.DocumentIndexingError:
  274. raise DocumentIndexingError("Cannot delete document during indexing.")
  275. return {"result": "success"}, 204
  276. class DatasetInitApi(Resource):
  277. @setup_required
  278. @login_required
  279. @account_initialization_required
  280. @marshal_with(dataset_and_document_fields)
  281. @cloud_edition_billing_resource_check("vector_space")
  282. @cloud_edition_billing_rate_limit_check("knowledge")
  283. def post(self):
  284. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  285. if not current_user.is_dataset_editor:
  286. raise Forbidden()
  287. parser = reqparse.RequestParser()
  288. parser.add_argument(
  289. "indexing_technique",
  290. type=str,
  291. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  292. required=True,
  293. nullable=False,
  294. location="json",
  295. )
  296. parser.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
  297. parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
  298. parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
  299. parser.add_argument(
  300. "doc_language", type=str, default="English", required=False, nullable=False, location="json"
  301. )
  302. parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
  303. parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
  304. parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
  305. args = parser.parse_args()
  306. knowledge_config = KnowledgeConfig(**args)
  307. if knowledge_config.indexing_technique == "high_quality":
  308. if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
  309. raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
  310. try:
  311. model_manager = ModelManager()
  312. model_manager.get_model_instance(
  313. tenant_id=current_user.current_tenant_id,
  314. provider=args["embedding_model_provider"],
  315. model_type=ModelType.TEXT_EMBEDDING,
  316. model=args["embedding_model"],
  317. )
  318. except InvokeAuthorizationError:
  319. raise ProviderNotInitializeError(
  320. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  321. )
  322. except ProviderTokenNotInitError as ex:
  323. raise ProviderNotInitializeError(ex.description)
  324. # validate args
  325. DocumentService.document_create_args_validate(knowledge_config)
  326. try:
  327. dataset, documents, batch = DocumentService.save_document_without_dataset_id(
  328. tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user
  329. )
  330. except ProviderTokenNotInitError as ex:
  331. raise ProviderNotInitializeError(ex.description)
  332. except QuotaExceededError:
  333. raise ProviderQuotaExceededError()
  334. except ModelCurrentlyNotSupportError:
  335. raise ProviderModelCurrentlyNotSupportError()
  336. response = {"dataset": dataset, "documents": documents, "batch": batch}
  337. return response
  338. class DocumentIndexingEstimateApi(DocumentResource):
  339. @setup_required
  340. @login_required
  341. @account_initialization_required
  342. def get(self, dataset_id, document_id):
  343. dataset_id = str(dataset_id)
  344. document_id = str(document_id)
  345. document = self.get_document(dataset_id, document_id)
  346. if document.indexing_status in {"completed", "error"}:
  347. raise DocumentAlreadyFinishedError()
  348. data_process_rule = document.dataset_process_rule
  349. data_process_rule_dict = data_process_rule.to_dict()
  350. response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}
  351. if document.data_source_type == "upload_file":
  352. data_source_info = document.data_source_info_dict
  353. if data_source_info and "upload_file_id" in data_source_info:
  354. file_id = data_source_info["upload_file_id"]
  355. file = (
  356. db.session.query(UploadFile)
  357. .where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
  358. .first()
  359. )
  360. # raise error if file not found
  361. if not file:
  362. raise NotFound("File not found.")
  363. extract_setting = ExtractSetting(
  364. datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form
  365. )
  366. indexing_runner = IndexingRunner()
  367. try:
  368. estimate_response = indexing_runner.indexing_estimate(
  369. current_user.current_tenant_id,
  370. [extract_setting],
  371. data_process_rule_dict,
  372. document.doc_form,
  373. "English",
  374. dataset_id,
  375. )
  376. return estimate_response.model_dump(), 200
  377. except LLMBadRequestError:
  378. raise ProviderNotInitializeError(
  379. "No Embedding Model available. Please configure a valid provider "
  380. "in the Settings -> Model Provider."
  381. )
  382. except ProviderTokenNotInitError as ex:
  383. raise ProviderNotInitializeError(ex.description)
  384. except PluginDaemonClientSideError as ex:
  385. raise ProviderNotInitializeError(ex.description)
  386. except Exception as e:
  387. raise IndexingEstimateError(str(e))
  388. return response, 200
  389. class DocumentBatchIndexingEstimateApi(DocumentResource):
  390. @setup_required
  391. @login_required
  392. @account_initialization_required
  393. def get(self, dataset_id, batch):
  394. dataset_id = str(dataset_id)
  395. batch = str(batch)
  396. documents = self.get_batch_documents(dataset_id, batch)
  397. if not documents:
  398. return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
  399. data_process_rule = documents[0].dataset_process_rule
  400. data_process_rule_dict = data_process_rule.to_dict()
  401. extract_settings = []
  402. for document in documents:
  403. if document.indexing_status in {"completed", "error"}:
  404. raise DocumentAlreadyFinishedError()
  405. data_source_info = document.data_source_info_dict
  406. if document.data_source_type == "upload_file":
  407. if not data_source_info:
  408. continue
  409. file_id = data_source_info["upload_file_id"]
  410. file_detail = (
  411. db.session.query(UploadFile)
  412. .where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id)
  413. .first()
  414. )
  415. if file_detail is None:
  416. raise NotFound("File not found.")
  417. extract_setting = ExtractSetting(
  418. datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form
  419. )
  420. extract_settings.append(extract_setting)
  421. elif document.data_source_type == "notion_import":
  422. if not data_source_info:
  423. continue
  424. extract_setting = ExtractSetting(
  425. datasource_type=DatasourceType.NOTION.value,
  426. notion_info={
  427. "notion_workspace_id": data_source_info["notion_workspace_id"],
  428. "notion_obj_id": data_source_info["notion_page_id"],
  429. "notion_page_type": data_source_info["type"],
  430. "tenant_id": current_user.current_tenant_id,
  431. },
  432. document_model=document.doc_form,
  433. )
  434. extract_settings.append(extract_setting)
  435. elif document.data_source_type == "website_crawl":
  436. if not data_source_info:
  437. continue
  438. extract_setting = ExtractSetting(
  439. datasource_type=DatasourceType.WEBSITE.value,
  440. website_info={
  441. "provider": data_source_info["provider"],
  442. "job_id": data_source_info["job_id"],
  443. "url": data_source_info["url"],
  444. "tenant_id": current_user.current_tenant_id,
  445. "mode": data_source_info["mode"],
  446. "only_main_content": data_source_info["only_main_content"],
  447. },
  448. document_model=document.doc_form,
  449. )
  450. extract_settings.append(extract_setting)
  451. else:
  452. raise ValueError("Data source type not support")
  453. indexing_runner = IndexingRunner()
  454. try:
  455. response = indexing_runner.indexing_estimate(
  456. current_user.current_tenant_id,
  457. extract_settings,
  458. data_process_rule_dict,
  459. document.doc_form,
  460. "English",
  461. dataset_id,
  462. )
  463. return response.model_dump(), 200
  464. except LLMBadRequestError:
  465. raise ProviderNotInitializeError(
  466. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  467. )
  468. except ProviderTokenNotInitError as ex:
  469. raise ProviderNotInitializeError(ex.description)
  470. except PluginDaemonClientSideError as ex:
  471. raise ProviderNotInitializeError(ex.description)
  472. except Exception as e:
  473. raise IndexingEstimateError(str(e))
  474. class DocumentBatchIndexingStatusApi(DocumentResource):
  475. @setup_required
  476. @login_required
  477. @account_initialization_required
  478. def get(self, dataset_id, batch):
  479. dataset_id = str(dataset_id)
  480. batch = str(batch)
  481. documents = self.get_batch_documents(dataset_id, batch)
  482. documents_status = []
  483. for document in documents:
  484. completed_segments = (
  485. db.session.query(DocumentSegment)
  486. .where(
  487. DocumentSegment.completed_at.isnot(None),
  488. DocumentSegment.document_id == str(document.id),
  489. DocumentSegment.status != "re_segment",
  490. )
  491. .count()
  492. )
  493. total_segments = (
  494. db.session.query(DocumentSegment)
  495. .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
  496. .count()
  497. )
  498. # Create a dictionary with document attributes and additional fields
  499. document_dict = {
  500. "id": document.id,
  501. "indexing_status": "paused" if document.is_paused else document.indexing_status,
  502. "processing_started_at": document.processing_started_at,
  503. "parsing_completed_at": document.parsing_completed_at,
  504. "cleaning_completed_at": document.cleaning_completed_at,
  505. "splitting_completed_at": document.splitting_completed_at,
  506. "completed_at": document.completed_at,
  507. "paused_at": document.paused_at,
  508. "error": document.error,
  509. "stopped_at": document.stopped_at,
  510. "completed_segments": completed_segments,
  511. "total_segments": total_segments,
  512. }
  513. documents_status.append(marshal(document_dict, document_status_fields))
  514. data = {"data": documents_status}
  515. return data
  516. class DocumentIndexingStatusApi(DocumentResource):
  517. @setup_required
  518. @login_required
  519. @account_initialization_required
  520. def get(self, dataset_id, document_id):
  521. dataset_id = str(dataset_id)
  522. document_id = str(document_id)
  523. document = self.get_document(dataset_id, document_id)
  524. completed_segments = (
  525. db.session.query(DocumentSegment)
  526. .where(
  527. DocumentSegment.completed_at.isnot(None),
  528. DocumentSegment.document_id == str(document_id),
  529. DocumentSegment.status != "re_segment",
  530. )
  531. .count()
  532. )
  533. total_segments = (
  534. db.session.query(DocumentSegment)
  535. .where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
  536. .count()
  537. )
  538. # Create a dictionary with document attributes and additional fields
  539. document_dict = {
  540. "id": document.id,
  541. "indexing_status": "paused" if document.is_paused else document.indexing_status,
  542. "processing_started_at": document.processing_started_at,
  543. "parsing_completed_at": document.parsing_completed_at,
  544. "cleaning_completed_at": document.cleaning_completed_at,
  545. "splitting_completed_at": document.splitting_completed_at,
  546. "completed_at": document.completed_at,
  547. "paused_at": document.paused_at,
  548. "error": document.error,
  549. "stopped_at": document.stopped_at,
  550. "completed_segments": completed_segments,
  551. "total_segments": total_segments,
  552. }
  553. return marshal(document_dict, document_status_fields)
  554. class DocumentApi(DocumentResource):
  555. METADATA_CHOICES = {"all", "only", "without"}
  556. @setup_required
  557. @login_required
  558. @account_initialization_required
  559. def get(self, dataset_id, document_id):
  560. dataset_id = str(dataset_id)
  561. document_id = str(document_id)
  562. document = self.get_document(dataset_id, document_id)
  563. metadata = request.args.get("metadata", "all")
  564. if metadata not in self.METADATA_CHOICES:
  565. raise InvalidMetadataError(f"Invalid metadata value: {metadata}")
  566. if metadata == "only":
  567. response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
  568. elif metadata == "without":
  569. dataset_process_rules = DatasetService.get_process_rules(dataset_id)
  570. document_process_rules = document.dataset_process_rule.to_dict()
  571. data_source_info = document.data_source_detail_dict
  572. response = {
  573. "id": document.id,
  574. "position": document.position,
  575. "data_source_type": document.data_source_type,
  576. "data_source_info": data_source_info,
  577. "dataset_process_rule_id": document.dataset_process_rule_id,
  578. "dataset_process_rule": dataset_process_rules,
  579. "document_process_rule": document_process_rules,
  580. "name": document.name,
  581. "created_from": document.created_from,
  582. "created_by": document.created_by,
  583. "created_at": document.created_at.timestamp(),
  584. "tokens": document.tokens,
  585. "indexing_status": document.indexing_status,
  586. "completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
  587. "updated_at": int(document.updated_at.timestamp()) if document.updated_at else None,
  588. "indexing_latency": document.indexing_latency,
  589. "error": document.error,
  590. "enabled": document.enabled,
  591. "disabled_at": int(document.disabled_at.timestamp()) if document.disabled_at else None,
  592. "disabled_by": document.disabled_by,
  593. "archived": document.archived,
  594. "segment_count": document.segment_count,
  595. "average_segment_length": document.average_segment_length,
  596. "hit_count": document.hit_count,
  597. "display_status": document.display_status,
  598. "doc_form": document.doc_form,
  599. "doc_language": document.doc_language,
  600. }
  601. else:
  602. dataset_process_rules = DatasetService.get_process_rules(dataset_id)
  603. document_process_rules = document.dataset_process_rule.to_dict()
  604. data_source_info = document.data_source_detail_dict
  605. response = {
  606. "id": document.id,
  607. "position": document.position,
  608. "data_source_type": document.data_source_type,
  609. "data_source_info": data_source_info,
  610. "dataset_process_rule_id": document.dataset_process_rule_id,
  611. "dataset_process_rule": dataset_process_rules,
  612. "document_process_rule": document_process_rules,
  613. "name": document.name,
  614. "created_from": document.created_from,
  615. "created_by": document.created_by,
  616. "created_at": document.created_at.timestamp(),
  617. "tokens": document.tokens,
  618. "indexing_status": document.indexing_status,
  619. "completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
  620. "updated_at": int(document.updated_at.timestamp()) if document.updated_at else None,
  621. "indexing_latency": document.indexing_latency,
  622. "error": document.error,
  623. "enabled": document.enabled,
  624. "disabled_at": int(document.disabled_at.timestamp()) if document.disabled_at else None,
  625. "disabled_by": document.disabled_by,
  626. "archived": document.archived,
  627. "doc_type": document.doc_type,
  628. "doc_metadata": document.doc_metadata_details,
  629. "segment_count": document.segment_count,
  630. "average_segment_length": document.average_segment_length,
  631. "hit_count": document.hit_count,
  632. "display_status": document.display_status,
  633. "doc_form": document.doc_form,
  634. "doc_language": document.doc_language,
  635. }
  636. return response, 200
  637. @setup_required
  638. @login_required
  639. @account_initialization_required
  640. @cloud_edition_billing_rate_limit_check("knowledge")
  641. def delete(self, dataset_id, document_id):
  642. dataset_id = str(dataset_id)
  643. document_id = str(document_id)
  644. dataset = DatasetService.get_dataset(dataset_id)
  645. if dataset is None:
  646. raise NotFound("Dataset not found.")
  647. # check user's model setting
  648. DatasetService.check_dataset_model_setting(dataset)
  649. document = self.get_document(dataset_id, document_id)
  650. try:
  651. DocumentService.delete_document(document)
  652. except services.errors.document.DocumentIndexingError:
  653. raise DocumentIndexingError("Cannot delete document during indexing.")
  654. return {"result": "success"}, 204
  655. class DocumentProcessingApi(DocumentResource):
  656. @setup_required
  657. @login_required
  658. @account_initialization_required
  659. @cloud_edition_billing_rate_limit_check("knowledge")
  660. def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]):
  661. dataset_id = str(dataset_id)
  662. document_id = str(document_id)
  663. document = self.get_document(dataset_id, document_id)
  664. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  665. if not current_user.is_dataset_editor:
  666. raise Forbidden()
  667. if action == "pause":
  668. if document.indexing_status != "indexing":
  669. raise InvalidActionError("Document not in indexing state.")
  670. document.paused_by = current_user.id
  671. document.paused_at = naive_utc_now()
  672. document.is_paused = True
  673. db.session.commit()
  674. elif action == "resume":
  675. if document.indexing_status not in {"paused", "error"}:
  676. raise InvalidActionError("Document not in paused or error state.")
  677. document.paused_by = None
  678. document.paused_at = None
  679. document.is_paused = False
  680. db.session.commit()
  681. return {"result": "success"}, 200
  682. class DocumentMetadataApi(DocumentResource):
  683. @setup_required
  684. @login_required
  685. @account_initialization_required
  686. def put(self, dataset_id, document_id):
  687. dataset_id = str(dataset_id)
  688. document_id = str(document_id)
  689. document = self.get_document(dataset_id, document_id)
  690. req_data = request.get_json()
  691. doc_type = req_data.get("doc_type")
  692. doc_metadata = req_data.get("doc_metadata")
  693. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  694. if not current_user.is_dataset_editor:
  695. raise Forbidden()
  696. if doc_type is None or doc_metadata is None:
  697. raise ValueError("Both doc_type and doc_metadata must be provided.")
  698. if doc_type not in DocumentService.DOCUMENT_METADATA_SCHEMA:
  699. raise ValueError("Invalid doc_type.")
  700. if not isinstance(doc_metadata, dict):
  701. raise ValueError("doc_metadata must be a dictionary.")
  702. metadata_schema: dict = cast(dict, DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type])
  703. document.doc_metadata = {}
  704. if doc_type == "others":
  705. document.doc_metadata = doc_metadata
  706. else:
  707. for key, value_type in metadata_schema.items():
  708. value = doc_metadata.get(key)
  709. if value is not None and isinstance(value, value_type):
  710. document.doc_metadata[key] = value
  711. document.doc_type = doc_type
  712. document.updated_at = naive_utc_now()
  713. db.session.commit()
  714. return {"result": "success", "message": "Document metadata updated."}, 200
  715. class DocumentStatusApi(DocumentResource):
  716. @setup_required
  717. @login_required
  718. @account_initialization_required
  719. @cloud_edition_billing_resource_check("vector_space")
  720. @cloud_edition_billing_rate_limit_check("knowledge")
  721. def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
  722. dataset_id = str(dataset_id)
  723. dataset = DatasetService.get_dataset(dataset_id)
  724. if dataset is None:
  725. raise NotFound("Dataset not found.")
  726. # The role of the current user in the ta table must be admin, owner, or editor
  727. if not current_user.is_dataset_editor:
  728. raise Forbidden()
  729. # check user's model setting
  730. DatasetService.check_dataset_model_setting(dataset)
  731. # check user's permission
  732. DatasetService.check_dataset_permission(dataset, current_user)
  733. document_ids = request.args.getlist("document_id")
  734. try:
  735. DocumentService.batch_update_document_status(dataset, document_ids, action, current_user)
  736. except services.errors.document.DocumentIndexingError as e:
  737. raise InvalidActionError(str(e))
  738. except ValueError as e:
  739. raise InvalidActionError(str(e))
  740. except NotFound as e:
  741. raise NotFound(str(e))
  742. return {"result": "success"}, 200
  743. class DocumentPauseApi(DocumentResource):
  744. @setup_required
  745. @login_required
  746. @account_initialization_required
  747. @cloud_edition_billing_rate_limit_check("knowledge")
  748. def patch(self, dataset_id, document_id):
  749. """pause document."""
  750. dataset_id = str(dataset_id)
  751. document_id = str(document_id)
  752. dataset = DatasetService.get_dataset(dataset_id)
  753. if not dataset:
  754. raise NotFound("Dataset not found.")
  755. document = DocumentService.get_document(dataset.id, document_id)
  756. # 404 if document not found
  757. if document is None:
  758. raise NotFound("Document Not Exists.")
  759. # 403 if document is archived
  760. if DocumentService.check_archived(document):
  761. raise ArchivedDocumentImmutableError()
  762. try:
  763. # pause document
  764. DocumentService.pause_document(document)
  765. except services.errors.document.DocumentIndexingError:
  766. raise DocumentIndexingError("Cannot pause completed document.")
  767. return {"result": "success"}, 204
  768. class DocumentRecoverApi(DocumentResource):
  769. @setup_required
  770. @login_required
  771. @account_initialization_required
  772. @cloud_edition_billing_rate_limit_check("knowledge")
  773. def patch(self, dataset_id, document_id):
  774. """recover document."""
  775. dataset_id = str(dataset_id)
  776. document_id = str(document_id)
  777. dataset = DatasetService.get_dataset(dataset_id)
  778. if not dataset:
  779. raise NotFound("Dataset not found.")
  780. document = DocumentService.get_document(dataset.id, document_id)
  781. # 404 if document not found
  782. if document is None:
  783. raise NotFound("Document Not Exists.")
  784. # 403 if document is archived
  785. if DocumentService.check_archived(document):
  786. raise ArchivedDocumentImmutableError()
  787. try:
  788. # pause document
  789. DocumentService.recover_document(document)
  790. except services.errors.document.DocumentIndexingError:
  791. raise DocumentIndexingError("Document is not in paused status.")
  792. return {"result": "success"}, 204
  793. class DocumentRetryApi(DocumentResource):
  794. @setup_required
  795. @login_required
  796. @account_initialization_required
  797. @cloud_edition_billing_rate_limit_check("knowledge")
  798. def post(self, dataset_id):
  799. """retry document."""
  800. parser = reqparse.RequestParser()
  801. parser.add_argument("document_ids", type=list, required=True, nullable=False, location="json")
  802. args = parser.parse_args()
  803. dataset_id = str(dataset_id)
  804. dataset = DatasetService.get_dataset(dataset_id)
  805. retry_documents = []
  806. if not dataset:
  807. raise NotFound("Dataset not found.")
  808. for document_id in args["document_ids"]:
  809. try:
  810. document_id = str(document_id)
  811. document = DocumentService.get_document(dataset.id, document_id)
  812. # 404 if document not found
  813. if document is None:
  814. raise NotFound("Document Not Exists.")
  815. # 403 if document is archived
  816. if DocumentService.check_archived(document):
  817. raise ArchivedDocumentImmutableError()
  818. # 400 if document is completed
  819. if document.indexing_status == "completed":
  820. raise DocumentAlreadyFinishedError()
  821. retry_documents.append(document)
  822. except Exception:
  823. logger.exception("Failed to retry document, document id: %s", document_id)
  824. continue
  825. # retry document
  826. DocumentService.retry_document(dataset_id, retry_documents)
  827. return {"result": "success"}, 204
  828. class DocumentRenameApi(DocumentResource):
  829. @setup_required
  830. @login_required
  831. @account_initialization_required
  832. @marshal_with(document_fields)
  833. def post(self, dataset_id, document_id):
  834. # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
  835. if not current_user.is_dataset_editor:
  836. raise Forbidden()
  837. dataset = DatasetService.get_dataset(dataset_id)
  838. DatasetService.check_dataset_operator_permission(current_user, dataset)
  839. parser = reqparse.RequestParser()
  840. parser.add_argument("name", type=str, required=True, nullable=False, location="json")
  841. args = parser.parse_args()
  842. try:
  843. document = DocumentService.rename_document(dataset_id, document_id, args["name"])
  844. except services.errors.document.DocumentIndexingError:
  845. raise DocumentIndexingError("Cannot delete document during indexing.")
  846. return document
  847. class WebsiteDocumentSyncApi(DocumentResource):
  848. @setup_required
  849. @login_required
  850. @account_initialization_required
  851. def get(self, dataset_id, document_id):
  852. """sync website document."""
  853. dataset_id = str(dataset_id)
  854. dataset = DatasetService.get_dataset(dataset_id)
  855. if not dataset:
  856. raise NotFound("Dataset not found.")
  857. document_id = str(document_id)
  858. document = DocumentService.get_document(dataset.id, document_id)
  859. if not document:
  860. raise NotFound("Document not found.")
  861. if document.tenant_id != current_user.current_tenant_id:
  862. raise Forbidden("No permission.")
  863. if document.data_source_type != "website_crawl":
  864. raise ValueError("Document is not a website document.")
  865. # 403 if document is archived
  866. if DocumentService.check_archived(document):
  867. raise ArchivedDocumentImmutableError()
  868. # sync document
  869. DocumentService.sync_website_document(dataset_id, document)
  870. return {"result": "success"}, 200
  871. api.add_resource(GetProcessRuleApi, "/datasets/process-rule")
  872. api.add_resource(DatasetDocumentListApi, "/datasets/<uuid:dataset_id>/documents")
  873. api.add_resource(DatasetInitApi, "/datasets/init")
  874. api.add_resource(
  875. DocumentIndexingEstimateApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate"
  876. )
  877. api.add_resource(DocumentBatchIndexingEstimateApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-estimate")
  878. api.add_resource(DocumentBatchIndexingStatusApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status")
  879. api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
  880. api.add_resource(DocumentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
  881. api.add_resource(
  882. DocumentProcessingApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>"
  883. )
  884. api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
  885. api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>/batch")
  886. api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause")
  887. api.add_resource(DocumentRecoverApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume")
  888. api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")
  889. api.add_resource(DocumentRenameApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
  890. api.add_resource(WebsiteDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")