Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

dataset.py 19KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. from flask import request
  18. from peewee import OperationalError
  19. from pydantic import ValidationError
  20. from api import settings
  21. from api.db import FileSource, StatusEnum
  22. from api.db.db_models import File
  23. from api.db.services.document_service import DocumentService
  24. from api.db.services.file2document_service import File2DocumentService
  25. from api.db.services.file_service import FileService
  26. from api.db.services.knowledgebase_service import KnowledgebaseService
  27. from api.db.services.llm_service import LLMService, TenantLLMService
  28. from api.db.services.user_service import TenantService
  29. from api.utils import get_uuid
  30. from api.utils.api_utils import (
  31. check_duplicate_ids,
  32. dataset_readonly_fields,
  33. get_error_argument_result,
  34. get_error_data_result,
  35. get_parser_config,
  36. get_result,
  37. token_required,
  38. valid,
  39. valid_parser_config,
  40. )
  41. from api.utils.validation_utils import CreateDatasetReq, format_validation_error_message
  42. @manager.route("/datasets", methods=["POST"]) # noqa: F821
  43. @token_required
  44. def create(tenant_id):
  45. """
  46. Create a new dataset.
  47. ---
  48. tags:
  49. - Datasets
  50. security:
  51. - ApiKeyAuth: []
  52. parameters:
  53. - in: header
  54. name: Authorization
  55. type: string
  56. required: true
  57. description: Bearer token for authentication.
  58. - in: body
  59. name: body
  60. description: Dataset creation parameters.
  61. required: true
  62. schema:
  63. type: object
  64. required:
  65. - name
  66. properties:
  67. name:
  68. type: string
  69. description: Name of the dataset.
  70. avatar:
  71. type: string
  72. description: Base64 encoding of the avatar.
  73. description:
  74. type: string
  75. description: Description of the dataset.
  76. embedding_model:
  77. type: string
  78. description: Embedding model Name.
  79. permission:
  80. type: string
  81. enum: ['me', 'team']
  82. description: Dataset permission.
  83. chunk_method:
  84. type: string
  85. enum: ["naive", "book", "email", "laws", "manual", "one", "paper",
  86. "picture", "presentation", "qa", "table", "tag"
  87. ]
  88. description: Chunking method.
  89. pagerank:
  90. type: integer
  91. description: Set page rank.
  92. parser_config:
  93. type: object
  94. description: Parser configuration.
  95. responses:
  96. 200:
  97. description: Successful operation.
  98. schema:
  99. type: object
  100. properties:
  101. data:
  102. type: object
  103. """
  104. req_i = request.json
  105. if not isinstance(req_i, dict):
  106. return get_error_argument_result(f"Invalid request payload: expected object, got {type(req_i).__name__}")
  107. try:
  108. req_v = CreateDatasetReq(**req_i)
  109. except ValidationError as e:
  110. return get_error_argument_result(format_validation_error_message(e))
  111. # Field name transformations during model dump:
  112. # | Original | Dump Output |
  113. # |----------------|-------------|
  114. # | embedding_model| embd_id |
  115. # | chunk_method | parser_id |
  116. req = req_v.model_dump(by_alias=True)
  117. try:
  118. if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
  119. return get_error_argument_result(message=f"Dataset name '{req['name']}' already exists")
  120. except OperationalError as e:
  121. logging.exception(e)
  122. return get_error_data_result(message="Database operation failed")
  123. req["parser_config"] = get_parser_config(req["parser_id"], req["parser_config"])
  124. req["id"] = get_uuid()
  125. req["tenant_id"] = tenant_id
  126. req["created_by"] = tenant_id
  127. try:
  128. ok, t = TenantService.get_by_id(tenant_id)
  129. if not ok:
  130. return get_error_data_result(message="Tenant not found")
  131. except OperationalError as e:
  132. logging.exception(e)
  133. return get_error_data_result(message="Database operation failed")
  134. if not req.get("embd_id"):
  135. req["embd_id"] = t.embd_id
  136. else:
  137. builtin_embedding_models = [
  138. "BAAI/bge-large-zh-v1.5@BAAI",
  139. "maidalun1020/bce-embedding-base_v1@Youdao",
  140. ]
  141. is_builtin_model = req["embd_id"] in builtin_embedding_models
  142. try:
  143. # model name must be model_name@model_factory
  144. llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["embd_id"])
  145. is_tenant_model = TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type="embedding")
  146. is_supported_model = LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding")
  147. if not (is_supported_model and (is_builtin_model or is_tenant_model)):
  148. return get_error_argument_result(f"The embedding_model '{req['embd_id']}' is not supported")
  149. except OperationalError as e:
  150. logging.exception(e)
  151. return get_error_data_result(message="Database operation failed")
  152. try:
  153. if not KnowledgebaseService.save(**req):
  154. return get_error_data_result(message="Database operation failed")
  155. except OperationalError as e:
  156. logging.exception(e)
  157. return get_error_data_result(message="Database operation failed")
  158. try:
  159. ok, k = KnowledgebaseService.get_by_id(req["id"])
  160. if not ok:
  161. return get_error_data_result(message="Dataset created failed")
  162. except OperationalError as e:
  163. logging.exception(e)
  164. return get_error_data_result(message="Database operation failed")
  165. response_data = {}
  166. key_mapping = {
  167. "chunk_num": "chunk_count",
  168. "doc_num": "document_count",
  169. "parser_id": "chunk_method",
  170. "embd_id": "embedding_model",
  171. }
  172. for key, value in k.to_dict().items():
  173. new_key = key_mapping.get(key, key)
  174. response_data[new_key] = value
  175. return get_result(data=response_data)
  176. @manager.route("/datasets", methods=["DELETE"]) # noqa: F821
  177. @token_required
  178. def delete(tenant_id):
  179. """
  180. Delete datasets.
  181. ---
  182. tags:
  183. - Datasets
  184. security:
  185. - ApiKeyAuth: []
  186. parameters:
  187. - in: header
  188. name: Authorization
  189. type: string
  190. required: true
  191. description: Bearer token for authentication.
  192. - in: body
  193. name: body
  194. description: Dataset deletion parameters.
  195. required: true
  196. schema:
  197. type: object
  198. properties:
  199. ids:
  200. type: array
  201. items:
  202. type: string
  203. description: List of dataset IDs to delete.
  204. responses:
  205. 200:
  206. description: Successful operation.
  207. schema:
  208. type: object
  209. """
  210. errors = []
  211. success_count = 0
  212. req = request.json
  213. if not req:
  214. ids = None
  215. else:
  216. ids = req.get("ids")
  217. if not ids:
  218. id_list = []
  219. kbs = KnowledgebaseService.query(tenant_id=tenant_id)
  220. for kb in kbs:
  221. id_list.append(kb.id)
  222. else:
  223. id_list = ids
  224. unique_id_list, duplicate_messages = check_duplicate_ids(id_list, "dataset")
  225. id_list = unique_id_list
  226. for id in id_list:
  227. kbs = KnowledgebaseService.query(id=id, tenant_id=tenant_id)
  228. if not kbs:
  229. errors.append(f"You don't own the dataset {id}")
  230. continue
  231. for doc in DocumentService.query(kb_id=id):
  232. if not DocumentService.remove_document(doc, tenant_id):
  233. errors.append(f"Remove document error for dataset {id}")
  234. continue
  235. f2d = File2DocumentService.get_by_document_id(doc.id)
  236. FileService.filter_delete(
  237. [
  238. File.source_type == FileSource.KNOWLEDGEBASE,
  239. File.id == f2d[0].file_id,
  240. ]
  241. )
  242. File2DocumentService.delete_by_document_id(doc.id)
  243. FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name])
  244. if not KnowledgebaseService.delete_by_id(id):
  245. errors.append(f"Delete dataset error for {id}")
  246. continue
  247. success_count += 1
  248. if errors:
  249. if success_count > 0:
  250. return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} datasets with {len(errors)} errors")
  251. else:
  252. return get_error_data_result(message="; ".join(errors))
  253. if duplicate_messages:
  254. if success_count > 0:
  255. return get_result(
  256. message=f"Partially deleted {success_count} datasets with {len(duplicate_messages)} errors",
  257. data={"success_count": success_count, "errors": duplicate_messages},
  258. )
  259. else:
  260. return get_error_data_result(message=";".join(duplicate_messages))
  261. return get_result(code=settings.RetCode.SUCCESS)
  262. @manager.route("/datasets/<dataset_id>", methods=["PUT"]) # noqa: F821
  263. @token_required
  264. def update(tenant_id, dataset_id):
  265. """
  266. Update a dataset.
  267. ---
  268. tags:
  269. - Datasets
  270. security:
  271. - ApiKeyAuth: []
  272. parameters:
  273. - in: path
  274. name: dataset_id
  275. type: string
  276. required: true
  277. description: ID of the dataset to update.
  278. - in: header
  279. name: Authorization
  280. type: string
  281. required: true
  282. description: Bearer token for authentication.
  283. - in: body
  284. name: body
  285. description: Dataset update parameters.
  286. required: true
  287. schema:
  288. type: object
  289. properties:
  290. name:
  291. type: string
  292. description: New name of the dataset.
  293. permission:
  294. type: string
  295. enum: ['me', 'team']
  296. description: Updated permission.
  297. chunk_method:
  298. type: string
  299. enum: ["naive", "manual", "qa", "table", "paper", "book", "laws",
  300. "presentation", "picture", "one", "email", "tag"
  301. ]
  302. description: Updated chunking method.
  303. parser_config:
  304. type: object
  305. description: Updated parser configuration.
  306. responses:
  307. 200:
  308. description: Successful operation.
  309. schema:
  310. type: object
  311. """
  312. if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
  313. return get_error_data_result(message="You don't own the dataset")
  314. req = request.json
  315. for k in req.keys():
  316. if dataset_readonly_fields(k):
  317. return get_result(code=settings.RetCode.ARGUMENT_ERROR, message=f"'{k}' is readonly.")
  318. e, t = TenantService.get_by_id(tenant_id)
  319. invalid_keys = {"id", "embd_id", "chunk_num", "doc_num", "parser_id", "create_date", "create_time", "created_by", "status", "token_num", "update_date", "update_time"}
  320. if any(key in req for key in invalid_keys):
  321. return get_error_data_result(message="The input parameters are invalid.")
  322. permission = req.get("permission")
  323. chunk_method = req.get("chunk_method")
  324. parser_config = req.get("parser_config")
  325. valid_parser_config(parser_config)
  326. valid_permission = ["me", "team"]
  327. valid_chunk_method = ["naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one", "email", "tag"]
  328. check_validation = valid(
  329. permission,
  330. valid_permission,
  331. chunk_method,
  332. valid_chunk_method,
  333. )
  334. if check_validation:
  335. return check_validation
  336. if "tenant_id" in req:
  337. if req["tenant_id"] != tenant_id:
  338. return get_error_data_result(message="Can't change `tenant_id`.")
  339. e, kb = KnowledgebaseService.get_by_id(dataset_id)
  340. if "parser_config" in req:
  341. temp_dict = kb.parser_config
  342. temp_dict.update(req["parser_config"])
  343. req["parser_config"] = temp_dict
  344. if "chunk_count" in req:
  345. if req["chunk_count"] != kb.chunk_num:
  346. return get_error_data_result(message="Can't change `chunk_count`.")
  347. req.pop("chunk_count")
  348. if "document_count" in req:
  349. if req["document_count"] != kb.doc_num:
  350. return get_error_data_result(message="Can't change `document_count`.")
  351. req.pop("document_count")
  352. if req.get("chunk_method"):
  353. if kb.chunk_num != 0 and req["chunk_method"] != kb.parser_id:
  354. return get_error_data_result(message="If `chunk_count` is not 0, `chunk_method` is not changeable.")
  355. req["parser_id"] = req.pop("chunk_method")
  356. if req["parser_id"] != kb.parser_id:
  357. if not req.get("parser_config"):
  358. req["parser_config"] = get_parser_config(chunk_method, parser_config)
  359. if "embedding_model" in req:
  360. if kb.chunk_num != 0 and req["embedding_model"] != kb.embd_id:
  361. return get_error_data_result(message="If `chunk_count` is not 0, `embedding_model` is not changeable.")
  362. if not req.get("embedding_model"):
  363. return get_error_data_result("`embedding_model` can't be empty")
  364. valid_embedding_models = [
  365. "BAAI/bge-large-zh-v1.5",
  366. "BAAI/bge-base-en-v1.5",
  367. "BAAI/bge-large-en-v1.5",
  368. "BAAI/bge-small-en-v1.5",
  369. "BAAI/bge-small-zh-v1.5",
  370. "jinaai/jina-embeddings-v2-base-en",
  371. "jinaai/jina-embeddings-v2-small-en",
  372. "nomic-ai/nomic-embed-text-v1.5",
  373. "sentence-transformers/all-MiniLM-L6-v2",
  374. "text-embedding-v2",
  375. "text-embedding-v3",
  376. "maidalun1020/bce-embedding-base_v1",
  377. ]
  378. embd_model = LLMService.query(llm_name=req["embedding_model"], model_type="embedding")
  379. if embd_model:
  380. if req["embedding_model"] not in valid_embedding_models and not TenantLLMService.query(
  381. tenant_id=tenant_id,
  382. model_type="embedding",
  383. llm_name=req.get("embedding_model"),
  384. ):
  385. return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist")
  386. if not embd_model:
  387. embd_model = TenantLLMService.query(tenant_id=tenant_id, model_type="embedding", llm_name=req.get("embedding_model"))
  388. if not embd_model:
  389. return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist")
  390. req["embd_id"] = req.pop("embedding_model")
  391. if "name" in req:
  392. req["name"] = req["name"].strip()
  393. if len(req["name"]) >= 128:
  394. return get_error_data_result(message="Dataset name should not be longer than 128 characters.")
  395. if req["name"].lower() != kb.name.lower() and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0:
  396. return get_error_data_result(message="Duplicated dataset name in updating dataset.")
  397. flds = list(req.keys())
  398. for f in flds:
  399. if req[f] == "" and f in ["permission", "parser_id", "chunk_method"]:
  400. del req[f]
  401. if not KnowledgebaseService.update_by_id(kb.id, req):
  402. return get_error_data_result(message="Update dataset error.(Database error)")
  403. return get_result(code=settings.RetCode.SUCCESS)
  404. @manager.route("/datasets", methods=["GET"]) # noqa: F821
  405. @token_required
  406. def list_datasets(tenant_id):
  407. """
  408. List datasets.
  409. ---
  410. tags:
  411. - Datasets
  412. security:
  413. - ApiKeyAuth: []
  414. parameters:
  415. - in: query
  416. name: id
  417. type: string
  418. required: false
  419. description: Dataset ID to filter.
  420. - in: query
  421. name: name
  422. type: string
  423. required: false
  424. description: Dataset name to filter.
  425. - in: query
  426. name: page
  427. type: integer
  428. required: false
  429. default: 1
  430. description: Page number.
  431. - in: query
  432. name: page_size
  433. type: integer
  434. required: false
  435. default: 1024
  436. description: Number of items per page.
  437. - in: query
  438. name: orderby
  439. type: string
  440. required: false
  441. default: "create_time"
  442. description: Field to order by.
  443. - in: query
  444. name: desc
  445. type: boolean
  446. required: false
  447. default: true
  448. description: Order in descending.
  449. - in: header
  450. name: Authorization
  451. type: string
  452. required: true
  453. description: Bearer token for authentication.
  454. responses:
  455. 200:
  456. description: Successful operation.
  457. schema:
  458. type: array
  459. items:
  460. type: object
  461. """
  462. id = request.args.get("id")
  463. name = request.args.get("name")
  464. if id:
  465. kbs = KnowledgebaseService.get_kb_by_id(id, tenant_id)
  466. if not kbs:
  467. return get_error_data_result(f"You don't own the dataset {id}")
  468. if name:
  469. kbs = KnowledgebaseService.get_kb_by_name(name, tenant_id)
  470. if not kbs:
  471. return get_error_data_result(f"You don't own the dataset {name}")
  472. page_number = int(request.args.get("page", 1))
  473. items_per_page = int(request.args.get("page_size", 30))
  474. orderby = request.args.get("orderby", "create_time")
  475. if request.args.get("desc", "false").lower() not in ["true", "false"]:
  476. return get_error_data_result("desc should be true or false")
  477. if request.args.get("desc", "true").lower() == "false":
  478. desc = False
  479. else:
  480. desc = True
  481. tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
  482. kbs = KnowledgebaseService.get_list(
  483. [m["tenant_id"] for m in tenants],
  484. tenant_id,
  485. page_number,
  486. items_per_page,
  487. orderby,
  488. desc,
  489. id,
  490. name,
  491. )
  492. renamed_list = []
  493. for kb in kbs:
  494. key_mapping = {
  495. "chunk_num": "chunk_count",
  496. "doc_num": "document_count",
  497. "parser_id": "chunk_method",
  498. "embd_id": "embedding_model",
  499. }
  500. renamed_data = {}
  501. for key, value in kb.items():
  502. new_key = key_mapping.get(key, key)
  503. renamed_data[new_key] = value
  504. renamed_list.append(renamed_data)
  505. return get_result(data=renamed_list)