Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

dataset.py 7.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from flask import request
  17. from api.db import StatusEnum, FileSource
  18. from api.db.db_models import File
  19. from api.db.services.document_service import DocumentService
  20. from api.db.services.file2document_service import File2DocumentService
  21. from api.db.services.file_service import FileService
  22. from api.db.services.knowledgebase_service import KnowledgebaseService
  23. from api.db.services.user_service import TenantService
  24. from api.settings import RetCode
  25. from api.utils import get_uuid
  26. from api.utils.api_utils import get_result, token_required,get_error_data_result
  27. @manager.route('/dataset', methods=['POST'])
  28. @token_required
  29. def create(tenant_id):
  30. req = request.json
  31. e, t = TenantService.get_by_id(tenant_id)
  32. if "tenant_id" in req or "embedding_model" in req:
  33. return get_error_data_result(
  34. retmsg="Tenant_id or embedding_model must not be provided")
  35. chunk_count=req.get("chunk_count")
  36. document_count=req.get("document_count")
  37. if chunk_count or document_count:
  38. return get_error_data_result(retmsg="chunk_count or document_count must be 0 or not be provided")
  39. if "name" not in req:
  40. return get_error_data_result(
  41. retmsg="Name is not empty!")
  42. req['id'] = get_uuid()
  43. req["name"] = req["name"].strip()
  44. if req["name"] == "":
  45. return get_error_data_result(
  46. retmsg="Name is not empty string!")
  47. if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
  48. return get_error_data_result(
  49. retmsg="Duplicated knowledgebase name in creating dataset.")
  50. req["tenant_id"] = req['created_by'] = tenant_id
  51. req['embedding_model'] = t.embd_id
  52. key_mapping = {
  53. "chunk_num": "chunk_count",
  54. "doc_num": "document_count",
  55. "parser_id": "parse_method",
  56. "embd_id": "embedding_model"
  57. }
  58. mapped_keys = {new_key: req[old_key] for new_key, old_key in key_mapping.items() if old_key in req}
  59. req.update(mapped_keys)
  60. if not KnowledgebaseService.save(**req):
  61. return get_error_data_result(retmsg="Create dataset error.(Database error)")
  62. renamed_data = {}
  63. e, k = KnowledgebaseService.get_by_id(req["id"])
  64. for key, value in k.to_dict().items():
  65. new_key = key_mapping.get(key, key)
  66. renamed_data[new_key] = value
  67. return get_result(data=renamed_data)
  68. @manager.route('/dataset', methods=['DELETE'])
  69. @token_required
  70. def delete(tenant_id):
  71. req = request.json
  72. ids = req.get("ids")
  73. if not ids:
  74. return get_error_data_result(
  75. retmsg="ids are required")
  76. for id in ids:
  77. kbs = KnowledgebaseService.query(id=id, tenant_id=tenant_id)
  78. if not kbs:
  79. return get_error_data_result(retmsg=f"You don't own the dataset {id}")
  80. for doc in DocumentService.query(kb_id=id):
  81. if not DocumentService.remove_document(doc, tenant_id):
  82. return get_error_data_result(
  83. retmsg="Remove document error.(Database error)")
  84. f2d = File2DocumentService.get_by_document_id(doc.id)
  85. FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
  86. File2DocumentService.delete_by_document_id(doc.id)
  87. if not KnowledgebaseService.delete_by_id(id):
  88. return get_error_data_result(
  89. retmsg="Delete dataset error.(Database serror)")
  90. return get_result(retcode=RetCode.SUCCESS)
  91. @manager.route('/dataset/<dataset_id>', methods=['PUT'])
  92. @token_required
  93. def update(tenant_id,dataset_id):
  94. if not KnowledgebaseService.query(id=dataset_id,tenant_id=tenant_id):
  95. return get_error_data_result(retmsg="You don't own the dataset")
  96. req = request.json
  97. e, t = TenantService.get_by_id(tenant_id)
  98. invalid_keys = {"id", "embd_id", "chunk_num", "doc_num", "parser_id"}
  99. if any(key in req for key in invalid_keys):
  100. return get_error_data_result(retmsg="The input parameters are invalid.")
  101. if "tenant_id" in req:
  102. if req["tenant_id"] != tenant_id:
  103. return get_error_data_result(
  104. retmsg="Can't change tenant_id.")
  105. e, kb = KnowledgebaseService.get_by_id(dataset_id)
  106. if "chunk_count" in req:
  107. if req["chunk_count"] != kb.chunk_num:
  108. return get_error_data_result(
  109. retmsg="Can't change chunk_count.")
  110. req.pop("chunk_count")
  111. if "document_count" in req:
  112. if req['document_count'] != kb.doc_num:
  113. return get_error_data_result(
  114. retmsg="Can't change document_count.")
  115. req.pop("document_count")
  116. if "parse_method" in req:
  117. if kb.chunk_num != 0 and req['parse_method'] != kb.parser_id:
  118. return get_error_data_result(
  119. retmsg="If chunk count is not 0, parse method is not changable.")
  120. req['parser_id'] = req.pop('parse_method')
  121. if "embedding_model" in req:
  122. if kb.chunk_num != 0 and req['parse_method'] != kb.parser_id:
  123. return get_error_data_result(
  124. retmsg="If chunk count is not 0, parse method is not changable.")
  125. req['embd_id'] = req.pop('embedding_model')
  126. if "name" in req:
  127. req["name"] = req["name"].strip()
  128. if req["name"].lower() != kb.name.lower() \
  129. and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id,
  130. status=StatusEnum.VALID.value)) > 0:
  131. return get_error_data_result(
  132. retmsg="Duplicated knowledgebase name in updating dataset.")
  133. if not KnowledgebaseService.update_by_id(kb.id, req):
  134. return get_error_data_result(retmsg="Update dataset error.(Database error)")
  135. return get_result(retcode=RetCode.SUCCESS)
  136. @manager.route('/dataset', methods=['GET'])
  137. @token_required
  138. def list(tenant_id):
  139. id = request.args.get("id")
  140. name = request.args.get("name")
  141. kbs = KnowledgebaseService.query(id=id,name=name,status=1)
  142. if not kbs:
  143. return get_error_data_result(retmsg="The dataset doesn't exist")
  144. page_number = int(request.args.get("page", 1))
  145. items_per_page = int(request.args.get("page_size", 1024))
  146. orderby = request.args.get("orderby", "create_time")
  147. if request.args.get("desc") == "False" or request.args.get("desc") == "false" :
  148. desc = False
  149. else:
  150. desc = True
  151. tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
  152. kbs = KnowledgebaseService.get_list(
  153. [m["tenant_id"] for m in tenants], tenant_id, page_number, items_per_page, orderby, desc, id, name)
  154. renamed_list = []
  155. for kb in kbs:
  156. key_mapping = {
  157. "chunk_num": "chunk_count",
  158. "doc_num": "document_count",
  159. "parser_id": "parse_method",
  160. "embd_id": "embedding_model"
  161. }
  162. renamed_data = {}
  163. for key, value in kb.items():
  164. new_key = key_mapping.get(key, key)
  165. renamed_data[new_key] = value
  166. renamed_list.append(renamed_data)
  167. return get_result(data=renamed_list)