Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

knowledgebase_service.py 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from datetime import datetime
  17. from peewee import fn
  18. from api.db import StatusEnum, TenantPermission
  19. from api.db.db_models import DB, Document, Knowledgebase, Tenant, User, UserTenant
  20. from api.db.services.common_service import CommonService
  21. from api.utils import current_timestamp, datetime_format
  22. class KnowledgebaseService(CommonService):
  23. """Service class for managing knowledge base operations.
  24. This class extends CommonService to provide specialized functionality for knowledge base
  25. management, including document parsing status tracking, access control, and configuration
  26. management. It handles operations such as listing, creating, updating, and deleting
  27. knowledge bases, as well as managing their associated documents and permissions.
  28. The class implements a comprehensive set of methods for:
  29. - Document parsing status verification
  30. - Knowledge base access control
  31. - Parser configuration management
  32. - Tenant-based knowledge base organization
  33. Attributes:
  34. model: The Knowledgebase model class for database operations.
  35. """
  36. model = Knowledgebase
  37. @classmethod
  38. @DB.connection_context()
  39. def accessible4deletion(cls, kb_id, user_id):
  40. """Check if a knowledge base can be deleted by a specific user.
  41. This method verifies whether a user has permission to delete a knowledge base
  42. by checking if they are the creator of that knowledge base.
  43. Args:
  44. kb_id (str): The unique identifier of the knowledge base to check.
  45. user_id (str): The unique identifier of the user attempting the deletion.
  46. Returns:
  47. bool: True if the user has permission to delete the knowledge base,
  48. False if the user doesn't have permission or the knowledge base doesn't exist.
  49. Example:
  50. >>> KnowledgebaseService.accessible4deletion("kb123", "user456")
  51. True
  52. Note:
  53. - This method only checks creator permissions
  54. - A return value of False can mean either:
  55. 1. The knowledge base doesn't exist
  56. 2. The user is not the creator of the knowledge base
  57. """
  58. # Check if a knowledge base can be deleted by a user
  59. docs = cls.model.select(
  60. cls.model.id).where(cls.model.id == kb_id, cls.model.created_by == user_id).paginate(0, 1)
  61. docs = docs.dicts()
  62. if not docs:
  63. return False
  64. return True
  65. @classmethod
  66. @DB.connection_context()
  67. def is_parsed_done(cls, kb_id):
  68. # Check if all documents in the knowledge base have completed parsing
  69. #
  70. # Args:
  71. # kb_id: Knowledge base ID
  72. #
  73. # Returns:
  74. # If all documents are parsed successfully, returns (True, None)
  75. # If any document is not fully parsed, returns (False, error_message)
  76. from api.db import TaskStatus
  77. from api.db.services.document_service import DocumentService
  78. # Get knowledge base information
  79. kbs = cls.query(id=kb_id)
  80. if not kbs:
  81. return False, "Knowledge base not found"
  82. kb = kbs[0]
  83. # Get all documents in the knowledge base
  84. docs, _ = DocumentService.get_by_kb_id(kb_id, 1, 1000, "create_time", True, "")
  85. # Check parsing status of each document
  86. for doc in docs:
  87. # If document is being parsed, don't allow chat creation
  88. if doc['run'] == TaskStatus.RUNNING.value or doc['run'] == TaskStatus.CANCEL.value or doc['run'] == TaskStatus.FAIL.value:
  89. return False, f"Document '{doc['name']}' in dataset '{kb.name}' is still being parsed. Please wait until all documents are parsed before starting a chat."
  90. # If document is not yet parsed and has no chunks, don't allow chat creation
  91. if doc['run'] == TaskStatus.UNSTART.value and doc['chunk_num'] == 0:
  92. return False, f"Document '{doc['name']}' in dataset '{kb.name}' has not been parsed yet. Please parse all documents before starting a chat."
  93. return True, None
  94. @classmethod
  95. @DB.connection_context()
  96. def list_documents_by_ids(cls, kb_ids):
  97. # Get document IDs associated with given knowledge base IDs
  98. # Args:
  99. # kb_ids: List of knowledge base IDs
  100. # Returns:
  101. # List of document IDs
  102. doc_ids = cls.model.select(Document.id.alias("document_id")).join(Document, on=(cls.model.id == Document.kb_id)).where(
  103. cls.model.id.in_(kb_ids)
  104. )
  105. doc_ids = list(doc_ids.dicts())
  106. doc_ids = [doc["document_id"] for doc in doc_ids]
  107. return doc_ids
  108. @classmethod
  109. @DB.connection_context()
  110. def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
  111. page_number, items_per_page,
  112. orderby, desc, keywords,
  113. parser_id=None
  114. ):
  115. # Get knowledge bases by tenant IDs with pagination and filtering
  116. # Args:
  117. # joined_tenant_ids: List of tenant IDs
  118. # user_id: Current user ID
  119. # page_number: Page number for pagination
  120. # items_per_page: Number of items per page
  121. # orderby: Field to order by
  122. # desc: Boolean indicating descending order
  123. # keywords: Search keywords
  124. # parser_id: Optional parser ID filter
  125. # Returns:
  126. # Tuple of (knowledge_base_list, total_count)
  127. fields = [
  128. cls.model.id,
  129. cls.model.avatar,
  130. cls.model.name,
  131. cls.model.language,
  132. cls.model.description,
  133. cls.model.tenant_id,
  134. cls.model.permission,
  135. cls.model.doc_num,
  136. cls.model.token_num,
  137. cls.model.chunk_num,
  138. cls.model.parser_id,
  139. cls.model.embd_id,
  140. User.nickname,
  141. User.avatar.alias('tenant_avatar'),
  142. cls.model.update_time
  143. ]
  144. if keywords:
  145. kbs = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where(
  146. ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
  147. TenantPermission.TEAM.value)) | (
  148. cls.model.tenant_id == user_id))
  149. & (cls.model.status == StatusEnum.VALID.value),
  150. (fn.LOWER(cls.model.name).contains(keywords.lower()))
  151. )
  152. else:
  153. kbs = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where(
  154. ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
  155. TenantPermission.TEAM.value)) | (
  156. cls.model.tenant_id == user_id))
  157. & (cls.model.status == StatusEnum.VALID.value)
  158. )
  159. if parser_id:
  160. kbs = kbs.where(cls.model.parser_id == parser_id)
  161. if desc:
  162. kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
  163. else:
  164. kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
  165. count = kbs.count()
  166. if page_number and items_per_page:
  167. kbs = kbs.paginate(page_number, items_per_page)
  168. return list(kbs.dicts()), count
  169. @classmethod
  170. @DB.connection_context()
  171. def get_kb_ids(cls, tenant_id):
  172. # Get all knowledge base IDs for a tenant
  173. # Args:
  174. # tenant_id: Tenant ID
  175. # Returns:
  176. # List of knowledge base IDs
  177. fields = [
  178. cls.model.id,
  179. ]
  180. kbs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
  181. kb_ids = [kb.id for kb in kbs]
  182. return kb_ids
  183. @classmethod
  184. @DB.connection_context()
  185. def get_detail(cls, kb_id):
  186. # Get detailed information about a knowledge base
  187. # Args:
  188. # kb_id: Knowledge base ID
  189. # Returns:
  190. # Dictionary containing knowledge base details
  191. fields = [
  192. cls.model.id,
  193. cls.model.embd_id,
  194. cls.model.avatar,
  195. cls.model.name,
  196. cls.model.language,
  197. cls.model.description,
  198. cls.model.permission,
  199. cls.model.doc_num,
  200. cls.model.token_num,
  201. cls.model.chunk_num,
  202. cls.model.parser_id,
  203. cls.model.parser_config,
  204. cls.model.pagerank]
  205. kbs = cls.model.select(*fields).join(Tenant, on=(
  206. (Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where(
  207. (cls.model.id == kb_id),
  208. (cls.model.status == StatusEnum.VALID.value)
  209. )
  210. if not kbs:
  211. return
  212. d = kbs[0].to_dict()
  213. return d
  214. @classmethod
  215. @DB.connection_context()
  216. def update_parser_config(cls, id, config):
  217. # Update parser configuration for a knowledge base
  218. # Args:
  219. # id: Knowledge base ID
  220. # config: New parser configuration
  221. e, m = cls.get_by_id(id)
  222. if not e:
  223. raise LookupError(f"knowledgebase({id}) not found.")
  224. def dfs_update(old, new):
  225. # Deep update of nested configuration
  226. for k, v in new.items():
  227. if k not in old:
  228. old[k] = v
  229. continue
  230. if isinstance(v, dict):
  231. assert isinstance(old[k], dict)
  232. dfs_update(old[k], v)
  233. elif isinstance(v, list):
  234. assert isinstance(old[k], list)
  235. old[k] = list(set(old[k] + v))
  236. else:
  237. old[k] = v
  238. dfs_update(m.parser_config, config)
  239. cls.update_by_id(id, {"parser_config": m.parser_config})
  240. @classmethod
  241. @DB.connection_context()
  242. def get_field_map(cls, ids):
  243. # Get field mappings for knowledge bases
  244. # Args:
  245. # ids: List of knowledge base IDs
  246. # Returns:
  247. # Dictionary of field mappings
  248. conf = {}
  249. for k in cls.get_by_ids(ids):
  250. if k.parser_config and "field_map" in k.parser_config:
  251. conf.update(k.parser_config["field_map"])
  252. return conf
  253. @classmethod
  254. @DB.connection_context()
  255. def get_by_name(cls, kb_name, tenant_id):
  256. # Get knowledge base by name and tenant ID
  257. # Args:
  258. # kb_name: Knowledge base name
  259. # tenant_id: Tenant ID
  260. # Returns:
  261. # Tuple of (exists, knowledge_base)
  262. kb = cls.model.select().where(
  263. (cls.model.name == kb_name)
  264. & (cls.model.tenant_id == tenant_id)
  265. & (cls.model.status == StatusEnum.VALID.value)
  266. )
  267. if kb:
  268. return True, kb[0]
  269. return False, None
  270. @classmethod
  271. @DB.connection_context()
  272. def get_all_ids(cls):
  273. # Get all knowledge base IDs
  274. # Returns:
  275. # List of all knowledge base IDs
  276. return [m["id"] for m in cls.model.select(cls.model.id).dicts()]
  277. @classmethod
  278. @DB.connection_context()
  279. def get_list(cls, joined_tenant_ids, user_id,
  280. page_number, items_per_page, orderby, desc, id, name):
  281. # Get list of knowledge bases with filtering and pagination
  282. # Args:
  283. # joined_tenant_ids: List of tenant IDs
  284. # user_id: Current user ID
  285. # page_number: Page number for pagination
  286. # items_per_page: Number of items per page
  287. # orderby: Field to order by
  288. # desc: Boolean indicating descending order
  289. # id: Optional ID filter
  290. # name: Optional name filter
  291. # Returns:
  292. # List of knowledge bases
  293. kbs = cls.model.select()
  294. if id:
  295. kbs = kbs.where(cls.model.id == id)
  296. if name:
  297. kbs = kbs.where(cls.model.name == name)
  298. kbs = kbs.where(
  299. ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
  300. TenantPermission.TEAM.value)) | (
  301. cls.model.tenant_id == user_id))
  302. & (cls.model.status == StatusEnum.VALID.value)
  303. )
  304. if desc:
  305. kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
  306. else:
  307. kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
  308. kbs = kbs.paginate(page_number, items_per_page)
  309. return list(kbs.dicts())
  310. @classmethod
  311. @DB.connection_context()
  312. def accessible(cls, kb_id, user_id):
  313. # Check if a knowledge base is accessible by a user
  314. # Args:
  315. # kb_id: Knowledge base ID
  316. # user_id: User ID
  317. # Returns:
  318. # Boolean indicating accessibility
  319. docs = cls.model.select(
  320. cls.model.id).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
  321. ).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
  322. docs = docs.dicts()
  323. if not docs:
  324. return False
  325. return True
  326. @classmethod
  327. @DB.connection_context()
  328. def get_kb_by_id(cls, kb_id, user_id):
  329. # Get knowledge base by ID and user ID
  330. # Args:
  331. # kb_id: Knowledge base ID
  332. # user_id: User ID
  333. # Returns:
  334. # List containing knowledge base information
  335. kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
  336. ).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
  337. kbs = kbs.dicts()
  338. return list(kbs)
  339. @classmethod
  340. @DB.connection_context()
  341. def get_kb_by_name(cls, kb_name, user_id):
  342. # Get knowledge base by name and user ID
  343. # Args:
  344. # kb_name: Knowledge base name
  345. # user_id: User ID
  346. # Returns:
  347. # List containing knowledge base information
  348. kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
  349. ).where(cls.model.name == kb_name, UserTenant.user_id == user_id).paginate(0, 1)
  350. kbs = kbs.dicts()
  351. return list(kbs)
  352. @classmethod
  353. @DB.connection_context()
  354. def atomic_increase_doc_num_by_id(cls, kb_id):
  355. data = {}
  356. data["update_time"] = current_timestamp()
  357. data["update_date"] = datetime_format(datetime.now())
  358. data["doc_num"] = cls.model.doc_num + 1
  359. num = cls.model.update(data).where(cls.model.id == kb_id).execute()
  360. return num
  361. @classmethod
  362. @DB.connection_context()
  363. def update_document_number_in_init(cls, kb_id, doc_num):
  364. """
  365. Only use this function when init system
  366. """
  367. ok, kb = cls.get_by_id(kb_id)
  368. if not ok:
  369. return
  370. kb.doc_num = doc_num
  371. dirty_fields = kb.dirty_fields
  372. if cls.model._meta.combined.get("update_time") in dirty_fields:
  373. dirty_fields.remove(cls.model._meta.combined["update_time"])
  374. if cls.model._meta.combined.get("update_date") in dirty_fields:
  375. dirty_fields.remove(cls.model._meta.combined["update_date"])
  376. try:
  377. kb.save(only=dirty_fields)
  378. except ValueError as e:
  379. if str(e) == "no data to save!":
  380. pass # that's OK
  381. else:
  382. raise e