Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033
  1. import logging
  2. from argparse import ArgumentTypeError
  3. from typing import Literal, cast
  4. from flask import request
  5. from flask_login import current_user
  6. from flask_restx import Resource, marshal, marshal_with, reqparse
  7. from sqlalchemy import asc, desc, select
  8. from werkzeug.exceptions import Forbidden, NotFound
  9. import services
  10. from controllers.console import api
  11. from controllers.console.app.error import (
  12. ProviderModelCurrentlyNotSupportError,
  13. ProviderNotInitializeError,
  14. ProviderQuotaExceededError,
  15. )
  16. from controllers.console.datasets.error import (
  17. ArchivedDocumentImmutableError,
  18. DocumentAlreadyFinishedError,
  19. DocumentIndexingError,
  20. IndexingEstimateError,
  21. InvalidActionError,
  22. InvalidMetadataError,
  23. )
  24. from controllers.console.wraps import (
  25. account_initialization_required,
  26. cloud_edition_billing_rate_limit_check,
  27. cloud_edition_billing_resource_check,
  28. setup_required,
  29. )
  30. from core.errors.error import (
  31. LLMBadRequestError,
  32. ModelCurrentlyNotSupportError,
  33. ProviderTokenNotInitError,
  34. QuotaExceededError,
  35. )
  36. from core.indexing_runner import IndexingRunner
  37. from core.model_manager import ModelManager
  38. from core.model_runtime.entities.model_entities import ModelType
  39. from core.model_runtime.errors.invoke import InvokeAuthorizationError
  40. from core.plugin.impl.exc import PluginDaemonClientSideError
  41. from core.rag.extractor.entity.datasource_type import DatasourceType
  42. from core.rag.extractor.entity.extract_setting import ExtractSetting
  43. from extensions.ext_database import db
  44. from fields.document_fields import (
  45. dataset_and_document_fields,
  46. document_fields,
  47. document_status_fields,
  48. document_with_segments_fields,
  49. )
  50. from libs.datetime_utils import naive_utc_now
  51. from libs.login import login_required
  52. from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
  53. from services.dataset_service import DatasetService, DocumentService
  54. from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
  55. logger = logging.getLogger(__name__)
  56. class DocumentResource(Resource):
  57. def get_document(self, dataset_id: str, document_id: str) -> Document:
  58. dataset = DatasetService.get_dataset(dataset_id)
  59. if not dataset:
  60. raise NotFound("Dataset not found.")
  61. try:
  62. DatasetService.check_dataset_permission(dataset, current_user)
  63. except services.errors.account.NoPermissionError as e:
  64. raise Forbidden(str(e))
  65. document = DocumentService.get_document(dataset_id, document_id)
  66. if not document:
  67. raise NotFound("Document not found.")
  68. if document.tenant_id != current_user.current_tenant_id:
  69. raise Forbidden("No permission.")
  70. return document
  71. def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]:
  72. dataset = DatasetService.get_dataset(dataset_id)
  73. if not dataset:
  74. raise NotFound("Dataset not found.")
  75. try:
  76. DatasetService.check_dataset_permission(dataset, current_user)
  77. except services.errors.account.NoPermissionError as e:
  78. raise Forbidden(str(e))
  79. documents = DocumentService.get_batch_documents(dataset_id, batch)
  80. if not documents:
  81. raise NotFound("Documents not found.")
  82. return documents
  83. class GetProcessRuleApi(Resource):
  84. @setup_required
  85. @login_required
  86. @account_initialization_required
  87. def get(self):
  88. req_data = request.args
  89. document_id = req_data.get("document_id")
  90. # get default rules
  91. mode = DocumentService.DEFAULT_RULES["mode"]
  92. rules = DocumentService.DEFAULT_RULES["rules"]
  93. limits = DocumentService.DEFAULT_RULES["limits"]
  94. if document_id:
  95. # get the latest process rule
  96. document = db.get_or_404(Document, document_id)
  97. dataset = DatasetService.get_dataset(document.dataset_id)
  98. if not dataset:
  99. raise NotFound("Dataset not found.")
  100. try:
  101. DatasetService.check_dataset_permission(dataset, current_user)
  102. except services.errors.account.NoPermissionError as e:
  103. raise Forbidden(str(e))
  104. # get the latest process rule
  105. dataset_process_rule = (
  106. db.session.query(DatasetProcessRule)
  107. .where(DatasetProcessRule.dataset_id == document.dataset_id)
  108. .order_by(DatasetProcessRule.created_at.desc())
  109. .limit(1)
  110. .one_or_none()
  111. )
  112. if dataset_process_rule:
  113. mode = dataset_process_rule.mode
  114. rules = dataset_process_rule.rules_dict
  115. return {"mode": mode, "rules": rules, "limits": limits}
  116. class DatasetDocumentListApi(Resource):
  117. @setup_required
  118. @login_required
  119. @account_initialization_required
  120. def get(self, dataset_id):
  121. dataset_id = str(dataset_id)
  122. page = request.args.get("page", default=1, type=int)
  123. limit = request.args.get("limit", default=20, type=int)
  124. search = request.args.get("keyword", default=None, type=str)
  125. sort = request.args.get("sort", default="-created_at", type=str)
  126. # "yes", "true", "t", "y", "1" convert to True, while others convert to False.
  127. try:
  128. fetch_val = request.args.get("fetch", default="false")
  129. if isinstance(fetch_val, bool):
  130. fetch = fetch_val
  131. else:
  132. if fetch_val.lower() in ("yes", "true", "t", "y", "1"):
  133. fetch = True
  134. elif fetch_val.lower() in ("no", "false", "f", "n", "0"):
  135. fetch = False
  136. else:
  137. raise ArgumentTypeError(
  138. f"Truthy value expected: got {fetch_val} but expected one of yes/no, true/false, t/f, y/n, 1/0 "
  139. f"(case insensitive)."
  140. )
  141. except (ArgumentTypeError, ValueError, Exception):
  142. fetch = False
  143. dataset = DatasetService.get_dataset(dataset_id)
  144. if not dataset:
  145. raise NotFound("Dataset not found.")
  146. try:
  147. DatasetService.check_dataset_permission(dataset, current_user)
  148. except services.errors.account.NoPermissionError as e:
  149. raise Forbidden(str(e))
  150. query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
  151. if search:
  152. search = f"%{search}%"
  153. query = query.where(Document.name.like(search))
  154. if sort.startswith("-"):
  155. sort_logic = desc
  156. sort = sort[1:]
  157. else:
  158. sort_logic = asc
  159. if sort == "hit_count":
  160. sub_query = (
  161. db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
  162. .group_by(DocumentSegment.document_id)
  163. .subquery()
  164. )
  165. query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
  166. sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
  167. sort_logic(Document.position),
  168. )
  169. elif sort == "created_at":
  170. query = query.order_by(
  171. sort_logic(Document.created_at),
  172. sort_logic(Document.position),
  173. )
  174. else:
  175. query = query.order_by(
  176. desc(Document.created_at),
  177. desc(Document.position),
  178. )
  179. paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
  180. documents = paginated_documents.items
  181. if fetch:
  182. for document in documents:
  183. completed_segments = (
  184. db.session.query(DocumentSegment)
  185. .where(
  186. DocumentSegment.completed_at.isnot(None),
  187. DocumentSegment.document_id == str(document.id),
  188. DocumentSegment.status != "re_segment",
  189. )
  190. .count()
  191. )
  192. total_segments = (
  193. db.session.query(DocumentSegment)
  194. .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
  195. .count()
  196. )
  197. document.completed_segments = completed_segments
  198. document.total_segments = total_segments
  199. data = marshal(documents, document_with_segments_fields)
  200. else:
  201. data = marshal(documents, document_fields)
  202. response = {
  203. "data": data,
  204. "has_more": len(documents) == limit,
  205. "limit": limit,
  206. "total": paginated_documents.total,
  207. "page": page,
  208. }
  209. return response
  210. @setup_required
  211. @login_required
  212. @account_initialization_required
  213. @marshal_with(dataset_and_document_fields)
  214. @cloud_edition_billing_resource_check("vector_space")
  215. @cloud_edition_billing_rate_limit_check("knowledge")
  216. def post(self, dataset_id):
  217. dataset_id = str(dataset_id)
  218. dataset = DatasetService.get_dataset(dataset_id)
  219. if not dataset:
  220. raise NotFound("Dataset not found.")
  221. # The role of the current user in the ta table must be admin, owner, or editor
  222. if not current_user.is_dataset_editor:
  223. raise Forbidden()
  224. try:
  225. DatasetService.check_dataset_permission(dataset, current_user)
  226. except services.errors.account.NoPermissionError as e:
  227. raise Forbidden(str(e))
  228. parser = reqparse.RequestParser()
  229. parser.add_argument(
  230. "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
  231. )
  232. parser.add_argument("data_source", type=dict, required=False, location="json")
  233. parser.add_argument("process_rule", type=dict, required=False, location="json")
  234. parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
  235. parser.add_argument("original_document_id", type=str, required=False, location="json")
  236. parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
  237. parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
  238. parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
  239. parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
  240. parser.add_argument(
  241. "doc_language", type=str, default="English", required=False, nullable=False, location="json"
  242. )
  243. args = parser.parse_args()
  244. knowledge_config = KnowledgeConfig(**args)
  245. if not dataset.indexing_technique and not knowledge_config.indexing_technique:
  246. raise ValueError("indexing_technique is required.")
  247. # validate args
  248. DocumentService.document_create_args_validate(knowledge_config)
  249. try:
  250. documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user)
  251. dataset = DatasetService.get_dataset(dataset_id)
  252. except ProviderTokenNotInitError as ex:
  253. raise ProviderNotInitializeError(ex.description)
  254. except QuotaExceededError:
  255. raise ProviderQuotaExceededError()
  256. except ModelCurrentlyNotSupportError:
  257. raise ProviderModelCurrentlyNotSupportError()
  258. return {"dataset": dataset, "documents": documents, "batch": batch}
  259. @setup_required
  260. @login_required
  261. @account_initialization_required
  262. @cloud_edition_billing_rate_limit_check("knowledge")
  263. def delete(self, dataset_id):
  264. dataset_id = str(dataset_id)
  265. dataset = DatasetService.get_dataset(dataset_id)
  266. if dataset is None:
  267. raise NotFound("Dataset not found.")
  268. # check user's model setting
  269. DatasetService.check_dataset_model_setting(dataset)
  270. try:
  271. document_ids = request.args.getlist("document_id")
  272. DocumentService.delete_documents(dataset, document_ids)
  273. except services.errors.document.DocumentIndexingError:
  274. raise DocumentIndexingError("Cannot delete document during indexing.")
  275. return {"result": "success"}, 204
  276. class DatasetInitApi(Resource):
  277. @setup_required
  278. @login_required
  279. @account_initialization_required
  280. @marshal_with(dataset_and_document_fields)
  281. @cloud_edition_billing_resource_check("vector_space")
  282. @cloud_edition_billing_rate_limit_check("knowledge")
  283. def post(self):
  284. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  285. if not current_user.is_dataset_editor:
  286. raise Forbidden()
  287. parser = reqparse.RequestParser()
  288. parser.add_argument(
  289. "indexing_technique",
  290. type=str,
  291. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  292. required=True,
  293. nullable=False,
  294. location="json",
  295. )
  296. parser.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
  297. parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
  298. parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
  299. parser.add_argument(
  300. "doc_language", type=str, default="English", required=False, nullable=False, location="json"
  301. )
  302. parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
  303. parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
  304. parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
  305. args = parser.parse_args()
  306. knowledge_config = KnowledgeConfig(**args)
  307. if knowledge_config.indexing_technique == "high_quality":
  308. if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
  309. raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
  310. try:
  311. model_manager = ModelManager()
  312. model_manager.get_model_instance(
  313. tenant_id=current_user.current_tenant_id,
  314. provider=args["embedding_model_provider"],
  315. model_type=ModelType.TEXT_EMBEDDING,
  316. model=args["embedding_model"],
  317. )
  318. except InvokeAuthorizationError:
  319. raise ProviderNotInitializeError(
  320. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  321. )
  322. except ProviderTokenNotInitError as ex:
  323. raise ProviderNotInitializeError(ex.description)
  324. # validate args
  325. DocumentService.document_create_args_validate(knowledge_config)
  326. try:
  327. dataset, documents, batch = DocumentService.save_document_without_dataset_id(
  328. tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user
  329. )
  330. except ProviderTokenNotInitError as ex:
  331. raise ProviderNotInitializeError(ex.description)
  332. except QuotaExceededError:
  333. raise ProviderQuotaExceededError()
  334. except ModelCurrentlyNotSupportError:
  335. raise ProviderModelCurrentlyNotSupportError()
  336. response = {"dataset": dataset, "documents": documents, "batch": batch}
  337. return response
  338. class DocumentIndexingEstimateApi(DocumentResource):
  339. @setup_required
  340. @login_required
  341. @account_initialization_required
  342. def get(self, dataset_id, document_id):
  343. dataset_id = str(dataset_id)
  344. document_id = str(document_id)
  345. document = self.get_document(dataset_id, document_id)
  346. if document.indexing_status in {"completed", "error"}:
  347. raise DocumentAlreadyFinishedError()
  348. data_process_rule = document.dataset_process_rule
  349. data_process_rule_dict = data_process_rule.to_dict()
  350. response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}
  351. if document.data_source_type == "upload_file":
  352. data_source_info = document.data_source_info_dict
  353. if data_source_info and "upload_file_id" in data_source_info:
  354. file_id = data_source_info["upload_file_id"]
  355. file = (
  356. db.session.query(UploadFile)
  357. .where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
  358. .first()
  359. )
  360. # raise error if file not found
  361. if not file:
  362. raise NotFound("File not found.")
  363. extract_setting = ExtractSetting(
  364. datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form
  365. )
  366. indexing_runner = IndexingRunner()
  367. try:
  368. estimate_response = indexing_runner.indexing_estimate(
  369. current_user.current_tenant_id,
  370. [extract_setting],
  371. data_process_rule_dict,
  372. document.doc_form,
  373. "English",
  374. dataset_id,
  375. )
  376. return estimate_response.model_dump(), 200
  377. except LLMBadRequestError:
  378. raise ProviderNotInitializeError(
  379. "No Embedding Model available. Please configure a valid provider "
  380. "in the Settings -> Model Provider."
  381. )
  382. except ProviderTokenNotInitError as ex:
  383. raise ProviderNotInitializeError(ex.description)
  384. except PluginDaemonClientSideError as ex:
  385. raise ProviderNotInitializeError(ex.description)
  386. except Exception as e:
  387. raise IndexingEstimateError(str(e))
  388. return response, 200
  389. class DocumentBatchIndexingEstimateApi(DocumentResource):
  390. @setup_required
  391. @login_required
  392. @account_initialization_required
  393. def get(self, dataset_id, batch):
  394. dataset_id = str(dataset_id)
  395. batch = str(batch)
  396. documents = self.get_batch_documents(dataset_id, batch)
  397. if not documents:
  398. return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
  399. data_process_rule = documents[0].dataset_process_rule
  400. data_process_rule_dict = data_process_rule.to_dict()
  401. extract_settings = []
  402. for document in documents:
  403. if document.indexing_status in {"completed", "error"}:
  404. raise DocumentAlreadyFinishedError()
  405. data_source_info = document.data_source_info_dict
  406. if document.data_source_type == "upload_file":
  407. file_id = data_source_info["upload_file_id"]
  408. file_detail = (
  409. db.session.query(UploadFile)
  410. .where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id)
  411. .first()
  412. )
  413. if file_detail is None:
  414. raise NotFound("File not found.")
  415. extract_setting = ExtractSetting(
  416. datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form
  417. )
  418. extract_settings.append(extract_setting)
  419. elif document.data_source_type == "notion_import":
  420. extract_setting = ExtractSetting(
  421. datasource_type=DatasourceType.NOTION.value,
  422. notion_info={
  423. "notion_workspace_id": data_source_info["notion_workspace_id"],
  424. "notion_obj_id": data_source_info["notion_page_id"],
  425. "notion_page_type": data_source_info["type"],
  426. "tenant_id": current_user.current_tenant_id,
  427. },
  428. document_model=document.doc_form,
  429. )
  430. extract_settings.append(extract_setting)
  431. elif document.data_source_type == "website_crawl":
  432. extract_setting = ExtractSetting(
  433. datasource_type=DatasourceType.WEBSITE.value,
  434. website_info={
  435. "provider": data_source_info["provider"],
  436. "job_id": data_source_info["job_id"],
  437. "url": data_source_info["url"],
  438. "tenant_id": current_user.current_tenant_id,
  439. "mode": data_source_info["mode"],
  440. "only_main_content": data_source_info["only_main_content"],
  441. },
  442. document_model=document.doc_form,
  443. )
  444. extract_settings.append(extract_setting)
  445. else:
  446. raise ValueError("Data source type not support")
  447. indexing_runner = IndexingRunner()
  448. try:
  449. response = indexing_runner.indexing_estimate(
  450. current_user.current_tenant_id,
  451. extract_settings,
  452. data_process_rule_dict,
  453. document.doc_form,
  454. "English",
  455. dataset_id,
  456. )
  457. return response.model_dump(), 200
  458. except LLMBadRequestError:
  459. raise ProviderNotInitializeError(
  460. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  461. )
  462. except ProviderTokenNotInitError as ex:
  463. raise ProviderNotInitializeError(ex.description)
  464. except PluginDaemonClientSideError as ex:
  465. raise ProviderNotInitializeError(ex.description)
  466. except Exception as e:
  467. raise IndexingEstimateError(str(e))
  468. class DocumentBatchIndexingStatusApi(DocumentResource):
  469. @setup_required
  470. @login_required
  471. @account_initialization_required
  472. def get(self, dataset_id, batch):
  473. dataset_id = str(dataset_id)
  474. batch = str(batch)
  475. documents = self.get_batch_documents(dataset_id, batch)
  476. documents_status = []
  477. for document in documents:
  478. completed_segments = (
  479. db.session.query(DocumentSegment)
  480. .where(
  481. DocumentSegment.completed_at.isnot(None),
  482. DocumentSegment.document_id == str(document.id),
  483. DocumentSegment.status != "re_segment",
  484. )
  485. .count()
  486. )
  487. total_segments = (
  488. db.session.query(DocumentSegment)
  489. .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
  490. .count()
  491. )
  492. # Create a dictionary with document attributes and additional fields
  493. document_dict = {
  494. "id": document.id,
  495. "indexing_status": "paused" if document.is_paused else document.indexing_status,
  496. "processing_started_at": document.processing_started_at,
  497. "parsing_completed_at": document.parsing_completed_at,
  498. "cleaning_completed_at": document.cleaning_completed_at,
  499. "splitting_completed_at": document.splitting_completed_at,
  500. "completed_at": document.completed_at,
  501. "paused_at": document.paused_at,
  502. "error": document.error,
  503. "stopped_at": document.stopped_at,
  504. "completed_segments": completed_segments,
  505. "total_segments": total_segments,
  506. }
  507. documents_status.append(marshal(document_dict, document_status_fields))
  508. data = {"data": documents_status}
  509. return data
  510. class DocumentIndexingStatusApi(DocumentResource):
  511. @setup_required
  512. @login_required
  513. @account_initialization_required
  514. def get(self, dataset_id, document_id):
  515. dataset_id = str(dataset_id)
  516. document_id = str(document_id)
  517. document = self.get_document(dataset_id, document_id)
  518. completed_segments = (
  519. db.session.query(DocumentSegment)
  520. .where(
  521. DocumentSegment.completed_at.isnot(None),
  522. DocumentSegment.document_id == str(document_id),
  523. DocumentSegment.status != "re_segment",
  524. )
  525. .count()
  526. )
  527. total_segments = (
  528. db.session.query(DocumentSegment)
  529. .where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
  530. .count()
  531. )
  532. # Create a dictionary with document attributes and additional fields
  533. document_dict = {
  534. "id": document.id,
  535. "indexing_status": "paused" if document.is_paused else document.indexing_status,
  536. "processing_started_at": document.processing_started_at,
  537. "parsing_completed_at": document.parsing_completed_at,
  538. "cleaning_completed_at": document.cleaning_completed_at,
  539. "splitting_completed_at": document.splitting_completed_at,
  540. "completed_at": document.completed_at,
  541. "paused_at": document.paused_at,
  542. "error": document.error,
  543. "stopped_at": document.stopped_at,
  544. "completed_segments": completed_segments,
  545. "total_segments": total_segments,
  546. }
  547. return marshal(document_dict, document_status_fields)
  548. class DocumentApi(DocumentResource):
  549. METADATA_CHOICES = {"all", "only", "without"}
  550. @setup_required
  551. @login_required
  552. @account_initialization_required
  553. def get(self, dataset_id, document_id):
  554. dataset_id = str(dataset_id)
  555. document_id = str(document_id)
  556. document = self.get_document(dataset_id, document_id)
  557. metadata = request.args.get("metadata", "all")
  558. if metadata not in self.METADATA_CHOICES:
  559. raise InvalidMetadataError(f"Invalid metadata value: {metadata}")
  560. if metadata == "only":
  561. response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
  562. elif metadata == "without":
  563. dataset_process_rules = DatasetService.get_process_rules(dataset_id)
  564. document_process_rules = document.dataset_process_rule.to_dict()
  565. data_source_info = document.data_source_detail_dict
  566. response = {
  567. "id": document.id,
  568. "position": document.position,
  569. "data_source_type": document.data_source_type,
  570. "data_source_info": data_source_info,
  571. "dataset_process_rule_id": document.dataset_process_rule_id,
  572. "dataset_process_rule": dataset_process_rules,
  573. "document_process_rule": document_process_rules,
  574. "name": document.name,
  575. "created_from": document.created_from,
  576. "created_by": document.created_by,
  577. "created_at": document.created_at.timestamp(),
  578. "tokens": document.tokens,
  579. "indexing_status": document.indexing_status,
  580. "completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
  581. "updated_at": int(document.updated_at.timestamp()) if document.updated_at else None,
  582. "indexing_latency": document.indexing_latency,
  583. "error": document.error,
  584. "enabled": document.enabled,
  585. "disabled_at": int(document.disabled_at.timestamp()) if document.disabled_at else None,
  586. "disabled_by": document.disabled_by,
  587. "archived": document.archived,
  588. "segment_count": document.segment_count,
  589. "average_segment_length": document.average_segment_length,
  590. "hit_count": document.hit_count,
  591. "display_status": document.display_status,
  592. "doc_form": document.doc_form,
  593. "doc_language": document.doc_language,
  594. }
  595. else:
  596. dataset_process_rules = DatasetService.get_process_rules(dataset_id)
  597. document_process_rules = document.dataset_process_rule.to_dict()
  598. data_source_info = document.data_source_detail_dict
  599. response = {
  600. "id": document.id,
  601. "position": document.position,
  602. "data_source_type": document.data_source_type,
  603. "data_source_info": data_source_info,
  604. "dataset_process_rule_id": document.dataset_process_rule_id,
  605. "dataset_process_rule": dataset_process_rules,
  606. "document_process_rule": document_process_rules,
  607. "name": document.name,
  608. "created_from": document.created_from,
  609. "created_by": document.created_by,
  610. "created_at": document.created_at.timestamp(),
  611. "tokens": document.tokens,
  612. "indexing_status": document.indexing_status,
  613. "completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
  614. "updated_at": int(document.updated_at.timestamp()) if document.updated_at else None,
  615. "indexing_latency": document.indexing_latency,
  616. "error": document.error,
  617. "enabled": document.enabled,
  618. "disabled_at": int(document.disabled_at.timestamp()) if document.disabled_at else None,
  619. "disabled_by": document.disabled_by,
  620. "archived": document.archived,
  621. "doc_type": document.doc_type,
  622. "doc_metadata": document.doc_metadata_details,
  623. "segment_count": document.segment_count,
  624. "average_segment_length": document.average_segment_length,
  625. "hit_count": document.hit_count,
  626. "display_status": document.display_status,
  627. "doc_form": document.doc_form,
  628. "doc_language": document.doc_language,
  629. }
  630. return response, 200
  631. @setup_required
  632. @login_required
  633. @account_initialization_required
  634. @cloud_edition_billing_rate_limit_check("knowledge")
  635. def delete(self, dataset_id, document_id):
  636. dataset_id = str(dataset_id)
  637. document_id = str(document_id)
  638. dataset = DatasetService.get_dataset(dataset_id)
  639. if dataset is None:
  640. raise NotFound("Dataset not found.")
  641. # check user's model setting
  642. DatasetService.check_dataset_model_setting(dataset)
  643. document = self.get_document(dataset_id, document_id)
  644. try:
  645. DocumentService.delete_document(document)
  646. except services.errors.document.DocumentIndexingError:
  647. raise DocumentIndexingError("Cannot delete document during indexing.")
  648. return {"result": "success"}, 204
  649. class DocumentProcessingApi(DocumentResource):
  650. @setup_required
  651. @login_required
  652. @account_initialization_required
  653. @cloud_edition_billing_rate_limit_check("knowledge")
  654. def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]):
  655. dataset_id = str(dataset_id)
  656. document_id = str(document_id)
  657. document = self.get_document(dataset_id, document_id)
  658. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  659. if not current_user.is_dataset_editor:
  660. raise Forbidden()
  661. if action == "pause":
  662. if document.indexing_status != "indexing":
  663. raise InvalidActionError("Document not in indexing state.")
  664. document.paused_by = current_user.id
  665. document.paused_at = naive_utc_now()
  666. document.is_paused = True
  667. db.session.commit()
  668. elif action == "resume":
  669. if document.indexing_status not in {"paused", "error"}:
  670. raise InvalidActionError("Document not in paused or error state.")
  671. document.paused_by = None
  672. document.paused_at = None
  673. document.is_paused = False
  674. db.session.commit()
  675. return {"result": "success"}, 200
  676. class DocumentMetadataApi(DocumentResource):
  677. @setup_required
  678. @login_required
  679. @account_initialization_required
  680. def put(self, dataset_id, document_id):
  681. dataset_id = str(dataset_id)
  682. document_id = str(document_id)
  683. document = self.get_document(dataset_id, document_id)
  684. req_data = request.get_json()
  685. doc_type = req_data.get("doc_type")
  686. doc_metadata = req_data.get("doc_metadata")
  687. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  688. if not current_user.is_dataset_editor:
  689. raise Forbidden()
  690. if doc_type is None or doc_metadata is None:
  691. raise ValueError("Both doc_type and doc_metadata must be provided.")
  692. if doc_type not in DocumentService.DOCUMENT_METADATA_SCHEMA:
  693. raise ValueError("Invalid doc_type.")
  694. if not isinstance(doc_metadata, dict):
  695. raise ValueError("doc_metadata must be a dictionary.")
  696. metadata_schema: dict = cast(dict, DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type])
  697. document.doc_metadata = {}
  698. if doc_type == "others":
  699. document.doc_metadata = doc_metadata
  700. else:
  701. for key, value_type in metadata_schema.items():
  702. value = doc_metadata.get(key)
  703. if value is not None and isinstance(value, value_type):
  704. document.doc_metadata[key] = value
  705. document.doc_type = doc_type
  706. document.updated_at = naive_utc_now()
  707. db.session.commit()
  708. return {"result": "success", "message": "Document metadata updated."}, 200
  709. class DocumentStatusApi(DocumentResource):
  710. @setup_required
  711. @login_required
  712. @account_initialization_required
  713. @cloud_edition_billing_resource_check("vector_space")
  714. @cloud_edition_billing_rate_limit_check("knowledge")
  715. def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
  716. dataset_id = str(dataset_id)
  717. dataset = DatasetService.get_dataset(dataset_id)
  718. if dataset is None:
  719. raise NotFound("Dataset not found.")
  720. # The role of the current user in the ta table must be admin, owner, or editor
  721. if not current_user.is_dataset_editor:
  722. raise Forbidden()
  723. # check user's model setting
  724. DatasetService.check_dataset_model_setting(dataset)
  725. # check user's permission
  726. DatasetService.check_dataset_permission(dataset, current_user)
  727. document_ids = request.args.getlist("document_id")
  728. try:
  729. DocumentService.batch_update_document_status(dataset, document_ids, action, current_user)
  730. except services.errors.document.DocumentIndexingError as e:
  731. raise InvalidActionError(str(e))
  732. except ValueError as e:
  733. raise InvalidActionError(str(e))
  734. except NotFound as e:
  735. raise NotFound(str(e))
  736. return {"result": "success"}, 200
  737. class DocumentPauseApi(DocumentResource):
  738. @setup_required
  739. @login_required
  740. @account_initialization_required
  741. @cloud_edition_billing_rate_limit_check("knowledge")
  742. def patch(self, dataset_id, document_id):
  743. """pause document."""
  744. dataset_id = str(dataset_id)
  745. document_id = str(document_id)
  746. dataset = DatasetService.get_dataset(dataset_id)
  747. if not dataset:
  748. raise NotFound("Dataset not found.")
  749. document = DocumentService.get_document(dataset.id, document_id)
  750. # 404 if document not found
  751. if document is None:
  752. raise NotFound("Document Not Exists.")
  753. # 403 if document is archived
  754. if DocumentService.check_archived(document):
  755. raise ArchivedDocumentImmutableError()
  756. try:
  757. # pause document
  758. DocumentService.pause_document(document)
  759. except services.errors.document.DocumentIndexingError:
  760. raise DocumentIndexingError("Cannot pause completed document.")
  761. return {"result": "success"}, 204
  762. class DocumentRecoverApi(DocumentResource):
  763. @setup_required
  764. @login_required
  765. @account_initialization_required
  766. @cloud_edition_billing_rate_limit_check("knowledge")
  767. def patch(self, dataset_id, document_id):
  768. """recover document."""
  769. dataset_id = str(dataset_id)
  770. document_id = str(document_id)
  771. dataset = DatasetService.get_dataset(dataset_id)
  772. if not dataset:
  773. raise NotFound("Dataset not found.")
  774. document = DocumentService.get_document(dataset.id, document_id)
  775. # 404 if document not found
  776. if document is None:
  777. raise NotFound("Document Not Exists.")
  778. # 403 if document is archived
  779. if DocumentService.check_archived(document):
  780. raise ArchivedDocumentImmutableError()
  781. try:
  782. # pause document
  783. DocumentService.recover_document(document)
  784. except services.errors.document.DocumentIndexingError:
  785. raise DocumentIndexingError("Document is not in paused status.")
  786. return {"result": "success"}, 204
  787. class DocumentRetryApi(DocumentResource):
  788. @setup_required
  789. @login_required
  790. @account_initialization_required
  791. @cloud_edition_billing_rate_limit_check("knowledge")
  792. def post(self, dataset_id):
  793. """retry document."""
  794. parser = reqparse.RequestParser()
  795. parser.add_argument("document_ids", type=list, required=True, nullable=False, location="json")
  796. args = parser.parse_args()
  797. dataset_id = str(dataset_id)
  798. dataset = DatasetService.get_dataset(dataset_id)
  799. retry_documents = []
  800. if not dataset:
  801. raise NotFound("Dataset not found.")
  802. for document_id in args["document_ids"]:
  803. try:
  804. document_id = str(document_id)
  805. document = DocumentService.get_document(dataset.id, document_id)
  806. # 404 if document not found
  807. if document is None:
  808. raise NotFound("Document Not Exists.")
  809. # 403 if document is archived
  810. if DocumentService.check_archived(document):
  811. raise ArchivedDocumentImmutableError()
  812. # 400 if document is completed
  813. if document.indexing_status == "completed":
  814. raise DocumentAlreadyFinishedError()
  815. retry_documents.append(document)
  816. except Exception:
  817. logger.exception("Failed to retry document, document id: %s", document_id)
  818. continue
  819. # retry document
  820. DocumentService.retry_document(dataset_id, retry_documents)
  821. return {"result": "success"}, 204
  822. class DocumentRenameApi(DocumentResource):
  823. @setup_required
  824. @login_required
  825. @account_initialization_required
  826. @marshal_with(document_fields)
  827. def post(self, dataset_id, document_id):
  828. # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
  829. if not current_user.is_dataset_editor:
  830. raise Forbidden()
  831. dataset = DatasetService.get_dataset(dataset_id)
  832. DatasetService.check_dataset_operator_permission(current_user, dataset)
  833. parser = reqparse.RequestParser()
  834. parser.add_argument("name", type=str, required=True, nullable=False, location="json")
  835. args = parser.parse_args()
  836. try:
  837. document = DocumentService.rename_document(dataset_id, document_id, args["name"])
  838. except services.errors.document.DocumentIndexingError:
  839. raise DocumentIndexingError("Cannot delete document during indexing.")
  840. return document
  841. class WebsiteDocumentSyncApi(DocumentResource):
  842. @setup_required
  843. @login_required
  844. @account_initialization_required
  845. def get(self, dataset_id, document_id):
  846. """sync website document."""
  847. dataset_id = str(dataset_id)
  848. dataset = DatasetService.get_dataset(dataset_id)
  849. if not dataset:
  850. raise NotFound("Dataset not found.")
  851. document_id = str(document_id)
  852. document = DocumentService.get_document(dataset.id, document_id)
  853. if not document:
  854. raise NotFound("Document not found.")
  855. if document.tenant_id != current_user.current_tenant_id:
  856. raise Forbidden("No permission.")
  857. if document.data_source_type != "website_crawl":
  858. raise ValueError("Document is not a website document.")
  859. # 403 if document is archived
  860. if DocumentService.check_archived(document):
  861. raise ArchivedDocumentImmutableError()
  862. # sync document
  863. DocumentService.sync_website_document(dataset_id, document)
  864. return {"result": "success"}, 200
  865. api.add_resource(GetProcessRuleApi, "/datasets/process-rule")
  866. api.add_resource(DatasetDocumentListApi, "/datasets/<uuid:dataset_id>/documents")
  867. api.add_resource(DatasetInitApi, "/datasets/init")
  868. api.add_resource(
  869. DocumentIndexingEstimateApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate"
  870. )
  871. api.add_resource(DocumentBatchIndexingEstimateApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-estimate")
  872. api.add_resource(DocumentBatchIndexingStatusApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status")
  873. api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
  874. api.add_resource(DocumentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
  875. api.add_resource(
  876. DocumentProcessingApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>"
  877. )
  878. api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
  879. api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>/batch")
  880. api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause")
  881. api.add_resource(DocumentRecoverApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume")
  882. api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")
  883. api.add_resource(DocumentRenameApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
  884. api.add_resource(WebsiteDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")